This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

View File

@@ -0,0 +1,91 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
from core.app.app_config.features.suggested_questions_after_answer.manager import (
SuggestedQuestionsAfterAnswerConfigManager,
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.model import App, AppMode
from models.workflow import Workflow
class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
"""
Advanced Chatbot App Config Entity.
"""
pass
class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=app_mode,
workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
additional_features=cls.convert_features(features_dict, app_mode),
)
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
"""
Validate for advanced chat app model config
:param tenant_id: tenant id
:param config: app model config args
:param only_structure_validate: if True, only structure validation will be performed
"""
related_config_keys = []
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
related_config_keys.extend(current_related_config_keys)
# opening_statement
config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config
)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# return retriever resource
config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@@ -0,0 +1,611 @@
import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import (
DraftVarLoader,
WorkflowDraftVariableService,
)
logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
_dialogue_count: int
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param streaming: is stream
"""
if not args.get("query"):
raise ValueError("query is required")
query = args["query"]
if not isinstance(query, str):
raise ValueError("query must be a string")
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {
"auto_generate_conversation_name": args.get("auto_generate_name", False),
**extract_external_trace_id_from_args(args),
}
# get conversation
conversation = None
conversation_id = args.get("conversation_id")
if conversation_id:
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
)
else:
file_objs = []
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True # type: ignore
workflow_run_id = str(uuid.uuid4())
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
workflow_run_id=workflow_run_id,
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
return self._generate(
workflow=workflow,
user=user,
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
stream=streaming,
)
def single_iteration_generate(
self,
app_model: App,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
conversation_id=None,
inputs={},
query="",
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
return self._generate(
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
variable_loader=var_loader,
)
def single_loop_generate(
self,
app_model: App,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is stream
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
conversation_id=None,
inputs={},
query="",
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
return self._generate(
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
variable_loader=var_loader,
)
def _generate(
self,
*,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Conversation | None = None,
stream: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
:param workflow: Workflow
:param user: account or end user
:param invoke_from: invoke from source
:param application_generate_entity: application generate entity
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param conversation: conversation
:param stream: is stream
"""
is_first_conversation = False
if not conversation:
is_first_conversation = True
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation:
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
db.session.refresh(conversation)
# get conversation dialogue count
# NOTE: dialogue_count should not start from 0,
# because during the first conversation, dialogue_count should be 1.
self._dialogue_count = get_thread_messages_length(conversation.id) + 1
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context and contextvars
context = contextvars.copy_context()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
"context": context,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
},
)
worker_thread.start()
# release database connection, because the following new thread operations may take a long time
db.session.refresh(workflow)
db.session.refresh(message)
# db.session.refresh(user)
db.session.close()
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
)
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
):
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation_id: conversation ID
:param message_id: message ID
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
Workflow.app_id == application_generate_entity.app_config.app_id,
Workflow.id == application_generate_entity.app_config.workflow_id,
)
)
if workflow is None:
raise ValueError("Workflow not found")
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}
if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id))
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id
app = session.scalar(select(App).where(App.id == application_generate_entity.app_config.app_id))
if app is None:
raise ValueError("App not found")
runner = AdvancedChatAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
dialogue_count=self._dialogue_count,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
app=app,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
try:
runner.run()
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def _handle_advanced_chat_response(
self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param conversation: conversation
:param message: message
:param user: account or end user
:param stream: is stream
:return:
"""
# init generate task pipeline
generate_task_pipeline = AdvancedChatAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
dialogue_count=self._dialogue_count,
stream=stream,
draft_var_saver_factory=draft_var_saver_factory,
)
try:
return generate_task_pipeline.process()
except ValueError as e:
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
raise e

View File

@@ -0,0 +1,405 @@
import logging
import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
QueueStopEvent,
QueueTextChunkEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable
logger = logging.getLogger(__name__)
class AdvancedChatAppRunner(WorkflowBasedAppRunner):
"""
AdvancedChat Application Runner
"""
def __init__(
self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
dialogue_count: int,
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
app: App,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
):
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
graph_engine_layers=graph_engine_layers,
)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
self._dialogue_count = dialogue_count
self._workflow = workflow
self.system_user_id = system_user_id
self._app = app
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
def run(self):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
system_inputs = SystemVariable(
query=self.application_generate_entity.query,
files=self.application_generate_entity.files,
conversation_id=self.conversation.id,
user_id=self.system_user_id,
dialogue_count=self._dialogue_count,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_run_id,
)
with Session(db.engine, expire_on_commit=False) as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
if not app_record:
raise ValueError("App not found")
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs
query = self.application_generate_entity.query
# moderation
if self.handle_input_moderation(
app_record=self._app,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
message_id=self.message.id,
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=self._app,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity,
):
return
# Initialize conversation variables
conversation_variables = self._initialize_conversation_variables()
# Create a variable pool.
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=conversation_variables,
)
# init graph
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
)
db.session.close()
# RUN WORKFLOW
# Create Redis command channel for this workflow execution
task_id = self.application_generate_entity.task_id
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
workflow_id=self._workflow.id,
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
)
self._queue_manager.graph_runtime_state = graph_runtime_state
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=self._workflow.id,
workflow_type=WorkflowType(self._workflow.type),
version=self._workflow.version,
graph_data=self._workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)
generator = workflow_entry.run()
for event in generator:
self._handle_event(workflow_entry, event)
def handle_input_moderation(
self,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
) -> bool:
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
app_id=app_record.id,
tenant_id=app_generate_entity.app_config.tenant_id,
app_generate_entity=app_generate_entity,
inputs=inputs,
query=query,
message_id=message_id,
)
except ModerationError as e:
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
return True
return False
def handle_annotation_reply(
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
) -> bool:
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
message=message,
query=query,
user_id=app_generate_entity.user_id,
invoke_from=app_generate_entity.invoke_from,
)
if annotation_reply:
self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id))
self._complete_with_stream_output(
text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
)
return True
return False
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy):
"""
Direct output
"""
self._publish_event(QueueTextChunkEvent(text=text))
self._publish_event(QueueStopEvent(stopped_by=stopped_by))
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> MessageAnnotation | None:
"""
Query app annotations to reply
:param app_record: app record
:param message: message
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:return:
"""
annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query(
app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
)
def moderation_for_inputs(
self,
*,
app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
query: str | None = None,
message_id: str,
) -> tuple[bool, Mapping[str, Any], str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
:param tenant_id: tenant id
:param app_generate_entity: app generate entity
:param inputs: inputs
:param query: query
:param message_id: message id
:return:
"""
moderation_feature = InputModeration()
return moderation_feature.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=dict(inputs),
query=query or "",
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
)
def _initialize_conversation_variables(self) -> list[VariableUnion]:
"""
Initialize conversation variables for the current conversation.
This method:
1. Loads existing variables from the database
2. Creates new variables if none exist
3. Syncs missing variables from the workflow definition
:return: List of conversation variables ready for use
"""
with Session(db.engine) as session:
existing_variables = self._load_existing_conversation_variables(session)
if not existing_variables:
# First time initialization - create all variables
existing_variables = self._create_all_conversation_variables(session)
else:
# Check and add any missing variables from the workflow
existing_variables = self._sync_missing_conversation_variables(session, existing_variables)
# Convert to Variable objects for use in the workflow
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
return cast(list[VariableUnion], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""
Load existing conversation variables from the database.
:param session: Database session
:return: List of existing conversation variables
"""
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id,
ConversationVariable.conversation_id == self.conversation.id,
)
return list(session.scalars(stmt).all())
def _create_all_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""
Create all conversation variables for a new conversation.
:param session: Database session
:return: List of created conversation variables
"""
new_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in self._workflow.conversation_variables
]
if new_variables:
session.add_all(new_variables)
return new_variables
def _sync_missing_conversation_variables(
self, session: Session, existing_variables: list[ConversationVariable]
) -> list[ConversationVariable]:
"""
Sync missing conversation variables from the workflow definition.
This handles the case where new variables are added to a workflow
after conversations have already been created.
:param session: Database session
:param existing_variables: List of existing conversation variables
:return: Updated list including any newly created variables
"""
# Get IDs of existing and workflow variables
existing_ids = {var.id for var in existing_variables}
workflow_variables = {var.id: var for var in self._workflow.conversation_variables}
# Find missing variable IDs
missing_ids = set(workflow_variables.keys()) - existing_ids
if not missing_ids:
return existing_variables
# Create missing variables with their default values
new_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id,
conversation_id=self.conversation.id,
variable=workflow_variables[var_id],
)
for var_id in missing_ids
]
session.add_all(new_variables)
# Return combined list
return existing_variables + new_variables

View File

@@ -0,0 +1,125 @@
from collections.abc import Generator
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppBlockingResponse,
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
)
class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
response = {
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
}
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response
:return:
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
return response
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
"""
Convert stream full response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(ChatbotAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk: dict[str, Any] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
"""
Convert stream simple response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(ChatbotAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk: dict[str, Any] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@@ -0,0 +1,891 @@
import json
import logging
import re
import time
from collections.abc import Callable, Generator, Mapping
from contextlib import contextmanager
from threading import Thread
from typing import Any, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAdvancedChatMessageEndEvent,
QueueAgentLogEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueLoopCompletedEvent,
QueueLoopNextEvent,
QueueLoopStartEvent,
QueueMessageReplaceEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
from core.app.entities.task_entities import (
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
StreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool,
dialogue_count: int,
draft_var_saver_factory: DraftVariableSaverFactory,
):
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
)
if isinstance(user, EndUser):
self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatorUserRole.END_USER
elif isinstance(user, Account):
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatorUserRole.ACCOUNT
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_system_variables = SystemVariable(
query=message.query,
files=application_generate_entity.files,
conversation_id=conversation.id,
user_id=user_session_id,
dialogue_count=dialogue_count,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_run_id,
)
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=self._workflow_system_variables,
)
self._task_state = WorkflowTaskState()
self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity, task_state=self._task_state
)
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._conversation_name_generate_thread: Thread | None = None
self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id: str = ""
self._draft_var_saver_factory = draft_var_saver_factory
self._graph_runtime_state: GraphRuntimeState | None = None
self._seed_graph_runtime_state_from_queue_manager()
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
:return:
"""
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
"""
Process blocking response.
:return:
"""
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
extras["metadata"] = stream_response.metadata
return ChatbotAppBlockingResponse(
task_id=stream_response.task_id,
data=ChatbotAppBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
answer=self._task_state.answer,
created_at=self._message_created_at,
**extras,
),
)
else:
continue
raise ValueError("queue listening stopped unexpectedly.")
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:
"""
To stream response.
:return:
"""
for stream_response in generator:
yield ChatbotAppStreamResponse(
conversation_id=self._conversation_id,
message_id=self._message_id,
created_at=self._message_created_at,
stream_response=stream_response,
)
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher:
return None
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(
self, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow_features_dict
if (
features_dict.get("text_to_speech")
and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
)
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not tts_publisher:
break
audio_trunk = tts_publisher.check_and_get_audio()
if audio_trunk is None:
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio_trunk.status == "finish":
break
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception:
logger.exception("Failed to listen audio message, task_id: %s", task_id)
break
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
@contextmanager
def _database_session(self):
"""Context manager for database sessions."""
with Session(db.engine, expire_on_commit=False) as session:
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline.ping_stream_response()
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events."""
with self._database_session() as session:
err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event(
self,
event: QueueWorkflowStartedEvent,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
runtime_state = self._resolve_graph_runtime_state()
run_id = self._extract_workflow_run_id(runtime_state)
self._workflow_run_id = run_id
with self._database_session() as session:
message = self._get_message(session=session)
if not message:
raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = run_id
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow_id,
)
yield workflow_start_resp
def _handle_node_retry_event(self, event: QueueNodeRetryEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle node retry events."""
self._ensure_workflow_initialized()
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
)
if node_retry_resp:
yield node_retry_resp
def _handle_node_started_event(
self, event: QueueNodeStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle node started events."""
self._ensure_workflow_initialized()
node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
)
if node_start_resp:
yield node_start_resp
def _handle_node_succeeded_event(
self, event: QueueNodeSucceededEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle node succeeded events."""
# Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]:
self._recorded_files.extend(
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
)
self._save_output_for_event(event, event.node_execution_id)
if node_finish_resp:
yield node_finish_resp
def _handle_node_failed_events(
self,
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, event.node_execution_id)
if node_finish_resp:
yield node_finish_resp
def _handle_text_chunk_event(
self,
event: QueueTextChunkEvent,
*,
tts_publisher: AppGeneratorTTSPublisher | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle text chunk events."""
delta_text = event.text
if delta_text is None:
return
# Handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
if should_direct_answer:
return
current_time = time.perf_counter()
if self._task_state.first_token_time is None and delta_text.strip():
self._task_state.first_token_time = current_time
self._task_state.is_streaming_response = True
if delta_text.strip():
self._task_state.last_token_time = current_time
# Only publish tts message at text chunk streaming
if tts_publisher and queue_message:
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._message_cycle_manager.message_to_stream_response(
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
)
def _handle_iteration_start_event(
self, event: QueueIterationStartEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle iteration start events."""
self._ensure_workflow_initialized()
iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_start_resp
def _handle_iteration_next_event(
self, event: QueueIterationNextEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle iteration next events."""
self._ensure_workflow_initialized()
iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_next_resp
def _handle_iteration_completed_event(
self, event: QueueIterationCompletedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle iteration completed events."""
self._ensure_workflow_initialized()
iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_finish_resp
def _handle_loop_start_event(self, event: QueueLoopStartEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle loop start events."""
self._ensure_workflow_initialized()
loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_start_resp
def _handle_loop_next_event(self, event: QueueLoopNextEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle loop next events."""
self._ensure_workflow_initialized()
loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_next_resp
def _handle_loop_completed_event(
self, event: QueueLoopCompletedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle loop completed events."""
self._ensure_workflow_initialized()
loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_finish_resp
def _handle_workflow_succeeded_event(
self,
event: QueueWorkflowSucceededEvent,
*,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.SUCCEEDED,
graph_runtime_state=validated_state,
)
yield workflow_finish_resp
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
def _handle_workflow_partial_success_event(
self,
event: QueueWorkflowPartialSuccessEvent,
*,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
graph_runtime_state=validated_state,
exceptions_count=event.exceptions_count,
)
yield workflow_finish_resp
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
def _handle_workflow_failed_event(
self,
event: QueueWorkflowFailedEvent,
*,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.FAILED,
graph_runtime_state=validated_state,
error=event.error,
exceptions_count=event.exceptions_count,
)
with self._database_session() as session:
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {event.error}"))
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
yield workflow_finish_resp
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_stop_event(
self,
event: QueueStopEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle stop events."""
_ = trace_manager
resolved_state = None
if self._workflow_run_id:
resolved_state = self._resolve_graph_runtime_state(graph_runtime_state)
if self._workflow_run_id and resolved_state:
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.STOPPED,
graph_runtime_state=resolved_state,
error=event.get_stop_reason(),
)
with self._database_session() as session:
# Save message
self._save_message(session=session, graph_runtime_state=resolved_state)
yield workflow_finish_resp
elif event.stopped_by in (
QueueStopEvent.StopBy.INPUT_MODERATION,
QueueStopEvent.StopBy.ANNOTATION_REPLY,
):
# When hitting input-moderation or annotation-reply, the workflow will not start
with self._database_session() as session:
# Save message
self._save_message(session=session)
yield self._message_end_to_stream_response()
def _handle_advanced_chat_message_end_event(
self,
event: QueueAdvancedChatMessageEndEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
resolved_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
self._task_state.answer
)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=output_moderation_answer,
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
)
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
yield self._message_end_to_stream_response()
def _handle_retriever_resources_event(
self, event: QueueRetrieverResourcesEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle retriever resources events."""
self._message_cycle_manager.handle_retriever_resources(event)
return
yield # Make this a generator
def _handle_annotation_reply_event(
self, event: QueueAnnotationReplyEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle annotation reply events."""
self._message_cycle_manager.handle_annotation_reply(event)
return
yield # Make this a generator
def _handle_message_replace_event(
self, event: QueueMessageReplaceEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle message replace events."""
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason)
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle agent log events."""
yield self._workflow_response_converter.handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
def _get_event_handlers(self) -> dict[type, Callable]:
"""Get mapping of event types to their handlers using fluent pattern."""
return {
# Basic events
QueuePingEvent: self._handle_ping_event,
QueueErrorEvent: self._handle_error_event,
QueueTextChunkEvent: self._handle_text_chunk_event,
# Workflow events
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
QueueWorkflowFailedEvent: self._handle_workflow_failed_event,
# Node events
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
# Iteration events
QueueIterationStartEvent: self._handle_iteration_start_event,
QueueIterationNextEvent: self._handle_iteration_next_event,
QueueIterationCompletedEvent: self._handle_iteration_completed_event,
# Loop events
QueueLoopStartEvent: self._handle_loop_start_event,
QueueLoopNextEvent: self._handle_loop_next_event,
QueueLoopCompletedEvent: self._handle_loop_completed_event,
# Control events
QueueStopEvent: self._handle_stop_event,
# Message events
QueueRetrieverResourcesEvent: self._handle_retriever_resources_event,
QueueAnnotationReplyEvent: self._handle_annotation_reply_event,
QueueMessageReplaceEvent: self._handle_message_replace_event,
QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event,
QueueAgentLogEvent: self._handle_agent_log_event,
}
def _dispatch_event(
self,
event: Any,
*,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
) -> Generator[StreamResponse, None, None]:
"""Dispatch events using elegant pattern matching."""
handlers = self._get_event_handlers()
event_type = type(event)
# Direct handler lookup
if handler := handlers.get(event_type):
yield from handler(
event,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# Handle node failure events with isinstance check
if isinstance(
event,
(
QueueNodeFailedEvent,
QueueNodeExceptionEvent,
),
):
yield from self._handle_node_failed_events(
event,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# For unhandled events, we continue (original behavior)
return
def _process_stream_response(
self,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response using elegant Fluent Python patterns.
Maintains exact same functionality as original 57-if-statement version.
"""
for queue_message in self._base_task_pipeline.queue_manager.listen():
event = queue_message.event
match event:
case QueueWorkflowStartedEvent():
self._resolve_graph_runtime_state()
yield from self._handle_workflow_started_event(event)
case QueueErrorEvent():
yield from self._handle_error_event(event)
break
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
break
case QueueStopEvent():
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
break
# Handle all other events through elegant dispatch
case _:
if responses := list(
self._dispatch_event(
event,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
):
yield from responses
if tts_publisher:
tts_publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer
answer_text = self._task_state.answer
if self._recorded_files:
# Remove markdown image links since we're storing files separately
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
message.answer = answer_text
message.updated_at = naive_utc_now()
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
# Set usage first before dumping metadata
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
message.message_tokens = usage.prompt_tokens
message.message_unit_price = usage.prompt_unit_price
message.message_price_unit = usage.prompt_price_unit
message.answer_tokens = usage.completion_tokens
message.answer_unit_price = usage.completion_unit_price
message.answer_price_unit = usage.completion_price_unit
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.metadata.usage = usage
else:
usage = LLMUsage.empty_usage()
self._task_state.metadata.usage = usage
# Add streaming metrics to usage if available
if self._task_state.is_streaming_response and self._task_state.first_token_time:
start_time = self._base_task_pipeline.start_at
first_token_time = self._task_state.first_token_time
last_token_time = self._task_state.last_token_time or first_token_time
usage.time_to_first_token = round(first_token_time - start_time, 3)
usage.time_to_generate = round(last_token_time - first_token_time, 3)
metadata = self._task_state.metadata.model_dump()
message.message_metadata = json.dumps(jsonable_encoder(metadata))
message_files = [
MessageFile(
message_id=message.id,
type=file["type"],
transfer_method=file["transfer_method"],
url=file["remote_url"],
belongs_to="assistant",
upload_file_id=file["related_id"],
created_by_role=CreatorUserRole.ACCOUNT
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatorUserRole.END_USER,
created_by=message.from_account_id or message.from_end_user_id or "",
)
for file in self._recorded_files
]
session.add_all(message_files)
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
if candidate is not None:
self._graph_runtime_state = candidate
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
"""
Message end to stream response.
:return:
"""
extras = self._task_state.metadata.model_dump()
if self._task_state.metadata.annotation_reply:
del extras["annotation_reply"]
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message_id,
files=self._recorded_files,
metadata=extras,
)
def _handle_output_moderation_chunk(self, text: str) -> bool:
"""
Handle output moderation chunk.
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._base_task_pipeline.output_moderation_handler:
if self._base_task_pipeline.output_moderation_handler.should_direct_output():
self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output()
self._base_task_pipeline.queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
)
self._base_task_pipeline.queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
)
return True
else:
self._base_task_pipeline.output_moderation_handler.append_new_token(text)
return False
def _get_message(self, *, session: Session):
stmt = select(Message).where(Message.id == self._message_id)
message = session.scalar(stmt)
if not message:
raise ValueError(f"Message not found: {self._message_id}")
return message
def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
with Session(db.engine) as session, session.begin():
saver = self._draft_var_saver_factory(
session=session,
app_id=self._application_generate_entity.app_config.app_id,
node_id=event.node_id,
node_type=event.node_type,
node_execution_id=node_execution_id,
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
)
saver.save(event.process_data, event.outputs)

View File

@@ -0,0 +1,236 @@
import uuid
from collections.abc import Mapping
from typing import Any, cast
from core.agent.entities import AgentEntity
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager
from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
from core.app.app_config.features.suggested_questions_after_answer.manager import (
SuggestedQuestionsAfterAnswerConfigManager,
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.entities.agent_entities import PlanningStrategy
from models.model import App, AppMode, AppModelConfig, Conversation
OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
class AgentChatAppConfig(EasyUIBasedAppConfig):
"""
Agent Chatbot App Config Entity.
"""
agent: AgentEntity | None = None
class AgentChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(
cls,
app_model: App,
app_model_config: AppModelConfig,
conversation: Conversation | None = None,
override_config_dict: dict | None = None,
) -> AgentChatAppConfig:
"""
Convert app model config to agent chat app config
:param app_model: app model
:param app_model_config: app model config
:param conversation: conversation
:param override_config_dict: app model config dict
:return:
"""
if override_config_dict:
config_from = EasyUIBasedAppModelConfigFrom.ARGS
elif conversation:
config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG
else:
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict or {}
app_mode = AppMode.value_of(app_model.mode)
app_config = AgentChatAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=app_mode,
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
dataset=DatasetConfigManager.convert(config=config_dict),
agent=AgentConfigManager.convert(config=config_dict),
additional_features=cls.convert_features(config_dict, app_mode),
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
config=config_dict
)
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
"""
Validate for agent chat app model config
:param tenant_id: tenant id
:param config: app model config args
"""
app_mode = AppMode.AGENT_CHAT
related_config_keys = []
# model
config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# user_input_form
config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# prompt
config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config)
related_config_keys.extend(current_related_config_keys)
# agent_mode
config, current_related_config_keys = cls.validate_agent_mode_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# opening_statement
config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config
)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# return retriever resource
config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# dataset configs
# dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id, config
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config
@classmethod
def validate_agent_mode_and_set_defaults(
cls, tenant_id: str, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""
Validate agent_mode and set defaults for agent feature
:param tenant_id: tenant ID
:param config: app model config args
"""
if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []}
agent_mode = config["agent_mode"]
if not isinstance(agent_mode, dict):
raise ValueError("agent_mode must be of object type")
# FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing
agent_mode = cast(dict[str, Any], agent_mode)
if "enabled" not in agent_mode or not agent_mode["enabled"]:
agent_mode["enabled"] = False
if not isinstance(agent_mode["enabled"], bool):
raise ValueError("enabled in agent_mode must be of boolean type")
if not agent_mode.get("strategy"):
agent_mode["strategy"] = PlanningStrategy.ROUTER
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")
if not agent_mode.get("tools"):
agent_mode["tools"] = []
if not isinstance(agent_mode["tools"], list):
raise ValueError("tools in agent_mode must be a list of objects")
for tool in agent_mode["tools"]:
key = list(tool.keys())[0]
if key in OLD_TOOLS:
# old style, use tool name as key
tool_item = tool[key]
if "enabled" not in tool_item or not tool_item["enabled"]:
tool_item["enabled"] = False
if not isinstance(tool_item["enabled"], bool):
raise ValueError("enabled in agent_mode.tools must be of boolean type")
if key == "dataset":
if "id" not in tool_item:
raise ValueError("id is required in dataset")
try:
uuid.UUID(tool_item["id"])
except ValueError:
raise ValueError("id in dataset must be of UUID type")
if not DatasetConfigManager.is_dataset_exists(tenant_id, tool_item["id"]):
raise ValueError("Dataset ID does not exist, please check your permission.")
else:
# latest style, use key-value pair
if "enabled" not in tool or not tool["enabled"]:
tool["enabled"] = False
if "provider_type" not in tool:
raise ValueError("provider_type is required in agent_mode.tools")
if "provider_id" not in tool:
raise ValueError("provider_id is required in agent_mode.tools")
if "tool_name" not in tool:
raise ValueError("tool_name is required in agent_mode.tools")
if "tool_parameters" not in tool:
raise ValueError("tool_parameters is required in agent_mode.tools")
return config, ["agent_mode"]

View File

@@ -0,0 +1,266 @@
import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser
from services.conversation_service import ConversationService
logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
"""
Generate App response.
:param app_model: App
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param streaming: is stream
"""
if not streaming:
raise ValueError("Agent Chat App does not support blocking mode")
if not args.get("query"):
raise ValueError("query is required")
query = args["query"]
if not isinstance(query, str):
raise ValueError("query must be a string")
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
# get conversation
conversation = None
conversation_id = args.get("conversation_id")
if conversation_id:
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
# validate override model config
override_model_config_dict = None
if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError("Only in App debug mode can override model config")
# validate config
override_model_config_dict = AgentChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id,
config=args["model_config"],
)
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True}
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args.get("files") or []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
)
else:
file_objs = []
# convert to app config
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict,
)
# get tracing instance
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
# init application generate entity
application_generate_entity = AgentChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras=extras,
call_depth=0,
trace_manager=trace_manager,
)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context and contextvars
context = contextvars.copy_context()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,
flask_app: Flask,
context: contextvars.Context,
application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
):
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation_id: conversation ID
:param message_id: message ID
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
# chatbot app
runner = AgentChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
)
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()

View File

@@ -0,0 +1,239 @@
import logging
from typing import cast
from sqlalchemy import select
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity
from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationError
from extensions.ext_database import db
from models.model import App, Conversation, Message
logger = logging.getLogger(__name__)
class AgentChatAppRunner(AppRunner):
"""
Agent Application Runner
"""
def run(
self,
application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
):
"""
Run assistant application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
:return:
"""
app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)
app_stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(app_stmt)
if not app_record:
raise ValueError("App not found")
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
memory = None
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model,
)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional)
prompt_messages, _ = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query,
memory=memory,
)
# moderation
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
app_id=app_record.id,
tenant_id=app_config.tenant_id,
app_generate_entity=application_generate_entity,
inputs=dict(inputs),
query=query or "",
message_id=message.id,
)
except ModerationError as e:
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream,
)
return
if query:
# annotation reply
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER,
)
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=annotation_reply.content,
stream=application_generate_entity.stream,
)
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_config.external_data_variables
if external_data_tools:
inputs = self.fill_in_inputs_from_external_data_tools(
tenant_id=app_record.tenant_id,
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query,
)
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional), external data, dataset context(optional)
prompt_messages, _ = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query,
memory=memory,
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages,
)
if hosting_moderation_result:
return
agent_entity = app_config.agent
assert agent_entity is not None
# init model instance
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model,
)
prompt_message, _ = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query,
memory=memory,
)
# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if not model_schema:
raise ValueError("Model schema not found")
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
conversation_result = db.session.scalar(conversation_stmt)
if conversation_result is None:
raise ValueError("Conversation not found")
msg_stmt = select(Message).where(Message.id == message.id)
message_result = db.session.scalar(msg_stmt)
if message_result is None:
raise ValueError("Message not found")
db.session.close()
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
# check LLM mode
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
runner_cls = CotChatAgentRunner
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
runner_cls = CotCompletionAgentRunner
else:
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
runner_cls = FunctionCallAgentRunner
else:
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
runner = runner_cls(
tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity,
conversation=conversation_result,
app_config=app_config,
model_config=application_generate_entity.model_conf,
config=agent_entity,
queue_manager=queue_manager,
message=message_result,
user_id=application_generate_entity.user_id,
memory=memory,
prompt_messages=prompt_message,
model_instance=model_instance,
)
invoke_result = runner.run(
message=message,
query=query,
inputs=inputs,
)
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream,
agent=True,
)

View File

@@ -0,0 +1,122 @@
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
)
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
response = {
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
}
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
:return:
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(ChatbotAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(ChatbotAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@@ -0,0 +1,132 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from typing import Any, Union
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
logger = logging.getLogger(__name__)
class AppGenerateResponseConverter(ABC):
_blocking_response_type: type[AppBlockingResponse]
@classmethod
def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:
def _generate_full_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_full_response(response)
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
else:
def _generate_simple_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_simple_response(response)
return _generate_simple_response()
@classmethod
@abstractmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
raise NotImplementedError
@classmethod
def _get_simple_metadata(cls, metadata: dict[str, Any]):
"""
Get simple metadata.
:param metadata: metadata
:return:
"""
# show_retrieve_source
updated_resources = []
if "retriever_resources" in metadata:
for resource in metadata["retriever_resources"]:
updated_resources.append(
{
"segment_id": resource.get("segment_id", ""),
"position": resource["position"],
"document_name": resource["document_name"],
"score": resource["score"],
"content": resource["content"],
}
)
metadata["retriever_resources"] = updated_resources
# show annotation reply
if "annotation_reply" in metadata:
del metadata["annotation_reply"]
# show usage
if "usage" in metadata:
del metadata["usage"]
return metadata
@classmethod
def _error_to_stream_response(cls, e: Exception):
"""
Error to stream response.
:param e: exception
:return:
"""
error_responses = {
ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: {
"code": "provider_quota_exceeded",
"message": "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.",
"status": 400,
},
ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400},
InvokeError: {"code": "completion_request_error", "status": 400},
}
# Determine the response based on the type of exception
data = None
for k, v in error_responses.items():
if isinstance(e, k):
data = v
if data:
data.setdefault("message", getattr(e, "description", str(e)))
else:
logger.error(e)
data = {
"code": "internal_server_error",
"message": "Internal Server Error, please contact support.",
"status": 500,
}
return data

View File

@@ -0,0 +1,232 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileUploadConfig
from core.workflow.enums import NodeType
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaver,
DraftVariableSaverFactory,
NoopDraftVariableSaver,
)
from factories import file_factory
from libs.orjson import orjson_dumps
from models import Account, EndUser
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING:
from core.app.app_config.entities import VariableEntity
class BaseAppGenerator:
def _prepare_user_inputs(
self,
*,
user_inputs: Mapping[str, Any] | None,
variables: Sequence["VariableEntity"],
tenant_id: str,
strict_type_validation: bool = False,
) -> Mapping[str, Any]:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
user_inputs = {
var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var)
for var in variables
}
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
# Convert files in inputs to File
entity_dictionary = {item.variable: item for item in variables}
# Convert single file to File
files_inputs = {
k: file_factory.build_from_mapping(
mapping=v,
tenant_id=tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types or [],
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
),
strict_type_validation=strict_type_validation,
)
for k, v in user_inputs.items()
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
}
# Convert list of files to File
file_list_inputs = {
k: file_factory.build_from_mappings(
mappings=v,
tenant_id=tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types or [],
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
),
)
for k, v in user_inputs.items()
if isinstance(v, list)
# Ensure skip List<File>
and all(isinstance(item, dict) for item in v)
and entity_dictionary[k].type == VariableEntityType.FILE_LIST
}
# Merge all inputs
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
# Check if all files are converted to File
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
raise ValueError("Invalid input type")
if any(
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
):
raise ValueError("Invalid input type")
return user_inputs
def _validate_inputs(
self,
*,
variable_entity: "VariableEntity",
value: Any,
):
if value is None:
if variable_entity.required:
raise ValueError(f"{variable_entity.variable} is required in input form")
# Use default value and continue validation to ensure type conversion
value = variable_entity.default
# If default is also None, return None directly
if value is None:
return None
if variable_entity.type in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
} and not isinstance(value, str):
raise ValueError(
f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string"
)
if variable_entity.type == VariableEntityType.NUMBER:
if isinstance(value, (int, float)):
return value
elif isinstance(value, str):
# handle empty string case
if not value.strip():
return None
# may raise ValueError if user_input_value is not a valid number
try:
if "." in value:
return float(value)
else:
return int(value)
except ValueError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
else:
raise TypeError(f"expected value type int, float or str, got {type(value)}, value: {value}")
match variable_entity.type:
case VariableEntityType.SELECT:
if value not in variable_entity.options:
raise ValueError(
f"{variable_entity.variable} in input form must be one of the following: "
f"{variable_entity.options}"
)
case VariableEntityType.TEXT_INPUT | VariableEntityType.PARAGRAPH:
if variable_entity.max_length and len(value) > variable_entity.max_length:
raise ValueError(
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} "
"characters"
)
case VariableEntityType.FILE:
if not isinstance(value, dict) and not isinstance(value, File):
raise ValueError(f"{variable_entity.variable} in input form must be a file")
case VariableEntityType.FILE_LIST:
# if number of files exceeds the limit, raise ValueError
if not (
isinstance(value, list)
and (all(isinstance(item, dict) for item in value) or all(isinstance(item, File) for item in value))
):
raise ValueError(f"{variable_entity.variable} in input form must be a list of files")
if variable_entity.max_length and len(value) > variable_entity.max_length:
raise ValueError(
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
)
case VariableEntityType.CHECKBOX:
if isinstance(value, str):
normalized_value = value.strip().lower()
if normalized_value in {"true", "1", "yes", "on"}:
value = True
elif normalized_value in {"false", "0", "no", "off"}:
value = False
elif isinstance(value, (int, float)):
if value == 1:
value = True
elif value == 0:
value = False
case _:
raise AssertionError("this statement should be unreachable.")
return value
def _sanitize_value(self, value: Any):
if isinstance(value, str):
return value.replace("\x00", "")
return value
@classmethod
def convert_to_event_stream(cls, generator: Union[Mapping, Generator[Mapping | str, None, None]]):
"""
Convert messages into event stream
"""
if isinstance(generator, dict):
return generator
else:
def gen():
for message in generator:
if isinstance(message, Mapping | dict):
yield f"data: {orjson_dumps(message)}\n\n"
else:
yield f"event: {message}\n\n"
return gen()
@final
@staticmethod
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
if invoke_from == InvokeFrom.DEBUGGER:
assert isinstance(account, Account)
def draft_var_saver_factory(
session: Session,
app_id: str,
node_id: str,
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> DraftVariableSaver:
return DraftVariableSaverImpl(
session=session,
app_id=app_id,
node_id=node_id,
node_type=node_type,
node_execution_id=node_execution_id,
enclosing_node_id=enclosing_node_id,
user=account,
)
else:
def draft_var_saver_factory(
session: Session,
app_id: str,
node_id: str,
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> DraftVariableSaver:
return NoopDraftVariableSaver()
return draft_var_saver_factory

View File

@@ -0,0 +1,219 @@
import logging
import queue
import threading
import time
from abc import abstractmethod
from enum import IntEnum, auto
from typing import Any
from cachetools import TTLCache, cachedmethod
from redis.exceptions import RedisError
from sqlalchemy.orm import DeclarativeMeta
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueErrorEvent,
QueuePingEvent,
QueueStopEvent,
WorkflowQueueMessage,
)
from core.workflow.runtime import GraphRuntimeState
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class PublishFrom(IntEnum):
APPLICATION_MANAGER = auto()
TASK_PIPELINE = auto()
class AppQueueManager:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom):
if not user_id:
raise ValueError("user is required")
self._task_id = task_id
self._user_id = user_id
self._invoke_from = invoke_from
self.invoke_from = invoke_from # Public accessor for invoke_from
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
self._task_belong_cache_key = AppQueueManager._generate_task_belong_cache_key(self._task_id)
redis_client.setex(self._task_belong_cache_key, 1800, f"{user_prefix}-{self._user_id}")
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self._q = q
self._graph_runtime_state: GraphRuntimeState | None = None
self._stopped_cache: TTLCache[tuple, bool] = TTLCache(maxsize=1, ttl=1)
self._cache_lock = threading.Lock()
def listen(self):
"""
Listen to queue
:return:
"""
# wait for APP_MAX_EXECUTION_TIME seconds to stop listen
listen_timeout = dify_config.APP_MAX_EXECUTION_TIME
start_time = time.time()
last_ping_time: int | float = 0
while True:
try:
message = self._q.get(timeout=1)
if message is None:
break
yield message
except queue.Empty:
continue
finally:
elapsed_time = time.time() - start_time
if elapsed_time >= listen_timeout or self._is_stopped():
# publish two messages to make sure the client can receive the stop signal
# and stop listening after the stop signal processed
self.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE
)
if elapsed_time // 10 > last_ping_time:
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
last_ping_time = elapsed_time // 10
def stop_listen(self):
"""
Stop listen to queue
:return:
"""
self._clear_task_belong_cache()
self._q.put(None)
def _clear_task_belong_cache(self) -> None:
"""
Remove the task belong cache key once listening is finished.
"""
try:
redis_client.delete(self._task_belong_cache_key)
except RedisError:
logger.exception(
"Failed to clear task belong cache for task %s (key: %s)", self._task_id, self._task_belong_cache_key
)
def publish_error(self, e, pub_from: PublishFrom) -> None:
"""
Publish error
:param e: error
:param pub_from: publish from
:return:
"""
self.publish(QueueErrorEvent(error=e), pub_from)
@property
def graph_runtime_state(self) -> GraphRuntimeState | None:
"""Retrieve the attached graph runtime state, if available."""
return self._graph_runtime_state
@graph_runtime_state.setter
def graph_runtime_state(self, graph_runtime_state: GraphRuntimeState | None) -> None:
"""Attach the live graph runtime state reference for downstream consumers."""
self._graph_runtime_state = graph_runtime_state
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
self._check_for_sqlalchemy_models(event.model_dump())
self._publish(event, pub_from)
@abstractmethod
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
raise NotImplementedError
@classmethod
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str):
"""
Set task stop flag
:return:
"""
result: Any | None = redis_client.get(cls._generate_task_belong_cache_key(task_id))
if result is None:
return
user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
if result.decode("utf-8") != f"{user_prefix}-{user_id}":
return
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
@classmethod
def set_stop_flag_no_user_check(cls, task_id: str) -> None:
"""
Set task stop flag without user permission check.
This method allows stopping workflows without user context.
:param task_id: The task ID to stop
:return:
"""
if not task_id:
return
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
@cachedmethod(lambda self: self._stopped_cache, lock=lambda self: self._cache_lock)
def _is_stopped(self) -> bool:
"""
Check if task is stopped
:return:
"""
stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id)
result = redis_client.get(stopped_cache_key)
if result is not None:
return True
return False
@classmethod
def _generate_task_belong_cache_key(cls, task_id: str) -> str:
"""
Generate task belong cache key
:param task_id: task id
:return:
"""
return f"generate_task_belong:{task_id}"
@classmethod
def _generate_stopped_cache_key(cls, task_id: str) -> str:
"""
Generate stopped cache key
:param task_id: task id
:return:
"""
return f"generate_task_stopped:{task_id}"
def _check_for_sqlalchemy_models(self, data: Any):
# from entity to dict or list
if isinstance(data, dict):
for value in data.values():
self._check_for_sqlalchemy_models(value)
elif isinstance(data, list):
for item in data:
self._check_for_sqlalchemy_models(item)
else:
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
)

View File

@@ -0,0 +1,388 @@
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AppGenerateEntity,
EasyUIBasedAppGenerateEntity,
InvokeFrom,
ModelConfigWithCredentialsEntity,
)
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
)
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from models.model import App, AppMode, Message, MessageAnnotation
if TYPE_CHECKING:
from core.file.models import File
_logger = logging.getLogger(__name__)
class AppRunner:
def recalc_llm_max_tokens(
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
return -1
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_context_tokens:
max_tokens = max(model_context_tokens - prompt_tokens, 16)
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
model_config.parameters[parameter_rule.name] = max_tokens
def organize_prompt_messages(
self,
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: Mapping[str, str],
files: Sequence["File"],
query: str = "",
context: str | None = None,
memory: TokenBufferMemory | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
"""
Organize prompt messages
:param context:
:param app_record: app record
:param model_config: model config entity
:param prompt_template_entity: prompt template entity
:param inputs: inputs
:param files: files
:param query: query
:param memory: memory
:param image_detail_config: the image quality config
:return:
"""
# get prompt without memory and context
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
prompt_transform: Union[SimplePromptTransform, AdvancedPromptTransform]
prompt_transform = SimplePromptTransform()
prompt_messages, stop = prompt_transform.get_prompt(
app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
)
else:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
model_mode = ModelMode(model_config.mode)
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
if not advanced_completion_prompt_template:
raise InvokeBadRequestError("Advanced completion prompt template is required.")
prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
if advanced_completion_prompt_template.role_prefix:
memory_config.role_prefix = MemoryConfig.RolePrefix(
user=advanced_completion_prompt_template.role_prefix.user,
assistant=advanced_completion_prompt_template.role_prefix.assistant,
)
else:
if not prompt_template_entity.advanced_chat_prompt_template:
raise InvokeBadRequestError("Advanced chat prompt template is required.")
prompt_template = []
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs=inputs,
query=query or "",
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config,
image_detail_config=image_detail_config,
)
stop = model_config.stop
return prompt_messages, stop
def direct_output(
self,
queue_manager: AppQueueManager,
app_generate_entity: EasyUIBasedAppGenerateEntity,
prompt_messages: list,
text: str,
stream: bool,
usage: LLMUsage | None = None,
):
"""
Direct output
:param queue_manager: application queue manager
:param app_generate_entity: app generate entity
:param prompt_messages: prompt messages
:param text: text
:param stream: stream
:param usage: usage
:return:
"""
if stream:
index = 0
for token in text:
chunk = LLMResultChunk(
model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)),
)
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
index += 1
time.sleep(0.01)
queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage or LLMUsage.empty_usage(),
),
),
PublishFrom.APPLICATION_MANAGER,
)
def _handle_invoke_result(
self,
invoke_result: Union[LLMResult, Generator[Any, None, None]],
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
):
"""
Handle invoke result
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param stream: stream
:param agent: agent
:return:
"""
if not stream and isinstance(invoke_result, LLMResult):
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
elif stream and isinstance(invoke_result, Generator):
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
"""
Handle invoke result direct
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
:return:
"""
queue_manager.publish(
QueueMessageEndEvent(
llm_result=invoke_result,
),
PublishFrom.APPLICATION_MANAGER,
)
def _handle_invoke_result_stream(
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
):
"""
Handle invoke result
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
:return:
"""
model: str = ""
prompt_messages: list[PromptMessage] = []
text = ""
usage = None
for result in invoke_result:
if not agent:
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
else:
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
message = result.delta.message
if isinstance(message.content, str):
text += message.content
elif isinstance(message.content, list):
for content in message.content:
if not isinstance(content, str):
# TODO(QuantumGhost): Add multimodal output support for easy ui.
_logger.warning("received multimodal output, type=%s", type(content))
text += content.data
else:
text += content # failback to str
if not model:
model = result.model
if not prompt_messages:
prompt_messages = list(result.prompt_messages)
if result.delta.usage:
usage = result.delta.usage
if usage is None:
usage = LLMUsage.empty_usage()
llm_result = LLMResult(
model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage
)
queue_manager.publish(
QueueMessageEndEvent(
llm_result=llm_result,
),
PublishFrom.APPLICATION_MANAGER,
)
def moderation_for_inputs(
self,
*,
app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
query: str | None = None,
message_id: str,
) -> tuple[bool, Mapping[str, Any], str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
:param tenant_id: tenant id
:param app_generate_entity: app generate entity
:param inputs: inputs
:param query: query
:param message_id: message id
:return:
"""
moderation_feature = InputModeration()
return moderation_feature.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=dict(inputs),
query=query or "",
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
)
def check_hosting_moderation(
self,
application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager,
prompt_messages: list[PromptMessage],
) -> bool:
"""
Check hosting moderation
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param prompt_messages: prompt messages
:return:
"""
hosting_moderation_feature = HostingModerationFeature()
moderation_result = hosting_moderation_feature.check(
application_generate_entity=application_generate_entity, prompt_messages=prompt_messages
)
if moderation_result:
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
stream=application_generate_entity.stream,
)
return moderation_result
def fill_in_inputs_from_external_data_tools(
self,
tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: Mapping[str, Any],
query: str,
) -> Mapping[str, Any]:
"""
Fill in variable inputs from external data tools if exists.
:param tenant_id: workspace id
:param app_id: app id
:param external_data_tools: external data tools configs
:param inputs: the inputs
:param query: the query
:return: the filled inputs
"""
external_data_fetch_feature = ExternalDataFetch()
return external_data_fetch_feature.fetch(
tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query
)
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> MessageAnnotation | None:
"""
Query app annotations to reply
:param app_record: app record
:param message: message
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:return:
"""
annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query(
app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
)

View File

View File

@@ -0,0 +1,148 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager
from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
from core.app.app_config.features.suggested_questions_after_answer.manager import (
SuggestedQuestionsAfterAnswerConfigManager,
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import App, AppMode, AppModelConfig, Conversation
class ChatAppConfig(EasyUIBasedAppConfig):
"""
Chatbot App Config Entity.
"""
pass
class ChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(
cls,
app_model: App,
app_model_config: AppModelConfig,
conversation: Conversation | None = None,
override_config_dict: dict | None = None,
) -> ChatAppConfig:
"""
Convert app model config to chat app config
:param app_model: app model
:param app_model_config: app model config
:param conversation: conversation
:param override_config_dict: app model config dict
:return:
"""
if override_config_dict:
config_from = EasyUIBasedAppModelConfigFrom.ARGS
elif conversation:
config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG
else:
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
if not override_config_dict:
raise Exception("override_config_dict is required when config_from is ARGS")
config_dict = override_config_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = ChatAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=app_mode,
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
dataset=DatasetConfigManager.convert(config=config_dict),
additional_features=cls.convert_features(config_dict, app_mode),
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
config=config_dict
)
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict):
"""
Validate for chat app model config
:param tenant_id: tenant id
:param config: app model config args
"""
app_mode = AppMode.CHAT
related_config_keys = []
# model
config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# user_input_form
config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# prompt
config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config)
related_config_keys.extend(current_related_config_keys)
# dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys)
# opening_statement
config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config
)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# return retriever resource
config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id, config
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@@ -0,0 +1,255 @@
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models import Account
from models.model import App, EndUser
from services.conversation_service import ConversationService
logger = logging.getLogger(__name__)
class ChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
"""
Generate App response.
:param app_model: App
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param streaming: is stream
"""
if not args.get("query"):
raise ValueError("query is required")
query = args["query"]
if not isinstance(query, str):
raise ValueError("query must be a string")
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
# get conversation
conversation = None
conversation_id = args.get("conversation_id")
if conversation_id:
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
# validate override model config
override_model_config_dict = None
if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError("Only in App debug mode can override model config")
# validate config
override_model_config_dict = ChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
)
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True}
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
)
else:
file_objs = []
# convert to app config
app_config = ChatAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict,
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# init application generate entity
application_generate_entity = ChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
stream=streaming,
)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context
@copy_current_request_context
def worker_with_context():
return self._generate_worker(
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation_id=conversation.id,
message_id=message.id,
)
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
):
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation_id: conversation ID
:param message_id: message ID
:return:
"""
with flask_app.app_context():
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
# chatbot app
runner = ChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
)
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()

View File

@@ -0,0 +1,223 @@
import logging
from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfig
from core.app.entities.app_invoke_entities import (
ChatAppGenerateEntity,
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db
from models.model import App, Conversation, Message
logger = logging.getLogger(__name__)
class ChatAppRunner(AppRunner):
"""
Chat Application Runner
"""
def run(
self,
application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
):
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
:return:
"""
app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
image_detail_config = (
application_generate_entity.file_upload_config.image_config.detail
if (
application_generate_entity.file_upload_config
and application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
memory = None
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model,
)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional)
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
memory=memory,
image_detail_config=image_detail_config,
)
# moderation
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
app_id=app_record.id,
tenant_id=app_config.tenant_id,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id,
)
except ModerationError as e:
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream,
)
return
if query:
# annotation reply
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER,
)
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=annotation_reply.content,
stream=application_generate_entity.stream,
)
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_config.external_data_variables
if external_data_tools:
inputs = self.fill_in_inputs_from_external_data_tools(
tenant_id=app_record.tenant_id,
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query,
)
# get context from datasets
context = None
if app_config.dataset and app_config.dataset.dataset_ids:
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
app_record.id,
message.id,
application_generate_entity.user_id,
application_generate_entity.invoke_from,
)
dataset_retrieval = DatasetRetrieval(application_generate_entity)
context = dataset_retrieval.retrieve(
app_id=app_record.id,
user_id=application_generate_entity.user_id,
tenant_id=app_record.tenant_id,
model_config=application_generate_entity.model_conf,
config=app_config.dataset,
query=query,
invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=(
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
),
hit_callback=hit_callback,
memory=memory,
message_id=message.id,
inputs=inputs,
)
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
context=context,
memory=memory,
image_detail_config=image_detail_config,
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages,
)
if hosting_moderation_result:
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model,
)
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=application_generate_entity.model_conf.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,
)
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
)

View File

@@ -0,0 +1,122 @@
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
)
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
response = {
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
}
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
:return:
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(ChatbotAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(ChatbotAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@@ -0,0 +1,55 @@
"""Shared helpers for managing GraphRuntimeState across task pipelines."""
from __future__ import annotations
from typing import TYPE_CHECKING
from core.workflow.runtime import GraphRuntimeState
if TYPE_CHECKING:
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
class GraphRuntimeStateSupport:
"""
Mixin that centralises common GraphRuntimeState access patterns used by task pipelines.
Subclasses are expected to provide:
* `_base_task_pipeline` exposing the queue manager with an optional cached runtime state.
* `_graph_runtime_state` attribute used as the local cache for the runtime state.
"""
_base_task_pipeline: BasedGenerateTaskPipeline
_graph_runtime_state: GraphRuntimeState | None = None
def _ensure_graph_runtime_initialized(
self,
graph_runtime_state: GraphRuntimeState | None = None,
) -> GraphRuntimeState:
"""Validate and return the active graph runtime state."""
return self._resolve_graph_runtime_state(graph_runtime_state)
def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str:
system_variables = graph_runtime_state.variable_pool.system_variables
if not system_variables or not system_variables.workflow_execution_id:
raise ValueError("workflow_execution_id missing from runtime state")
return str(system_variables.workflow_execution_id)
def _resolve_graph_runtime_state(
self,
graph_runtime_state: GraphRuntimeState | None = None,
) -> GraphRuntimeState:
"""Return the cached runtime state or bootstrap it from the queue manager."""
if graph_runtime_state is not None:
self._graph_runtime_state = graph_runtime_state
return graph_runtime_state
if self._graph_runtime_state is None:
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
if candidate is not None:
self._graph_runtime_state = candidate
if self._graph_runtime_state is None:
raise ValueError("graph runtime state not initialized.")
return self._graph_runtime_state

View File

@@ -0,0 +1,677 @@
import time
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NewType, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueLoopCompletedEvent,
QueueLoopNextEvent,
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
LoopNodeCompletedStreamResponse,
LoopNodeNextStreamResponse,
LoopNodeStartStreamResponse,
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.trigger_manager import TriggerManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.enums import (
NodeType,
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
NodeExecutionId = NewType("NodeExecutionId", str)
@dataclass(slots=True)
class _NodeSnapshot:
"""In-memory cache for node metadata between start and completion events."""
title: str
index: int
start_at: datetime
iteration_id: str = ""
"""Empty string means the node is not executing inside an iteration."""
loop_id: str = ""
"""Empty string means the node is not executing inside a loop."""
class WorkflowResponseConverter:
_truncator: BaseTruncator
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
user: Union[Account, EndUser],
system_variables: SystemVariable,
):
self._application_generate_entity = application_generate_entity
self._user = user
self._system_variables = system_variables
self._workflow_inputs = self._prepare_workflow_inputs()
# Disable truncation for SERVICE_API calls to keep backward compatibility.
if application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
self._truncator = DummyVariableTruncator()
else:
self._truncator = VariableTruncator.default()
self._node_snapshots: dict[NodeExecutionId, _NodeSnapshot] = {}
self._workflow_execution_id: str | None = None
self._workflow_started_at: datetime | None = None
# ------------------------------------------------------------------
# Workflow lifecycle helpers
# ------------------------------------------------------------------
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
inputs = dict(self._application_generate_entity.inputs)
for field_name, value in self._system_variables.to_dict().items():
# TODO(@future-refactor): store system variables separately from user inputs so we don't
# need to flatten `sys.*` entries into the input payload just for rerun/export tooling.
if field_name == SystemVariableKey.CONVERSATION_ID:
# Conversation IDs are session-scoped; omitting them keeps workflow inputs
# reusable without pinning new runs to a prior conversation.
continue
inputs[f"sys.{field_name}"] = value
handled = WorkflowEntry.handle_special_values(inputs)
return dict(handled or {})
def _ensure_workflow_run_id(self, workflow_run_id: str | None = None) -> str:
"""Return the memoized workflow run id, optionally seeding it during start events."""
if workflow_run_id is not None:
self._workflow_execution_id = workflow_run_id
if not self._workflow_execution_id:
raise ValueError("workflow_run_id missing before streaming workflow events")
return self._workflow_execution_id
# ------------------------------------------------------------------
# Node snapshot helpers
# ------------------------------------------------------------------
def _store_snapshot(self, event: QueueNodeStartedEvent) -> _NodeSnapshot:
snapshot = _NodeSnapshot(
title=event.node_title,
index=event.node_run_index,
start_at=event.start_at,
iteration_id=event.in_iteration_id or "",
loop_id=event.in_loop_id or "",
)
node_execution_id = NodeExecutionId(event.node_execution_id)
self._node_snapshots[node_execution_id] = snapshot
return snapshot
def _get_snapshot(self, node_execution_id: str) -> _NodeSnapshot | None:
return self._node_snapshots.get(NodeExecutionId(node_execution_id))
def _pop_snapshot(self, node_execution_id: str) -> _NodeSnapshot | None:
return self._node_snapshots.pop(NodeExecutionId(node_execution_id), None)
@staticmethod
def _merge_metadata(
base_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None,
snapshot: _NodeSnapshot | None,
) -> Mapping[WorkflowNodeExecutionMetadataKey, Any] | None:
if not base_metadata and not snapshot:
return base_metadata
merged: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
if base_metadata:
merged.update(base_metadata)
if snapshot:
if snapshot.iteration_id:
merged[WorkflowNodeExecutionMetadataKey.ITERATION_ID] = snapshot.iteration_id
if snapshot.loop_id:
merged[WorkflowNodeExecutionMetadataKey.LOOP_ID] = snapshot.loop_id
return merged or None
def _truncate_mapping(
self,
mapping: Mapping[str, Any] | None,
) -> tuple[Mapping[str, Any] | None, bool]:
if mapping is None:
return None, False
if not mapping:
return {}, False
normalized = WorkflowEntry.handle_special_values(dict(mapping))
if normalized is None:
return None, False
truncated, is_truncated = self._truncator.truncate_variable_mapping(dict(normalized))
return truncated, is_truncated
@staticmethod
def _encode_outputs(outputs: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
if outputs is None:
return None
converter = WorkflowRuntimeTypeConverter()
return converter.to_json_encodable(outputs)
def workflow_start_to_stream_response(
self,
*,
task_id: str,
workflow_run_id: str,
workflow_id: str,
) -> WorkflowStartStreamResponse:
run_id = self._ensure_workflow_run_id(workflow_run_id)
started_at = naive_utc_now()
self._workflow_started_at = started_at
return WorkflowStartStreamResponse(
task_id=task_id,
workflow_run_id=run_id,
data=WorkflowStartStreamResponse.Data(
id=run_id,
workflow_id=workflow_id,
inputs=self._workflow_inputs,
created_at=int(started_at.timestamp()),
),
)
def workflow_finish_to_stream_response(
self,
*,
task_id: str,
workflow_id: str,
status: WorkflowExecutionStatus,
graph_runtime_state: GraphRuntimeState,
error: str | None = None,
exceptions_count: int = 0,
) -> WorkflowFinishStreamResponse:
run_id = self._ensure_workflow_run_id()
started_at = self._workflow_started_at
if started_at is None:
raise ValueError(
"workflow_finish_to_stream_response called before workflow_start_to_stream_response",
)
finished_at = naive_utc_now()
elapsed_time = (finished_at - started_at).total_seconds()
outputs_mapping = graph_runtime_state.outputs or {}
encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping)
created_by: Mapping[str, object] | None
user = self._user
if isinstance(user, Account):
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
else:
created_by = {
"id": user.id,
"user": user.session_id,
}
return WorkflowFinishStreamResponse(
task_id=task_id,
workflow_run_id=run_id,
data=WorkflowFinishStreamResponse.Data(
id=run_id,
workflow_id=workflow_id,
status=status.value,
outputs=encoded_outputs,
error=error,
elapsed_time=elapsed_time,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
created_by=created_by,
created_at=int(started_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(outputs_mapping),
exceptions_count=exceptions_count,
),
)
def workflow_node_start_to_stream_response(
self,
*,
event: QueueNodeStartedEvent,
task_id: str,
) -> NodeStartStreamResponse | None:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._store_snapshot(event)
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=run_id,
data=NodeStartStreamResponse.Data(
id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=snapshot.title,
index=snapshot.index,
created_at=int(snapshot.start_at.timestamp()),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
agent_strategy=event.agent_strategy,
),
)
if event.node_type == NodeType.TOOL:
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=ToolProviderType(event.provider_type),
provider_id=event.provider_id,
)
elif event.node_type == NodeType.DATASOURCE:
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(
self._application_generate_entity.app_config.tenant_id,
event.provider_id,
)
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
self._application_generate_entity.app_config.tenant_id
)
elif event.node_type == NodeType.TRIGGER_PLUGIN:
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
self._application_generate_entity.app_config.tenant_id,
event.provider_id,
)
return response
def workflow_node_finish_to_stream_response(
self,
*,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
task_id: str,
) -> NodeFinishStreamResponse | None:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._pop_snapshot(event.node_execution_id)
start_at = snapshot.start_at if snapshot else event.start_at
finished_at = naive_utc_now()
elapsed_time = (finished_at - start_at).total_seconds()
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
process_data, process_data_truncated = self._truncate_mapping(event.process_data)
encoded_outputs = self._encode_outputs(event.outputs)
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
metadata = self._merge_metadata(event.execution_metadata, snapshot)
if isinstance(event, QueueNodeSucceededEvent):
status = WorkflowNodeExecutionStatus.SUCCEEDED.value
error_message = event.error
elif isinstance(event, QueueNodeFailedEvent):
status = WorkflowNodeExecutionStatus.FAILED.value
error_message = event.error
else:
status = WorkflowNodeExecutionStatus.EXCEPTION.value
error_message = event.error
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=run_id,
data=NodeFinishStreamResponse.Data(
id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
index=snapshot.index if snapshot else 0,
title=snapshot.title if snapshot else "",
inputs=inputs,
inputs_truncated=inputs_truncated,
process_data=process_data,
process_data_truncated=process_data_truncated,
outputs=outputs,
outputs_truncated=outputs_truncated,
status=status,
error=error_message,
elapsed_time=elapsed_time,
execution_metadata=metadata,
created_at=int(start_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(event.outputs or {}),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
),
)
def workflow_node_retry_to_stream_response(
self,
*,
event: QueueNodeRetryEvent,
task_id: str,
) -> NodeRetryStreamResponse | None:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._get_snapshot(event.node_execution_id)
if snapshot is None:
raise AssertionError("node retry event arrived without a stored snapshot")
finished_at = naive_utc_now()
elapsed_time = (finished_at - event.start_at).total_seconds()
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
process_data, process_data_truncated = self._truncate_mapping(event.process_data)
encoded_outputs = self._encode_outputs(event.outputs)
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
metadata = self._merge_metadata(event.execution_metadata, snapshot)
return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=run_id,
data=NodeRetryStreamResponse.Data(
id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
index=snapshot.index,
title=snapshot.title,
inputs=inputs,
inputs_truncated=inputs_truncated,
process_data=process_data,
process_data_truncated=process_data_truncated,
outputs=outputs,
outputs_truncated=outputs_truncated,
status=WorkflowNodeExecutionStatus.RETRY.value,
error=event.error,
elapsed_time=elapsed_time,
execution_metadata=metadata,
created_at=int(snapshot.start_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(event.outputs or {}),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
retry_index=event.retry_index,
),
)
def workflow_iteration_start_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueIterationStartEvent,
) -> IterationNodeStartStreamResponse:
new_inputs, truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_title,
created_at=int(time.time()),
extras={},
inputs=new_inputs,
inputs_truncated=truncated,
metadata=event.metadata or {},
),
)
def workflow_iteration_next_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueIterationNextEvent,
) -> IterationNodeNextStreamResponse:
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_title,
index=event.index,
created_at=int(time.time()),
extras={},
),
)
def workflow_iteration_completed_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueIterationCompletedEvent,
) -> IterationNodeCompletedStreamResponse:
json_converter = WorkflowRuntimeTypeConverter()
new_inputs, inputs_truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
new_outputs, outputs_truncated = self._truncator.truncate_variable_mapping(
json_converter.to_json_encodable(event.outputs) or {}
)
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_title,
outputs=new_outputs,
outputs_truncated=outputs_truncated,
created_at=int(time.time()),
extras={},
inputs=new_inputs,
inputs_truncated=inputs_truncated,
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
),
)
def workflow_loop_start_to_stream_response(
self, *, task_id: str, workflow_execution_id: str, event: QueueLoopStartEvent
) -> LoopNodeStartStreamResponse:
new_inputs, truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
return LoopNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=LoopNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_title,
created_at=int(time.time()),
extras={},
inputs=new_inputs,
inputs_truncated=truncated,
metadata=event.metadata or {},
),
)
def workflow_loop_next_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueLoopNextEvent,
) -> LoopNodeNextStreamResponse:
return LoopNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=LoopNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_title,
index=event.index,
# The `pre_loop_output` field is not utilized by the frontend.
# Previously, it was assigned the value of `event.output`.
pre_loop_output={},
created_at=int(time.time()),
extras={},
),
)
def workflow_loop_completed_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueLoopCompletedEvent,
) -> LoopNodeCompletedStreamResponse:
json_converter = WorkflowRuntimeTypeConverter()
new_inputs, inputs_truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
new_outputs, outputs_truncated = self._truncator.truncate_variable_mapping(
json_converter.to_json_encodable(event.outputs) or {}
)
return LoopNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=LoopNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_title,
outputs=new_outputs,
outputs_truncated=outputs_truncated,
created_at=int(time.time()),
extras={},
inputs=new_inputs,
inputs_truncated=inputs_truncated,
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
),
)
def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
:return:
"""
if not outputs_dict:
return []
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
# Remove None
files = [file for file in files if file]
# Flatten list
# Flatten the list of sequences into a single list of mappings
flattened_files = [file for sublist in files if sublist for file in sublist]
# Convert to tuple to match Sequence type
return tuple(flattened_files)
@classmethod
def _fetch_files_from_variable_value(cls, value: Union[dict, list, Segment]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from variable value
:param value: variable value
:return:
"""
if not value:
return []
files: list[Mapping[str, Any]] = []
if isinstance(value, FileSegment):
files.append(value.value.to_dict())
elif isinstance(value, ArrayFileSegment):
files.extend([i.to_dict() for i in value.value])
elif isinstance(value, File):
files.append(value.to_dict())
elif isinstance(value, list):
for item in value:
file = cls._get_file_var_from_value(item)
if file:
files.append(file)
elif isinstance(
value,
dict,
):
file = cls._get_file_var_from_value(value)
if file:
files.append(file)
return files
@classmethod
def _get_file_var_from_value(cls, value: Union[dict, list]) -> Mapping[str, Any] | None:
"""
Get file var from value
:param value: variable value
:return:
"""
if not value:
return None
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
return value
elif isinstance(value, File):
return value.to_dict()
return None
def handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""
Handle agent log
:param task_id: task id
:param event: agent log event
:return:
"""
return AgentLogStreamResponse(
task_id=task_id,
data=AgentLogStreamResponse.Data(
node_execution_id=event.node_execution_id,
id=event.id,
parent_id=event.parent_id,
label=event.label,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
node_id=event.node_id,
),
)

View File

@@ -0,0 +1,119 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager
from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from models.model import App, AppMode, AppModelConfig
class CompletionAppConfig(EasyUIBasedAppConfig):
"""
Completion App Config Entity.
"""
pass
class CompletionAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None
) -> CompletionAppConfig:
"""
Convert app model config to completion app config
:param app_model: app model
:param app_model_config: app model config
:param override_config_dict: app model config dict
:return:
"""
if override_config_dict:
config_from = EasyUIBasedAppModelConfigFrom.ARGS
else:
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
config_dict = override_config_dict or {}
app_mode = AppMode.value_of(app_model.mode)
app_config = CompletionAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=app_mode,
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
dataset=DatasetConfigManager.convert(config=config_dict),
additional_features=cls.convert_features(config_dict, app_mode),
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
config=config_dict
)
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict):
"""
Validate for completion app model config
:param tenant_id: tenant id
:param config: app model config args
"""
app_mode = AppMode.COMPLETION
related_config_keys = []
# model
config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# user_input_form
config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config)
related_config_keys.extend(current_related_config_keys)
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# prompt
config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config)
related_config_keys.extend(current_related_config_keys)
# dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# more_like_this
config, current_related_config_keys = MoreLikeThisConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id, config
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@@ -0,0 +1,350 @@
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError
from sqlalchemy import select
from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.app.apps.completion.app_runner import CompletionAppRunner
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser, Message
from services.errors.app import MoreLikeThisDisabledError
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
class CompletionAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str | Mapping[str, Any], None, None]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = False,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
:param app_model: App
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param streaming: is stream
"""
query = args["query"]
if not isinstance(query, str):
raise ValueError("query must be a string")
query = query.replace("\x00", "")
inputs = args["inputs"]
# get conversation
conversation = None
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
# validate override model config
override_model_config_dict = None
if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError("Only in App debug mode can override model config")
# validate config
override_model_config_dict = CompletionAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
)
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
)
else:
file_objs = []
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# init application generate entity
application_generate_entity = CompletionAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
files=list(file_objs),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
extras={},
trace_manager=trace_manager,
)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context
@copy_current_request_context
def worker_with_context():
return self._generate_worker(
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message_id=message.id,
)
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=streaming,
)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager,
message_id: str,
):
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param message_id: message ID
:return:
"""
with flask_app.app_context():
try:
# get message
message = self._get_message(message_id)
# chatbot app
runner = CompletionAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message=message,
)
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def generate_more_like_this(
self,
app_model: App,
message_id: str,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
"""
Generate App response.
:param app_model: App
:param message_id: message ID
:param user: account or end user
:param invoke_from: invoke from source
:param stream: is stream
"""
stmt = select(Message).where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
message = db.session.scalar(stmt)
if not message:
raise MessageNotExistsError()
current_app_model_config = app_model.app_model_config
if not current_app_model_config:
raise MoreLikeThisDisabledError()
more_like_this = current_app_model_config.more_like_this_dict
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
if not app_model_config:
raise ValueError("Message app_model_config is None")
override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict["model"]
completion_params = model_dict.get("completion_params")
completion_params["temperature"] = 0.9
model_dict["completion_params"] = completion_params
override_model_config_dict["model"] = model_dict
# parse files
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=message.message_files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
)
else:
file_objs = []
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
)
# init application generate entity
application_generate_entity = CompletionAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
inputs=message.inputs,
query=message.query,
files=list(file_objs),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras={},
)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
)
# new thread with request context
@copy_current_request_context
def worker_with_context():
return self._generate_worker(
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message_id=message.id,
)
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=stream,
)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)

View File

@@ -0,0 +1,181 @@
import logging
from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfig
from core.app.entities.app_invoke_entities import (
CompletionAppGenerateEntity,
)
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db
from models.model import App, Message
logger = logging.getLogger(__name__)
class CompletionAppRunner(AppRunner):
"""
Completion Application Runner
"""
def run(
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
):
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param message: message
:return:
"""
app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
image_detail_config = (
application_generate_entity.file_upload_config.image_config.detail
if (
application_generate_entity.file_upload_config
and application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
image_detail_config=image_detail_config,
)
# moderation
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
app_id=app_record.id,
tenant_id=app_config.tenant_id,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query or "",
message_id=message.id,
)
except ModerationError as e:
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream,
)
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_config.external_data_variables
if external_data_tools:
inputs = self.fill_in_inputs_from_external_data_tools(
tenant_id=app_record.tenant_id,
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query,
)
# get context from datasets
context = None
if app_config.dataset and app_config.dataset.dataset_ids:
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
app_record.id,
message.id,
application_generate_entity.user_id,
application_generate_entity.invoke_from,
)
dataset_config = app_config.dataset
if dataset_config and dataset_config.retrieve_config.query_variable:
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
dataset_retrieval = DatasetRetrieval(application_generate_entity)
context = dataset_retrieval.retrieve(
app_id=app_record.id,
user_id=application_generate_entity.user_id,
tenant_id=app_record.tenant_id,
model_config=application_generate_entity.model_conf,
config=dataset_config,
query=query or "",
invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source
if app_config.additional_features
else False,
hit_callback=hit_callback,
message_id=message.id,
inputs=inputs,
)
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
context=context,
image_detail_config=image_detail_config,
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages,
)
if hosting_moderation_result:
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model,
)
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=application_generate_entity.model_conf.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,
)
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
)

View File

@@ -0,0 +1,121 @@
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
CompletionAppBlockingResponse,
CompletionAppStreamResponse,
ErrorStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
)
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
response = {
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
}
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
:return:
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(CompletionAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(CompletionAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@@ -0,0 +1,2 @@
class GenerateTaskStoppedError(Exception):
pass

View File

@@ -0,0 +1,279 @@
import json
import logging
from collections.abc import Generator
from typing import Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
AppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.task_entities import (
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
CompletionAppBlockingResponse,
CompletionAppStreamResponse,
)
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import CreatorUserRole
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
class MessageBasedAppGenerator(BaseAppGenerator):
def _handle_response(
self,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param conversation: conversation
:param message: message
:param user: user
:param stream: is stream
:return:
"""
# init generate task pipeline
generate_task_pipeline = EasyUIBasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
stream=stream,
)
try:
return generate_task_pipeline.process()
except ValueError as e:
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception("Failed to handle response, conversation_id: %s", conversation.id)
raise e
def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig:
if conversation:
stmt = select(AppModelConfig).where(
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
)
app_model_config = db.session.scalar(stmt)
if not app_model_config:
raise AppModelConfigBrokenError()
else:
if app_model.app_model_config_id is None:
raise AppModelConfigBrokenError()
app_model_config = app_model.app_model_config
if not app_model_config:
raise AppModelConfigBrokenError()
return app_model_config
def _init_generate_records(
self,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
conversation: Conversation | None = None,
) -> tuple[Conversation, Message]:
"""
Initialize generate records
:param application_generate_entity: application generate entity
:conversation conversation
:return:
"""
app_config: EasyUIBasedAppConfig = cast(EasyUIBasedAppConfig, application_generate_entity.app_config)
# get from source
end_user_id = None
account_id = None
if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
from_source = "api"
end_user_id = application_generate_entity.user_id
else:
from_source = "console"
account_id = application_generate_entity.user_id
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
app_model_config_id = None
override_model_configs = None
model_provider = None
model_id = None
else:
app_model_config_id = app_config.app_model_config_id
model_provider = application_generate_entity.model_conf.provider
model_id = application_generate_entity.model_conf.model
override_model_configs = None
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in {
AppMode.AGENT_CHAT,
AppMode.CHAT,
AppMode.COMPLETION,
}:
override_model_configs = app_config.app_model_config_dict
# get conversation introduction
introduction = self._get_conversation_introduction(application_generate_entity)
# get conversation name
query = application_generate_entity.query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query
if not conversation:
conversation = Conversation(
app_id=app_config.app_id,
app_model_config_id=app_model_config_id,
model_provider=model_provider,
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_config.app_mode.value,
name=conversation_name,
inputs=application_generate_entity.inputs,
introduction=introduction,
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
)
db.session.add(conversation)
db.session.commit()
db.session.refresh(conversation)
else:
conversation.updated_at = naive_utc_now()
db.session.commit()
message = Message(
app_id=app_config.app_id,
model_provider=model_provider,
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
query=application_generate_entity.query,
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
provider_response_latency=0,
total_price=0,
currency="USD",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
app_mode=app_config.app_mode,
)
db.session.add(message)
db.session.commit()
db.session.refresh(message)
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
created_by=account_id or end_user_id or "",
)
db.session.add(message_file)
db.session.commit()
return conversation, message
def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str:
"""
Get conversation introduction
:param application_generate_entity: application generate entity
:return: conversation introduction
"""
app_config = application_generate_entity.app_config
introduction = app_config.additional_features.opening_statement
if introduction:
try:
inputs = application_generate_entity.inputs
prompt_template = PromptTemplateParser(template=introduction)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
introduction = prompt_template.format(prompt_inputs)
except KeyError:
pass
return introduction or ""
def _get_conversation(self, conversation_id: str) -> Conversation:
"""
Get conversation by conversation id
:param conversation_id: conversation id
:return: conversation
"""
with Session(db.engine, expire_on_commit=False) as session:
conversation = session.scalar(select(Conversation).where(Conversation.id == conversation_id))
if not conversation:
raise ConversationNotExistsError("Conversation not exists")
return conversation
def _get_message(self, message_id: str) -> Message:
"""
Get message by message id
:param message_id: message id
:return: message
"""
with Session(db.engine, expire_on_commit=False) as session:
message = session.scalar(select(Message).where(Message.id == message_id))
if message is None:
raise MessageNotExistsError("Message not exists")
return message

View File

@@ -0,0 +1,47 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueAdvancedChatMessageEndEvent,
QueueErrorEvent,
QueueMessageEndEvent,
QueueStopEvent,
)
class MessageBasedAppQueueManager(AppQueueManager):
def __init__(
self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
):
super().__init__(task_id, user_id, invoke_from)
self._conversation_id = str(conversation_id)
self._app_mode = app_mode
self._message_id = str(message_id)
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
message = MessageQueueMessage(
task_id=self._task_id,
message_id=self._message_id,
conversation_id=self._conversation_id,
app_mode=self._app_mode,
event=event,
)
self._q.put(message)
if isinstance(
event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent
):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedError()

View File

@@ -0,0 +1,95 @@
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
:return:
"""
return cls.convert_blocking_full_response(blocking_response)
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
else:
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
else:
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk

View File

@@ -0,0 +1,66 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.dataset import Pipeline
from models.model import AppMode
from models.workflow import Workflow
class PipelineConfig(WorkflowUIBasedAppConfig):
"""
Pipeline Config Entity.
"""
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
pass
class PipelineConfigManager(BaseAppConfigManager):
@classmethod
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow, start_node_id: str) -> PipelineConfig:
pipeline_config = PipelineConfig(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id,
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(
workflow=workflow, start_node_id=start_node_id
),
)
return pipeline_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
"""
Validate for pipeline config
:param tenant_id: tenant id
:param config: app model config args
:param only_structure_validate: only validate the structure of the config
"""
related_config_keys = []
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@@ -0,0 +1,824 @@
import contextvars
import datetime
import json
import logging
import secrets
import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, cast, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
)
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.repositories.factory import DifyCoreRepositoryFactory
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.datasource_provider_service import DatasourceProviderService
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
logger = logging.getLogger(__name__)
class PipelineGenerator(BaseAppGenerator):
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: str | None = None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# Add null check for dataset
with Session(db.engine, expire_on_commit=False) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
)
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
)
documents: list[Document] = []
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
from services.dataset_service import DocumentService
for datasource_info in datasource_info_list:
position = DocumentService.get_documents_position(dataset.id)
document = self._build_document(
tenant_id=pipeline.tenant_id,
dataset_id=dataset.id,
built_in_field_enabled=dataset.built_in_field_enabled,
datasource_type=datasource_type,
datasource_info=datasource_info,
created_from="rag-pipeline",
position=position,
account=user,
batch=batch,
document_form=dataset.chunk_structure,
)
db.session.add(document)
documents.append(document)
db.session.commit()
# run in child thread
rag_pipeline_invoke_entities = []
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = args.get("original_document_id") or None
if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
document_id = document_id or documents[i].id
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document_id,
datasource_type=datasource_type,
datasource_info=json.dumps(datasource_info),
datasource_node_id=start_node_id,
input_data=dict(inputs),
pipeline_id=pipeline.id,
created_by=user.id,
)
db.session.add(document_pipeline_execution_log)
db.session.commit()
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=datasource_type,
datasource_info=datasource_info,
dataset_id=dataset.id,
original_document_id=args.get("original_document_id"),
start_node_id=start_node_id,
batch=batch,
document_id=document_id,
inputs=self._prepare_user_inputs(
user_inputs=inputs,
variables=pipeline_config.rag_pipeline_variables,
tenant_id=pipeline.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),
files=[],
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
call_depth=call_depth,
workflow_execution_id=workflow_run_id,
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
if invoke_from == InvokeFrom.DEBUGGER or is_retry:
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
context=contextvars.copy_context(),
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
else:
rag_pipeline_invoke_entities.append(
RagPipelineInvokeEntity(
pipeline_id=pipeline.id,
user_id=user.id,
tenant_id=pipeline.tenant_id,
workflow_id=workflow.id,
streaming=streaming,
workflow_execution_id=workflow_run_id,
workflow_thread_pool_id=workflow_thread_pool_id,
application_generate_entity=application_generate_entity.model_dump(),
)
)
if rag_pipeline_invoke_entities:
RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay()
# return batch, dataset, documents
return {
"batch": batch,
"dataset": PipelineDataset(
id=dataset.id,
name=dataset.name,
description=dataset.description,
chunk_structure=dataset.chunk_structure,
).model_dump(),
"documents": [
PipelineDocument(
id=document.id,
position=document.position,
data_source_type=document.data_source_type,
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
name=document.name,
indexing_status=document.indexing_status,
error=document.error,
enabled=document.enabled,
).model_dump()
for document in documents
],
}
def _generate(
self,
*,
flask_app: Flask,
context: contextvars.Context,
pipeline: Pipeline,
workflow_id: str,
user: Union[Account, EndUser],
application_generate_entity: RagPipelineGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
workflow_thread_pool_id: str | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
:param pipeline: Pipeline
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
with preserve_flask_contexts(flask_app, context_vars=context):
# init queue manager
workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first()
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
queue_manager = PipelineQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=AppMode.RAG_PIPELINE,
)
context = contextvars.copy_context()
# new thread
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"queue_manager": queue_manager,
"application_generate_entity": application_generate_entity,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
},
)
worker_thread.start()
draft_var_saver_factory = self._get_draft_var_saver_factory(
invoke_from,
user,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=streaming,
draft_var_saver_factory=draft_var_saver_factory,
)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def single_iteration_generate(
self,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)
with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")
# init application generate entity - use RagPipelineGenerateEntity instead
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
dataset_id=dataset.id,
batch=args.get("batch", ""),
document_id=args.get("document_id"),
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
workflow_execution_id=str(uuid.uuid4()),
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
context=contextvars.copy_context(),
)
def single_loop_generate(
self,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)
# init application generate entity
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
batch=args.get("batch", ""),
document_id=args.get("document_id"),
dataset_id=dataset.id,
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
context=contextvars.copy_context(),
)
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
try:
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
Workflow.app_id == application_generate_entity.app_config.app_id,
Workflow.id == application_generate_entity.app_config.workflow_id,
)
)
if workflow is None:
raise ValueError("Workflow not found")
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}
if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(
select(EndUser).where(EndUser.id == application_generate_entity.user_id)
)
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id
# workflow app
runner = PipelineRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
runner.run()
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def _handle_response(
self,
application_generate_entity: RagPipelineGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
:return:
"""
# init generate task pipeline
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream,
draft_var_saver_factory=draft_var_saver_factory,
)
try:
return generate_task_pipeline.process()
except ValueError as e:
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(
"Fails to process generate task pipeline, task_id: %r",
application_generate_entity.task_id,
)
raise e
def _build_document(
self,
tenant_id: str,
dataset_id: str,
built_in_field_enabled: bool,
datasource_type: str,
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
account: Union[Account, EndUser],
batch: str,
document_form: str,
):
if datasource_type == "local_file":
name = datasource_info.get("name", "untitled")
elif datasource_type == "online_document":
name = datasource_info.get("page", {}).get("page_name", "untitled")
elif datasource_type == "website_crawl":
name = datasource_info.get("title", "untitled")
elif datasource_type == "online_drive":
name = datasource_info.get("name", "untitled")
else:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
document = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=position,
data_source_type=datasource_type,
data_source_info=json.dumps(datasource_info),
batch=batch,
name=name,
created_from=created_from,
created_by=account.id,
doc_form=document_form,
)
doc_metadata = {}
if built_in_field_enabled:
doc_metadata = {
BuiltInField.document_name: name,
BuiltInField.uploader: account.name,
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.source: datasource_type,
}
if doc_metadata:
document.doc_metadata = doc_metadata
return document
def _format_datasource_info_list(
self,
datasource_type: str,
datasource_info_list: list[Mapping[str, Any]],
pipeline: Pipeline,
workflow: Workflow,
start_node_id: str,
user: Union[Account, EndUser],
) -> list[Mapping[str, Any]]:
"""
Format datasource info list.
"""
if datasource_type == "online_drive":
all_files: list[Mapping[str, Any]] = []
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == start_node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
credential_id=datasource_node_data.get("credential_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
for datasource_info in datasource_info_list:
if datasource_info.get("id") and datasource_info.get("type") == "folder":
# get all files in the folder
self._get_files_in_folder(
datasource_runtime,
datasource_info.get("id", ""),
datasource_info.get("bucket", None),
user.id,
all_files,
datasource_info,
None,
)
else:
all_files.append(
{
"id": datasource_info.get("id", ""),
"name": datasource_info.get("name", "untitled"),
"bucket": datasource_info.get("bucket", None),
}
)
return all_files
else:
return datasource_info_list
def _get_files_in_folder(
self,
datasource_runtime: OnlineDriveDatasourcePlugin,
prefix: str,
bucket: str | None,
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
next_page_parameters: dict | None = None,
):
"""
Get files in a folder.
"""
result_generator = datasource_runtime.online_drive_browse_files(
user_id=user_id,
request=OnlineDriveBrowseFilesRequest(
bucket=bucket,
prefix=prefix,
max_keys=20,
next_page_parameters=next_page_parameters,
),
provider_type=datasource_runtime.datasource_provider_type(),
)
is_truncated = False
for result in result_generator:
for files in result.result:
for file in files.files:
if file.type == "folder":
self._get_files_in_folder(
datasource_runtime,
file.id,
bucket,
user_id,
all_files,
datasource_info,
None,
)
else:
all_files.append(
{
"id": file.id,
"name": file.name,
"bucket": bucket,
}
)
is_truncated = files.is_truncated
next_page_parameters = files.next_page_parameters
if is_truncated:
self._get_files_in_folder(
datasource_runtime, prefix, bucket, user_id, all_files, datasource_info, next_page_parameters
)

View File

@@ -0,0 +1,45 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueErrorEvent,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
class PipelineQueueManager(AppQueueManager):
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
self._q.put(message)
if isinstance(
event,
QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedError()

View File

@@ -0,0 +1,287 @@
import logging
import time
from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.enums import WorkflowType
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.dataset import Document, Pipeline
from models.enums import UserFrom
from models.model import EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class PipelineRunner(WorkflowBasedAppRunner):
"""
Pipeline Application Runner
"""
def __init__(
self,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None:
"""
Run application
"""
app_config = self.application_generate_entity.app_config
app_config = cast(PipelineConfig, app_config)
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
db.session.close()
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = SystemVariable(
files=files,
user_id=user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
document_id=self.application_generate_entity.document_id,
original_document_id=self.application_generate_entity.original_document_id,
batch=self.application_generate_entity.batch,
dataset_id=self.application_generate_entity.dataset_id,
datasource_type=self.application_generate_entity.datasource_type,
datasource_info=self.application_generate_entity.datasource_info,
invoke_from=self.application_generate_entity.invoke_from.value,
)
rag_pipeline_variables = []
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable.model_validate(v)
if (
rag_pipeline_variable.belong_to_node_id
in (self.application_generate_entity.start_node_id, "shared")
) and rag_pipeline_variable.variable in inputs:
rag_pipeline_variables.append(
RAGPipelineVariableInput(
variable=rag_pipeline_variable,
value=inputs[rag_pipeline_variable.variable],
)
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
rag_pipeline_variables=rag_pipeline_variables,
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
graph = self._init_rag_pipeline_graph(
graph_runtime_state=graph_runtime_state,
start_node_id=self.application_generate_entity.start_node_id,
workflow=workflow,
)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
graph_runtime_state=graph_runtime_state,
variable_pool=variable_pool,
)
self._queue_manager.graph_runtime_state = graph_runtime_state
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
version=workflow.version,
graph_data=workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
generator = workflow_entry.run()
for event in generator:
self._update_document_status(
event, self.application_generate_entity.document_id, self.application_generate_entity.dataset_id
)
self._handle_event(workflow_entry, event)
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
.first()
)
# return workflow
return workflow
def _init_rag_pipeline_graph(
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
) -> Graph:
"""
Init pipeline graph
"""
graph_config = workflow.graph_dict
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# nodes = graph_config.get("nodes", [])
# edges = graph_config.get("edges", [])
# real_run_nodes = []
# real_edges = []
# exclude_node_ids = []
# for node in nodes:
# node_id = node.get("id")
# node_type = node.get("data", {}).get("type", "")
# if node_type == "datasource":
# if start_node_id != node_id:
# exclude_node_ids.append(node_id)
# continue
# real_run_nodes.append(node)
# for edge in edges:
# if edge.get("source") in exclude_node_ids:
# continue
# real_edges.append(edge)
# graph_config = dict(graph_config)
# graph_config["nodes"] = real_run_nodes
# graph_config["edges"] = real_edges
# init graph
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id=self.application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
if not graph:
raise ValueError("graph not found in workflow")
return graph
def _update_document_status(self, event: GraphEngineEvent, document_id: str | None, dataset_id: str | None) -> None:
"""
Update document status
"""
if isinstance(event, GraphRunFailedEvent):
if document_id and dataset_id:
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
document.indexing_status = "error"
document.error = event.error or "Unknown error"
db.session.add(document)
db.session.commit()

View File

@@ -0,0 +1,67 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.model import App, AppMode
from models.workflow import Workflow
class WorkflowAppConfig(WorkflowUIBasedAppConfig):
"""
Workflow App Config Entity.
"""
pass
class WorkflowAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = WorkflowAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=app_mode,
workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
additional_features=cls.convert_features(features_dict, app_mode),
)
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
"""
Validate for workflow app model config
:param tenant_id: tenant id
:param config: app model config args
:param only_structure_validate: only validate the structure of the config
"""
related_config_keys = []
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@@ -0,0 +1,574 @@
import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@staticmethod
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
@overload
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Generator[Mapping[str, Any] | str, None, None]: ...
@overload
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
system_files = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
)
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow,
)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id,
user_id=user.id if isinstance(user, Account) else user.session_id,
)
inputs: Mapping[str, Any] = args["inputs"]
extras = {
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
# trigger shouldn't prepare user inputs
if self._should_prepare_user_inputs(args):
inputs = self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
inputs=inputs,
files=list(system_files),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
call_depth=call_depth,
trace_manager=trace_manager,
workflow_execution_id=workflow_run_id,
extras=extras,
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if triggered_from is not None:
# Use explicitly provided triggered_from (for async triggers)
workflow_triggered_from = triggered_from
elif invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
)
def resume(self, *, workflow_run_id: str) -> None:
"""
@TBD
"""
pass
def _generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
"""
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=app_model.mode,
)
# new thread with request context and contextvars
context = contextvars.copy_context()
# release database connection, because the following new thread operations may take a long time
db.session.close()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": context,
"variable_loader": variable_loader,
"root_node_id": root_node_id,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": graph_engine_layers,
},
)
worker_thread.start()
draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
draft_var_saver_factory=draft_var_saver_factory,
stream=streaming,
)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def single_iteration_generate(
self,
app_model: App,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
)
def single_loop_generate(
self,
app_model: App,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create repositories
#
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
)
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
Workflow.app_id == application_generate_entity.app_config.app_id,
Workflow.id == application_generate_entity.app_config.workflow_id,
)
)
if workflow is None:
raise ValueError("Workflow not found")
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}
if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id))
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
)
try:
runner.run()
except GenerateTaskStoppedError as e:
logger.warning("Task stopped: %s", str(e))
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
def _handle_response(
self,
application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
draft_var_saver_factory=draft_var_saver_factory,
stream=stream,
)
try:
return generate_task_pipeline.process()
except ValueError as e:
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(
"Fails to process generate task pipeline, task_id: %s", application_generate_entity.task_id
)
raise e

View File

@@ -0,0 +1,45 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueErrorEvent,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
class WorkflowAppQueueManager(AppQueueManager):
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str):
super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
self._q.put(message)
if isinstance(
event,
QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedError()

View File

@@ -0,0 +1,153 @@
import logging
import time
from collections.abc import Sequence
from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class WorkflowAppRunner(WorkflowBasedAppRunner):
"""
Workflow Application Runner
"""
def __init__(
self,
*,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
root_node_id: str | None = None,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
):
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
graph_engine_layers=graph_engine_layers,
)
self.application_generate_entity = application_generate_entity
self._workflow = workflow
self._sys_user_id = system_user_id
self._root_node_id = root_node_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
def run(self):
"""
Run application
"""
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
system_inputs = SystemVariable(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
timestamp=int(naive_utc_now().timestamp()),
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs
# Create a variable pool.
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=self._workflow.environment_variables,
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
root_node_id=self._root_node_id,
)
# RUN WORKFLOW
# Create Redis command channel for this workflow execution
task_id = self.application_generate_entity.task_id
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
self._queue_manager.graph_runtime_state = graph_runtime_state
workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
workflow_id=self._workflow.id,
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
)
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=self._workflow.id,
workflow_type=WorkflowType(self._workflow.type),
version=self._workflow.version,
graph_data=self._workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)
generator = workflow_entry.run()
for event in generator:
self._handle_event(workflow_entry, event)

View File

@@ -0,0 +1,95 @@
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return blocking_response.model_dump()
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
:return:
"""
return cls.convert_blocking_full_response(blocking_response)
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@@ -0,0 +1,685 @@
import logging
import time
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import Union
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueLoopCompletedEvent,
QueueLoopNextEvent,
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
StreamResponse,
TextChunkStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models import Account
from models.enums import CreatorUserRole
from models.model import EndUser
from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom
logger = logging.getLogger(__name__)
class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
draft_var_saver_factory: DraftVariableSaverFactory,
):
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
)
if isinstance(user, EndUser):
self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatorUserRole.END_USER
else:
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatorUserRole.ACCOUNT
self._application_generate_entity = application_generate_entity
self._workflow_features_dict = workflow.features_dict
self._workflow_execution_id = ""
self._invoke_from = queue_manager.invoke_from
self._draft_var_saver_factory = draft_var_saver_factory
self._workflow = workflow
self._workflow_system_variables = SystemVariable(
files=application_generate_entity.files,
user_id=user_session_id,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_execution_id,
)
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=self._workflow_system_variables,
)
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
:return:
"""
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
"""
To blocking response.
:return:
"""
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
id=stream_response.data.id,
workflow_id=stream_response.data.workflow_id,
status=stream_response.data.status,
outputs=stream_response.data.outputs,
error=stream_response.data.error,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at),
),
)
return response
else:
continue
raise ValueError("queue listening stopped unexpectedly.")
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]:
"""
To stream response.
:return:
"""
workflow_run_id = None
for stream_response in generator:
if isinstance(stream_response, WorkflowStartStreamResponse):
workflow_run_id = stream_response.workflow_run_id
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher:
return None
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(
self, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow_features_dict
if (
features_dict.get("text_to_speech")
and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
)
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not tts_publisher:
break
audio_trunk = tts_publisher.check_and_get_audio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio_trunk.status == "finish":
break
else:
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception:
logger.exception("Fails to get audio trunk, task_id: %s", task_id)
break
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
@contextmanager
def _database_session(self):
"""Context manager for database sessions."""
with Session(db.engine, expire_on_commit=False) as session:
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""
if not self._workflow_execution_id:
raise ValueError("workflow run not initialized.")
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline.ping_stream_response()
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events."""
err = self._base_task_pipeline.handle_error(event=event)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event(
self, event: QueueWorkflowStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
runtime_state = self._resolve_graph_runtime_state()
run_id = self._extract_workflow_run_id(runtime_state)
self._workflow_execution_id = run_id
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow.id,
)
yield start_resp
def _handle_node_retry_event(self, event: QueueNodeRetryEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle node retry events."""
self._ensure_workflow_initialized()
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
)
if response:
yield response
def _handle_node_started_event(
self, event: QueueNodeStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle node started events."""
self._ensure_workflow_initialized()
node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
)
if node_start_response:
yield node_start_response
def _handle_node_succeeded_event(
self, event: QueueNodeSucceededEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle node succeeded events."""
node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
)
self._save_output_for_event(event, event.node_execution_id)
if node_success_response:
yield node_success_response
def _handle_node_failed_events(
self,
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, event.node_execution_id)
if node_failed_response:
yield node_failed_response
def _handle_iteration_start_event(
self, event: QueueIterationStartEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle iteration start events."""
self._ensure_workflow_initialized()
iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield iter_start_resp
def _handle_iteration_next_event(
self, event: QueueIterationNextEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle iteration next events."""
self._ensure_workflow_initialized()
iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield iter_next_resp
def _handle_iteration_completed_event(
self, event: QueueIterationCompletedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle iteration completed events."""
self._ensure_workflow_initialized()
iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield iter_finish_resp
def _handle_loop_start_event(self, event: QueueLoopStartEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle loop start events."""
self._ensure_workflow_initialized()
loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield loop_start_resp
def _handle_loop_next_event(self, event: QueueLoopNextEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle loop next events."""
self._ensure_workflow_initialized()
loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield loop_next_resp
def _handle_loop_completed_event(
self, event: QueueLoopCompletedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle loop completed events."""
self._ensure_workflow_initialized()
loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield loop_finish_resp
def _handle_workflow_succeeded_event(
self,
event: QueueWorkflowSucceededEvent,
*,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.SUCCEEDED,
graph_runtime_state=validated_state,
)
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
def _handle_workflow_partial_success_event(
self,
event: QueueWorkflowPartialSuccessEvent,
*,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
graph_runtime_state=validated_state,
exceptions_count=event.exceptions_count,
)
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
def _handle_workflow_failed_and_stop_events(
self,
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
*,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed and stop events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized()
if isinstance(event, QueueWorkflowFailedEvent):
status = WorkflowExecutionStatus.FAILED
error = event.error
exceptions_count = event.exceptions_count
else:
status = WorkflowExecutionStatus.STOPPED
error = event.get_stop_reason()
exceptions_count = 0
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow.id,
status=status,
graph_runtime_state=validated_state,
error=error,
exceptions_count=exceptions_count,
)
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
def _handle_text_chunk_event(
self,
event: QueueTextChunkEvent,
*,
tts_publisher: AppGeneratorTTSPublisher | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle text chunk events."""
delta_text = event.text
if delta_text is None:
return
# only publish tts message at text chunk streaming
if tts_publisher and queue_message:
tts_publisher.publish(queue_message)
yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle agent log events."""
yield self._workflow_response_converter.handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
def _get_event_handlers(self) -> dict[type, Callable]:
"""Get mapping of event types to their handlers using fluent pattern."""
return {
# Basic events
QueuePingEvent: self._handle_ping_event,
QueueErrorEvent: self._handle_error_event,
QueueTextChunkEvent: self._handle_text_chunk_event,
# Workflow events
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
# Node events
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
# Iteration events
QueueIterationStartEvent: self._handle_iteration_start_event,
QueueIterationNextEvent: self._handle_iteration_next_event,
QueueIterationCompletedEvent: self._handle_iteration_completed_event,
# Loop events
QueueLoopStartEvent: self._handle_loop_start_event,
QueueLoopNextEvent: self._handle_loop_next_event,
QueueLoopCompletedEvent: self._handle_loop_completed_event,
# Agent events
QueueAgentLogEvent: self._handle_agent_log_event,
}
def _dispatch_event(
self,
event: AppQueueEvent,
*,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
) -> Generator[StreamResponse, None, None]:
"""Dispatch events using elegant pattern matching."""
handlers = self._get_event_handlers()
event_type = type(event)
# Direct handler lookup
if handler := handlers.get(event_type):
yield from handler(
event,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# Handle node failure events with isinstance check
if isinstance(
event,
(
QueueNodeFailedEvent,
QueueNodeExceptionEvent,
),
):
yield from self._handle_node_failed_events(
event,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# Handle workflow failed and stop events with isinstance check
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
yield from self._handle_workflow_failed_and_stop_events(
event,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# For unhandled events, we continue (original behavior)
return
def _process_stream_response(
self,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response using elegant Fluent Python patterns.
Maintains exact same functionality as original 44-if-statement version.
"""
for queue_message in self._base_task_pipeline.queue_manager.listen():
event = queue_message.event
match event:
case QueueWorkflowStartedEvent():
self._resolve_graph_runtime_state()
yield from self._handle_workflow_started_event(event)
case QueueTextChunkEvent():
yield from self._handle_text_chunk_event(
event, tts_publisher=tts_publisher, queue_message=queue_message
)
case QueueErrorEvent():
yield from self._handle_error_event(event)
break
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_and_stop_events(event)
break
case QueueStopEvent():
yield from self._handle_workflow_failed_and_stop_events(event)
break
# Handle all other events through elegant dispatch
case _:
if responses := list(
self._dispatch_event(
event,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
):
yield from responses
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None):
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
elif invoke_from == InvokeFrom.EXPLORE:
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
elif invoke_from == InvokeFrom.WEB_APP:
created_from = WorkflowAppLogCreatedFrom.WEB_APP
else:
# not save log for debugging
return
if not workflow_run_id:
return
workflow_app_log = WorkflowAppLog(
tenant_id=self._application_generate_entity.app_config.tenant_id,
app_id=self._application_generate_entity.app_config.app_id,
workflow_id=self._workflow.id,
workflow_run_id=workflow_run_id,
created_from=created_from.value,
created_by_role=self._created_by_role,
created_by=self._user_id,
)
session.add(workflow_app_log)
session.commit()
def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: list[str] | None = None
) -> TextChunkStreamResponse:
"""
Handle completed event.
:param text: text
:return:
"""
response = TextChunkStreamResponse(
task_id=self._application_generate_entity.task_id,
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
)
return response
def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
with Session(db.engine) as session, session.begin():
saver = self._draft_var_saver_factory(
session=session,
app_id=self._application_generate_entity.app_config.app_id,
node_id=event.node_id,
node_type=event.node_type,
node_execution_id=node_execution_id,
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
)
saver.save(event.process_data, event.outputs)

View File

@@ -0,0 +1,572 @@
import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueLoopCompletedEvent,
QueueLoopNextEvent,
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunAgentLogEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
from models.enums import UserFrom
from models.workflow import Workflow
class WorkflowBasedAppRunner:
def __init__(
self,
*,
queue_manager: AppQueueManager,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
app_id: str,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
):
self._queue_manager = queue_manager
self._variable_loader = variable_loader
self._app_id = app_id
self._graph_engine_layers = graph_engine_layers
def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_runtime_state: GraphRuntimeState,
workflow_id: str = "",
tenant_id: str = "",
user_id: str = "",
root_node_id: str | None = None,
) -> Graph:
"""
Init graph
"""
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=tenant_id or "",
app_id=self._app_id,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
# Use the provided graph_runtime_state for consistent state management
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
if not graph:
raise ValueError("graph not found in workflow")
return graph
def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
(either single iteration or single loop).
Args:
workflow: The workflow instance
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise
Returns:
A tuple containing (graph, variable_pool, graph_runtime_state)
Raises:
ValueError: If neither single_iteration_run nor single_loop_run is specified
"""
# Create initial runtime state with variable pool containing environment variables
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
),
start_at=time.time(),
)
# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif single_loop_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state
def _get_graph_and_variable_pool_for_single_node_run(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
) -> tuple[Graph, VariablePool]:
"""
Get graph and variable pool for single node execution (iteration or loop).
Args:
workflow: The workflow instance
node_id: The node ID to execute
user_inputs: User inputs for the node
graph_runtime_state: The graph runtime state
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
node_type_label: Label for error messages ('iteration' or 'loop')
Returns:
A tuple containing (graph, variable_pool)
"""
# fetch workflow graph
graph_config = workflow.graph_dict
if not graph_config:
raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in the specified node type
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
target_node_config = None
for node in node_configs:
if node.get("id") == node_id:
target_node_config = node
break
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")
# Get node class
node_type = NodeType(target_node_config.get("data", {}).get("type"))
node_version = target_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# Use the variable pool from graph_runtime_state instead of creating a new one
variable_pool = graph_runtime_state.variable_pool
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=target_node_config
)
except NotImplementedError:
variable_mapping = {}
load_into_variable_pool(
variable_loader=self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
return graph, variable_pool
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)
def _get_graph_and_variable_pool_of_single_loop(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single loop
"""
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="loop_id",
node_type_label="loop",
)
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
"""
Handle event
:param workflow_entry: workflow entry
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(QueueWorkflowStartedEvent())
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self._publish_event(
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, GraphRunAbortedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type,
start_at=event.start_at,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
inputs=inputs,
process_data=process_data,
outputs=outputs,
error=event.error,
execution_metadata=execution_metadata,
retry_index=event.retry_index,
provider_type=event.provider_type,
provider_id=event.provider_id,
)
)
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type,
start_at=event.start_at,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
agent_strategy=event.agent_strategy,
provider_type=event.provider_type,
provider_id=event.provider_id,
)
)
elif isinstance(event, NodeRunSucceededEvent):
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeSucceededEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeRunFailedEvent):
self._publish_event(
QueueNodeFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeRunExceptionEvent):
self._publish_event(
QueueNodeExceptionEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
text=event.chunk,
from_variable_selector=list(event.selector),
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):
self._publish_event(
QueueRetrieverResourcesEvent(
retriever_resources=event.retriever_resources,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeRunAgentLogEvent):
self._publish_event(
QueueAgentLogEvent(
id=event.message_id,
label=event.label,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
node_id=event.node_id,
)
)
elif isinstance(event, NodeRunIterationStartedEvent):
self._publish_event(
QueueIterationStartEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
metadata=event.metadata,
)
)
elif isinstance(event, NodeRunIterationNextEvent):
self._publish_event(
QueueIterationNextEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
)
)
elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
self._publish_event(
QueueIterationCompletedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
)
)
elif isinstance(event, NodeRunLoopStartedEvent):
self._publish_event(
QueueLoopStartEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
metadata=event.metadata,
)
)
elif isinstance(event, NodeRunLoopNextEvent):
self._publish_event(
QueueLoopNextEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_loop_output,
)
)
elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
self._publish_event(
QueueLoopCompletedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
)
)
def _publish_event(self, event: AppQueueEvent):
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)