dify
This commit is contained in:
0
dify/api/core/app/task_pipeline/__init__.py
Normal file
0
dify/api/core/app/task_pipeline/__init__.py
Normal file
138
dify/api/core/app/task_pipeline/based_generate_task_pipeline.py
Normal file
138
dify/api/core/app/task_pipeline/based_generate_task_pipeline.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueErrorEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
from models.enums import MessageStatus
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasedGenerateTaskPipeline:
|
||||
"""
|
||||
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self.queue_manager = queue_manager
|
||||
self.start_at = time.perf_counter()
|
||||
self.output_moderation_handler = self._init_output_moderation()
|
||||
self.stream = stream
|
||||
|
||||
def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
|
||||
logger.debug("error: %s", event.error)
|
||||
e = event.error
|
||||
err: Exception
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
elif isinstance(e, InvokeError | ValueError):
|
||||
err = e
|
||||
else:
|
||||
description = getattr(e, "description", None)
|
||||
err = Exception(description if description is not None else str(e))
|
||||
|
||||
if not message_id or not session:
|
||||
return err
|
||||
|
||||
stmt = select(Message).where(Message.id == message_id)
|
||||
message = session.scalar(stmt)
|
||||
if not message:
|
||||
return err
|
||||
|
||||
err_desc = self._error_to_desc(err)
|
||||
message.status = MessageStatus.ERROR
|
||||
message.error = err_desc
|
||||
return err
|
||||
|
||||
def _error_to_desc(self, e: Exception) -> str:
|
||||
"""
|
||||
Error to desc.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
if isinstance(e, QuotaExceededError):
|
||||
return (
|
||||
"Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||
)
|
||||
|
||||
message = getattr(e, "description", str(e))
|
||||
if not message:
|
||||
message = "Internal Server Error, please contact support."
|
||||
|
||||
return message
|
||||
|
||||
def error_to_stream_response(self, e: Exception):
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
|
||||
|
||||
def ping_stream_response(self) -> PingStreamResponse:
|
||||
"""
|
||||
Ping stream response.
|
||||
:return:
|
||||
"""
|
||||
return PingStreamResponse(task_id=self._application_generate_entity.task_id)
|
||||
|
||||
def _init_output_moderation(self) -> OutputModeration | None:
|
||||
"""
|
||||
Init output moderation.
|
||||
:return:
|
||||
"""
|
||||
app_config = self._application_generate_entity.app_config
|
||||
sensitive_word_avoidance = app_config.sensitive_word_avoidance
|
||||
|
||||
if sensitive_word_avoidance:
|
||||
return OutputModeration(
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_id=app_config.app_id,
|
||||
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
|
||||
queue_manager=self.queue_manager,
|
||||
)
|
||||
return None
|
||||
|
||||
def handle_output_moderation_when_task_finished(self, completion: str) -> str | None:
|
||||
"""
|
||||
Handle output moderation when task finished.
|
||||
:param completion: completion
|
||||
:return:
|
||||
"""
|
||||
# response moderation
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.stop_thread()
|
||||
|
||||
completion, flagged = self.output_moderation_handler.moderation_completion(
|
||||
completion=completion, public_event=False
|
||||
)
|
||||
|
||||
self.output_moderation_handler = None
|
||||
if flagged:
|
||||
return completion
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,526 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
from typing import Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentMessageEvent,
|
||||
QueueAgentThoughtEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentMessageStreamResponse,
|
||||
AgentThoughtStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
CompletionAppStreamResponse,
|
||||
EasyUITaskState,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
"""
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
stream: bool,
|
||||
):
|
||||
super().__init__(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream,
|
||||
)
|
||||
self._model_config = application_generate_entity.model_conf
|
||||
self._app_config = application_generate_entity.app_config
|
||||
|
||||
self._conversation_id = conversation.id
|
||||
self._conversation_mode = conversation.mode
|
||||
|
||||
self._message_id = message.id
|
||||
self._message_created_at = int(message.created_at.timestamp())
|
||||
|
||||
self._task_state = EasyUITaskState(
|
||||
llm_result=LLMResult(
|
||||
model=self._model_config.model,
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
)
|
||||
|
||||
self._message_cycle_manager = MessageCycleManager(
|
||||
application_generate_entity=application_generate_entity,
|
||||
task_state=self._task_state,
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread: Thread | None = None
|
||||
|
||||
def process(
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
|
||||
]:
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self.stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
"""
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata.model_dump()
|
||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||
if self._conversation_mode == AppMode.COMPLETION:
|
||||
response = CompletionAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=CompletionAppBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
message_id=self._message_id,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
else:
|
||||
response = ChatbotAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
raise RuntimeError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
"""
|
||||
for stream_response in generator:
|
||||
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
||||
yield CompletionAppStreamResponse(
|
||||
message_id=self._message_id,
|
||||
created_at=self._message_created_at,
|
||||
stream_response=stream_response,
|
||||
)
|
||||
else:
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
created_at=self._message_created_at,
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||
if publisher is None:
|
||||
return None
|
||||
audio_msg = publisher.check_and_get_audio()
|
||||
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(
|
||||
self, trace_manager: TraceQueueManager | None = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
task_id = self._application_generate_entity.task_id
|
||||
publisher = None
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
|
||||
if (
|
||||
text_to_speech_dict
|
||||
and text_to_speech_dict.get("autoPlay") == "enabled"
|
||||
and text_to_speech_dict.get("enabled")
|
||||
):
|
||||
publisher = AppGeneratorTTSPublisher(
|
||||
tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None)
|
||||
)
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listen_audio_msg(publisher, task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
# timeout
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
if publisher is None:
|
||||
break
|
||||
audio = publisher.check_and_get_audio()
|
||||
if audio is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio.status == "finish":
|
||||
break
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id)
|
||||
if publisher:
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self, publisher: AppGeneratorTTSPublisher | None, trace_manager: TraceQueueManager | None = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self.queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
with Session(db.engine) as session:
|
||||
err = self.handle_error(event=event, session=session, message_id=self._message_id)
|
||||
session.commit()
|
||||
yield self.error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
if event.llm_result:
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
self._handle_stop(event)
|
||||
|
||||
# handle output moderation
|
||||
output_moderation_answer = self.handle_output_moderation_when_task_finished(
|
||||
cast(str, self._task_state.llm_result.message.content)
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=output_moderation_answer
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Save message
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
session.commit()
|
||||
message_end_resp = self._message_end_to_stream_response()
|
||||
yield message_end_resp
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
annotation = self._message_cycle_manager.handle_annotation_reply(event)
|
||||
if annotation:
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
agent_thought_response = self._agent_thought_to_stream_response(event)
|
||||
if agent_thought_response is not None:
|
||||
yield agent_thought_response
|
||||
elif isinstance(event, QueueMessageFileEvent):
|
||||
response = self._message_cycle_manager.message_file_to_stream_response(event)
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
continue
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
delta_text = ""
|
||||
for content in chunk.delta.message.content:
|
||||
logger.debug(
|
||||
"The content type %s in LLM chunk delta message content.: %r", type(content), content
|
||||
)
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
delta_text += content.data
|
||||
elif isinstance(content, str):
|
||||
delta_text += content # failback to str
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported content type %s in LLM chunk delta message content.: %r",
|
||||
type(content),
|
||||
content,
|
||||
)
|
||||
continue
|
||||
|
||||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
current_content = cast(str, self._task_state.llm_result.message.content)
|
||||
current_content += cast(str, delta_text)
|
||||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
)
|
||||
else:
|
||||
yield self._agent_message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self.ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None):
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
"""
|
||||
llm_result = self._task_state.llm_result
|
||||
usage = llm_result.usage
|
||||
|
||||
message_stmt = select(Message).where(Message.id == self._message_id)
|
||||
message = session.scalar(message_stmt)
|
||||
if not message:
|
||||
raise ValueError(f"message {self._message_id} not found")
|
||||
conversation_stmt = select(Conversation).where(Conversation.id == self._conversation_id)
|
||||
conversation = session.scalar(conversation_stmt)
|
||||
if not conversation:
|
||||
raise ValueError(f"Conversation {self._conversation_id} not found")
|
||||
|
||||
message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
self._model_config.mode, self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
message.message_tokens = usage.prompt_tokens
|
||||
message.message_unit_price = usage.prompt_unit_price
|
||||
message.message_price_unit = usage.prompt_price_unit
|
||||
message.answer = (
|
||||
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
|
||||
if llm_result.message.content
|
||||
else ""
|
||||
)
|
||||
message.updated_at = naive_utc_now()
|
||||
message.answer_tokens = usage.completion_tokens
|
||||
message.answer_unit_price = usage.completion_unit_price
|
||||
message.answer_price_unit = usage.completion_price_unit
|
||||
message.provider_response_latency = time.perf_counter() - self.start_at
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.llm_result.usage.latency = message.provider_response_latency
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
|
||||
)
|
||||
)
|
||||
|
||||
message_was_created.send(
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
)
|
||||
|
||||
def _handle_stop(self, event: QueueStopEvent):
|
||||
"""
|
||||
Handle stop.
|
||||
:return:
|
||||
"""
|
||||
model_config = self._model_config
|
||||
model = model_config.model
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = 0
|
||||
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages)
|
||||
|
||||
completion_tokens = 0
|
||||
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
|
||||
completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message])
|
||||
|
||||
credentials = model_config.credentials
|
||||
|
||||
# transform usage
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
self._task_state.llm_result.usage = model_type_instance.calc_response_usage(
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
"""
|
||||
Message end to stream response.
|
||||
:return:
|
||||
"""
|
||||
self._task_state.metadata.usage = self._task_state.llm_result.usage
|
||||
metadata_dict = self._task_state.metadata.model_dump()
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message_id,
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
"""
|
||||
Agent message to stream response.
|
||||
:param answer: answer
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
return AgentMessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id, id=message_id, answer=answer
|
||||
)
|
||||
|
||||
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> AgentThoughtStreamResponse | None:
|
||||
"""
|
||||
Agent thought to stream response.
|
||||
:param event: agent thought event
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
agent_thought: MessageAgentThought | None = (
|
||||
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
|
||||
if agent_thought:
|
||||
return AgentThoughtStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=agent_thought.id,
|
||||
position=agent_thought.position,
|
||||
thought=agent_thought.thought,
|
||||
observation=agent_thought.observation,
|
||||
tool=agent_thought.tool,
|
||||
tool_labels=agent_thought.tool_labels,
|
||||
tool_input=agent_thought.tool_input,
|
||||
message_files=agent_thought.files,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
"""
|
||||
Handle output moderation chunk.
|
||||
:param text: text
|
||||
:return: True if output moderation should direct output, otherwise False
|
||||
"""
|
||||
if self.output_moderation_handler:
|
||||
if self.output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output()
|
||||
self.queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=LLMResultChunk(
|
||||
model=self._task_state.llm_result.model,
|
||||
prompt_messages=self._task_state.llm_result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content),
|
||||
),
|
||||
)
|
||||
),
|
||||
PublishFrom.TASK_PIPELINE,
|
||||
)
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
else:
|
||||
self.output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
||||
12
dify/api/core/app/task_pipeline/exc.py
Normal file
12
dify/api/core/app/task_pipeline/exc.py
Normal file
@@ -0,0 +1,12 @@
|
||||
class TaskPipelineError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class RecordNotFoundError(TaskPipelineError):
|
||||
def __init__(self, record_name: str, record_id: str):
|
||||
super().__init__(f"{record_name} with id {record_id} not found")
|
||||
|
||||
|
||||
class WorkflowRunNotFoundError(RecordNotFoundError):
|
||||
def __init__(self, workflow_run_id: str):
|
||||
super().__init__("WorkflowRun", workflow_run_id)
|
||||
232
dify/api/core/app/task_pipeline/message_cycle_manager.py
Normal file
232
dify/api/core/app/task_pipeline/message_cycle_manager.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import logging
|
||||
from threading import Thread
|
||||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AnnotationReply,
|
||||
AnnotationReplyAccount,
|
||||
EasyUITaskState,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
StreamEvent,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageCycleManager:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
],
|
||||
task_state: Union[EasyUITaskState, WorkflowTaskState],
|
||||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
|
||||
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
|
||||
"""
|
||||
Generate conversation name.
|
||||
:param conversation_id: conversation id
|
||||
:param query: query
|
||||
:return: thread
|
||||
"""
|
||||
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
||||
return None
|
||||
|
||||
is_first_message = self._application_generate_entity.conversation_id is None
|
||||
extras = self._application_generate_entity.extras
|
||||
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
|
||||
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
thread = Thread(
|
||||
target=self._generate_conversation_name_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"conversation_id": conversation_id,
|
||||
"query": query,
|
||||
},
|
||||
)
|
||||
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
|
||||
return None
|
||||
|
||||
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
||||
with flask_app.app_context():
|
||||
# get conversation and message
|
||||
stmt = select(Conversation).where(Conversation.id == conversation_id)
|
||||
conversation = db.session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return
|
||||
|
||||
if conversation.mode != AppMode.COMPLETION:
|
||||
app_model = conversation.app
|
||||
if not app_model:
|
||||
return
|
||||
|
||||
# generate conversation name
|
||||
try:
|
||||
name = LLMGenerator.generate_conversation_name(
|
||||
app_model.tenant_id, query, conversation_id, conversation.app_id
|
||||
)
|
||||
conversation.name = name
|
||||
except Exception:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> MessageAnnotation | None:
|
||||
"""
|
||||
Handle annotation reply.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata.annotation_reply = AnnotationReply(
|
||||
id=annotation.id,
|
||||
account=AnnotationReplyAccount(
|
||||
id=annotation.account_id,
|
||||
name=account.name if account else "Dify user",
|
||||
),
|
||||
)
|
||||
|
||||
return annotation
|
||||
|
||||
return None
|
||||
|
||||
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent):
|
||||
"""
|
||||
Handle retriever resources.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
if not self._application_generate_entity.app_config.additional_features:
|
||||
raise ValueError("Additional features not found")
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
merged_resources = [r for r in self._task_state.metadata.retriever_resources or [] if r]
|
||||
existing_ids = {(r.dataset_id, r.document_id) for r in merged_resources if r.dataset_id and r.document_id}
|
||||
|
||||
# Add new unique resources from the event
|
||||
for resource in event.retriever_resources or []:
|
||||
if not resource:
|
||||
continue
|
||||
|
||||
is_duplicate = (
|
||||
resource.dataset_id
|
||||
and resource.document_id
|
||||
and (resource.dataset_id, resource.document_id) in existing_ids
|
||||
)
|
||||
|
||||
if not is_duplicate:
|
||||
merged_resources.append(resource)
|
||||
|
||||
for i, resource in enumerate(merged_resources, 1):
|
||||
resource.position = i
|
||||
|
||||
self._task_state.metadata.retriever_resources = merged_resources
|
||||
|
||||
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None:
|
||||
"""
|
||||
Message file to stream response.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
|
||||
|
||||
if message_file and message_file.url is not None:
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
|
||||
# get extension
|
||||
if "." in message_file.url:
|
||||
extension = f".{message_file.url.split('.')[-1]}"
|
||||
if len(extension) > 10:
|
||||
extension = ".bin"
|
||||
else:
|
||||
extension = ".bin"
|
||||
# add sign url to local file
|
||||
if message_file.url.startswith("http"):
|
||||
url = message_file.url
|
||||
else:
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
|
||||
return MessageFileStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_file.id,
|
||||
type=message_file.type,
|
||||
belongs_to=message_file.belongs_to or "user",
|
||||
url=url,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def message_to_stream_response(
|
||||
self, answer: str, message_id: str, from_variable_selector: list[str] | None = None
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
:param answer: answer
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
|
||||
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
|
||||
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer,
|
||||
from_variable_selector=from_variable_selector,
|
||||
event=event_type,
|
||||
)
|
||||
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
Message replace to stream response.
|
||||
:param answer: answer
|
||||
:return:
|
||||
"""
|
||||
return MessageReplaceStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id, answer=answer, reason=reason
|
||||
)
|
||||
Reference in New Issue
Block a user