This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1,138 @@
import logging
import time
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import (
AppGenerateEntity,
)
from core.app.entities.queue_entities import (
QueueErrorEvent,
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
PingStreamResponse,
)
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.moderation.output_moderation import ModerationRule, OutputModeration
from models.enums import MessageStatus
from models.model import Message
logger = logging.getLogger(__name__)
class BasedGenerateTaskPipeline:
"""
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
def __init__(
self,
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
stream: bool,
):
self._application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.start_at = time.perf_counter()
self.output_moderation_handler = self._init_output_moderation()
self.stream = stream
def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
logger.debug("error: %s", event.error)
e = event.error
err: Exception
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError):
err = e
else:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))
if not message_id or not session:
return err
stmt = select(Message).where(Message.id == message_id)
message = session.scalar(stmt)
if not message:
return err
err_desc = self._error_to_desc(err)
message.status = MessageStatus.ERROR
message.error = err_desc
return err
def _error_to_desc(self, e: Exception) -> str:
"""
Error to desc.
:param e: exception
:return:
"""
if isinstance(e, QuotaExceededError):
return (
"Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials."
)
message = getattr(e, "description", str(e))
if not message:
message = "Internal Server Error, please contact support."
return message
def error_to_stream_response(self, e: Exception):
"""
Error to stream response.
:param e: exception
:return:
"""
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
def ping_stream_response(self) -> PingStreamResponse:
"""
Ping stream response.
:return:
"""
return PingStreamResponse(task_id=self._application_generate_entity.task_id)
def _init_output_moderation(self) -> OutputModeration | None:
"""
Init output moderation.
:return:
"""
app_config = self._application_generate_entity.app_config
sensitive_word_avoidance = app_config.sensitive_word_avoidance
if sensitive_word_avoidance:
return OutputModeration(
tenant_id=app_config.tenant_id,
app_id=app_config.app_id,
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
queue_manager=self.queue_manager,
)
return None
def handle_output_moderation_when_task_finished(self, completion: str) -> str | None:
"""
Handle output moderation when task finished.
:param completion: completion
:return:
"""
# response moderation
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
completion, flagged = self.output_moderation_handler.moderation_completion(
completion=completion, public_event=False
)
self.output_moderation_handler = None
if flagged:
return completion
return None

View File

@@ -0,0 +1,526 @@
import logging
import time
from collections.abc import Generator
from threading import Thread
from typing import Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
)
from core.app.entities.queue_entities import (
QueueAgentMessageEvent,
QueueAgentThoughtEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueLLMChunkEvent,
QueueMessageEndEvent,
QueueMessageFileEvent,
QueueMessageReplaceEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
)
from core.app.entities.task_entities import (
AgentMessageStreamResponse,
AgentThoughtStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
CompletionAppBlockingResponse,
CompletionAppStreamResponse,
EasyUITaskState,
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
StreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
TextPromptMessageContent,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
"""
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
def __init__(
self,
application_generate_entity: Union[
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
stream: bool,
):
super().__init__(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
)
self._model_config = application_generate_entity.model_conf
self._app_config = application_generate_entity.app_config
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._task_state = EasyUITaskState(
llm_result=LLMResult(
model=self._model_config.model,
prompt_messages=[],
message=AssistantPromptMessage(content=""),
usage=LLMUsage.empty_usage(),
)
)
self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity,
task_state=self._task_state,
)
self._conversation_name_generate_thread: Thread | None = None
def process(
self,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]:
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
"""
Process blocking response.
:return:
"""
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation_mode == AppMode.COMPLETION:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=CompletionAppBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
message_id=self._message_id,
answer=cast(str, self._task_state.llm_result.message.content),
created_at=self._message_created_at,
**extras,
),
)
else:
response = ChatbotAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=ChatbotAppBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
answer=cast(str, self._task_state.llm_result.message.content),
created_at=self._message_created_at,
**extras,
),
)
return response
else:
continue
raise RuntimeError("queue listening stopped unexpectedly.")
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
"""
To stream response.
:return:
"""
for stream_response in generator:
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
yield CompletionAppStreamResponse(
message_id=self._message_id,
created_at=self._message_created_at,
stream_response=stream_response,
)
else:
yield ChatbotAppStreamResponse(
conversation_id=self._conversation_id,
message_id=self._message_id,
created_at=self._message_created_at,
stream_response=stream_response,
)
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if publisher is None:
return None
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(
self, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
tenant_id = self._application_generate_entity.app_config.tenant_id
task_id = self._application_generate_entity.task_id
publisher = None
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
if (
text_to_speech_dict
and text_to_speech_dict.get("autoPlay") == "enabled"
and text_to_speech_dict.get("enabled")
):
publisher = AppGeneratorTTSPublisher(
tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None)
)
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(publisher, task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
# timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
if publisher is None:
break
audio = publisher.check_and_get_audio()
if audio is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio.status == "finish":
break
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id)
if publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self, publisher: AppGeneratorTTSPublisher | None, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self.queue_manager.listen():
if publisher:
publisher.publish(message)
event = message.event
if isinstance(event, QueueErrorEvent):
with Session(db.engine) as session:
err = self.handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self.error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
if event.llm_result:
self._task_state.llm_result = event.llm_result
else:
self._handle_stop(event)
# handle output moderation
output_moderation_answer = self.handle_output_moderation_when_task_finished(
cast(str, self._task_state.llm_result.message.content)
)
if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=output_moderation_answer
)
with Session(db.engine) as session:
# Save message
self._save_message(session=session, trace_manager=trace_manager)
session.commit()
message_end_resp = self._message_end_to_stream_response()
yield message_end_resp
elif isinstance(event, QueueRetrieverResourcesEvent):
self._message_cycle_manager.handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent):
annotation = self._message_cycle_manager.handle_annotation_reply(event)
if annotation:
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent):
agent_thought_response = self._agent_thought_to_stream_response(event)
if agent_thought_response is not None:
yield agent_thought_response
elif isinstance(event, QueueMessageFileEvent):
response = self._message_cycle_manager.message_file_to_stream_response(event)
if response:
yield response
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
chunk = event.chunk
delta_text = chunk.delta.message.content
if delta_text is None:
continue
if isinstance(chunk.delta.message.content, list):
delta_text = ""
for content in chunk.delta.message.content:
logger.debug(
"The content type %s in LLM chunk delta message content.: %r", type(content), content
)
if isinstance(content, TextPromptMessageContent):
delta_text += content.data
elif isinstance(content, str):
delta_text += content # failback to str
else:
logger.warning(
"Unsupported content type %s in LLM chunk delta message content.: %r",
type(content),
content,
)
continue
if not self._task_state.llm_result.prompt_messages:
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
if should_direct_answer:
continue
current_content = cast(str, self._task_state.llm_result.message.content)
current_content += cast(str, delta_text)
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
)
else:
yield self._agent_message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self.ping_stream_response()
else:
continue
if publisher:
publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None):
"""
Save message.
:return:
"""
llm_result = self._task_state.llm_result
usage = llm_result.usage
message_stmt = select(Message).where(Message.id == self._message_id)
message = session.scalar(message_stmt)
if not message:
raise ValueError(f"message {self._message_id} not found")
conversation_stmt = select(Conversation).where(Conversation.id == self._conversation_id)
conversation = session.scalar(conversation_stmt)
if not conversation:
raise ValueError(f"Conversation {self._conversation_id} not found")
message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
self._model_config.mode, self._task_state.llm_result.prompt_messages
)
message.message_tokens = usage.prompt_tokens
message.message_unit_price = usage.prompt_unit_price
message.message_price_unit = usage.prompt_price_unit
message.answer = (
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
if llm_result.message.content
else ""
)
message.updated_at = naive_utc_now()
message.answer_tokens = usage.completion_tokens
message.answer_unit_price = usage.completion_unit_price
message.answer_price_unit = usage.completion_price_unit
message.provider_response_latency = time.perf_counter() - self.start_at
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency
message.message_metadata = self._task_state.metadata.model_dump_json()
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
)
)
message_was_created.send(
message,
application_generate_entity=self._application_generate_entity,
)
def _handle_stop(self, event: QueueStopEvent):
"""
Handle stop.
:return:
"""
model_config = self._model_config
model = model_config.model
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
# calculate num tokens
prompt_tokens = 0
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages)
completion_tokens = 0
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message])
credentials = model_config.credentials
# transform usage
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._task_state.llm_result.usage = model_type_instance.calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
"""
Message end to stream response.
:return:
"""
self._task_state.metadata.usage = self._task_state.llm_result.usage
metadata_dict = self._task_state.metadata.model_dump()
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message_id,
metadata=metadata_dict,
)
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
"""
Agent message to stream response.
:param answer: answer
:param message_id: message id
:return:
"""
return AgentMessageStreamResponse(
task_id=self._application_generate_entity.task_id, id=message_id, answer=answer
)
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> AgentThoughtStreamResponse | None:
"""
Agent thought to stream response.
:param event: agent thought event
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
agent_thought: MessageAgentThought | None = (
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
if agent_thought:
return AgentThoughtStreamResponse(
task_id=self._application_generate_entity.task_id,
id=agent_thought.id,
position=agent_thought.position,
thought=agent_thought.thought,
observation=agent_thought.observation,
tool=agent_thought.tool,
tool_labels=agent_thought.tool_labels,
tool_input=agent_thought.tool_input,
message_files=agent_thought.files,
)
return None
def _handle_output_moderation_chunk(self, text: str) -> bool:
"""
Handle output moderation chunk.
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self.output_moderation_handler:
if self.output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output()
self.queue_manager.publish(
QueueLLMChunkEvent(
chunk=LLMResultChunk(
model=self._task_state.llm_result.model,
prompt_messages=self._task_state.llm_result.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content),
),
)
),
PublishFrom.TASK_PIPELINE,
)
self.queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
)
return True
else:
self.output_moderation_handler.append_new_token(text)
return False

View File

@@ -0,0 +1,12 @@
class TaskPipelineError(ValueError):
pass
class RecordNotFoundError(TaskPipelineError):
def __init__(self, record_name: str, record_id: str):
super().__init__(f"{record_name} with id {record_id} not found")
class WorkflowRunNotFoundError(RecordNotFoundError):
def __init__(self, workflow_run_id: str):
super().__init__("WorkflowRun", workflow_run_id)

View File

@@ -0,0 +1,232 @@
import logging
from threading import Thread
from typing import Union
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
)
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
QueueMessageFileEvent,
QueueRetrieverResourcesEvent,
)
from core.app.entities.task_entities import (
AnnotationReply,
AnnotationReplyAccount,
EasyUITaskState,
MessageFileStreamResponse,
MessageReplaceStreamResponse,
MessageStreamResponse,
StreamEvent,
WorkflowTaskState,
)
from core.llm_generator.llm_generator import LLMGenerator
from core.tools.signature import sign_tool_file
from extensions.ext_database import db
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
from services.annotation_service import AppAnnotationService
logger = logging.getLogger(__name__)
class MessageCycleManager:
def __init__(
self,
*,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
task_state: Union[EasyUITaskState, WorkflowTaskState],
):
self._application_generate_entity = application_generate_entity
self._task_state = task_state
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
"""
Generate conversation name.
:param conversation_id: conversation id
:param query: query
:return: thread
"""
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
return None
is_first_message = self._application_generate_entity.conversation_id is None
extras = self._application_generate_entity.extras
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
if auto_generate_conversation_name and is_first_message:
# start generate thread
thread = Thread(
target=self._generate_conversation_name_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"conversation_id": conversation_id,
"query": query,
},
)
thread.start()
return thread
return None
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation = db.session.scalar(stmt)
if not conversation:
return
if conversation.mode != AppMode.COMPLETION:
app_model = conversation.app
if not app_model:
return
# generate conversation name
try:
name = LLMGenerator.generate_conversation_name(
app_model.tenant_id, query, conversation_id, conversation.app_id
)
conversation.name = name
except Exception:
if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
db.session.commit()
db.session.close()
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> MessageAnnotation | None:
"""
Handle annotation reply.
:param event: event
:return:
"""
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation:
account = annotation.account
self._task_state.metadata.annotation_reply = AnnotationReply(
id=annotation.id,
account=AnnotationReplyAccount(
id=annotation.account_id,
name=account.name if account else "Dify user",
),
)
return annotation
return None
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent):
"""
Handle retriever resources.
:param event: event
:return:
"""
if not self._application_generate_entity.app_config.additional_features:
raise ValueError("Additional features not found")
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
merged_resources = [r for r in self._task_state.metadata.retriever_resources or [] if r]
existing_ids = {(r.dataset_id, r.document_id) for r in merged_resources if r.dataset_id and r.document_id}
# Add new unique resources from the event
for resource in event.retriever_resources or []:
if not resource:
continue
is_duplicate = (
resource.dataset_id
and resource.document_id
and (resource.dataset_id, resource.document_id) in existing_ids
)
if not is_duplicate:
merged_resources.append(resource)
for i, resource in enumerate(merged_resources, 1):
resource.position = i
self._task_state.metadata.retriever_resources = merged_resources
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None:
"""
Message file to stream response.
:param event: event
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
if message_file and message_file.url is not None:
# get tool file id
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
# get extension
if "." in message_file.url:
extension = f".{message_file.url.split('.')[-1]}"
if len(extension) > 10:
extension = ".bin"
else:
extension = ".bin"
# add sign url to local file
if message_file.url.startswith("http"):
url = message_file.url
else:
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
return MessageFileStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_file.id,
type=message_file.type,
belongs_to=message_file.belongs_to or "user",
url=url,
)
return None
def message_to_stream_response(
self, answer: str, message_id: str, from_variable_selector: list[str] | None = None
) -> MessageStreamResponse:
"""
Message to stream response.
:param answer: answer
:param message_id: message id
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
from_variable_selector=from_variable_selector,
event=event_type,
)
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
"""
Message replace to stream response.
:param answer: answer
:return:
"""
return MessageReplaceStreamResponse(
task_id=self._application_generate_entity.task_id, answer=answer, reason=reason
)