dify
This commit is contained in:
0
dify/api/core/app/apps/chat/__init__.py
Normal file
0
dify/api/core/app/apps/chat/__init__.py
Normal file
148
dify/api/core/app/apps/chat/app_config_manager.py
Normal file
148
dify/api/core/app/apps/chat/app_config_manager.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.prompt_template.manager import PromptTemplateConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.variables.manager import BasicVariablesConfigManager
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
|
||||
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
|
||||
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
|
||||
from core.app.app_config.features.suggested_questions_after_answer.manager import (
|
||||
SuggestedQuestionsAfterAnswerConfigManager,
|
||||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation
|
||||
|
||||
|
||||
class ChatAppConfig(EasyUIBasedAppConfig):
|
||||
"""
|
||||
Chatbot App Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(
|
||||
cls,
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Conversation | None = None,
|
||||
override_config_dict: dict | None = None,
|
||||
) -> ChatAppConfig:
|
||||
"""
|
||||
Convert app model config to chat app config
|
||||
:param app_model: app model
|
||||
:param app_model_config: app model config
|
||||
:param conversation: conversation
|
||||
:param override_config_dict: app model config dict
|
||||
:return:
|
||||
"""
|
||||
if override_config_dict:
|
||||
config_from = EasyUIBasedAppModelConfigFrom.ARGS
|
||||
elif conversation:
|
||||
config_from = EasyUIBasedAppModelConfigFrom.CONVERSATION_SPECIFIC_CONFIG
|
||||
else:
|
||||
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
|
||||
|
||||
if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
if not override_config_dict:
|
||||
raise Exception("override_config_dict is required when config_from is ARGS")
|
||||
|
||||
config_dict = override_config_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = ChatAppConfig(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
app_mode=app_mode,
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
dataset=DatasetConfigManager.convert(config=config_dict),
|
||||
additional_features=cls.convert_features(config_dict, app_mode),
|
||||
)
|
||||
|
||||
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict):
|
||||
"""
|
||||
Validate for chat app model config
|
||||
|
||||
:param tenant_id: tenant id
|
||||
:param config: app model config args
|
||||
"""
|
||||
app_mode = AppMode.CHAT
|
||||
|
||||
related_config_keys = []
|
||||
|
||||
# model
|
||||
config, current_related_config_keys = ModelConfigManager.validate_and_set_defaults(tenant_id, config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# user_input_form
|
||||
config, current_related_config_keys = BasicVariablesConfigManager.validate_and_set_defaults(tenant_id, config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# prompt
|
||||
config, current_related_config_keys = PromptTemplateConfigManager.validate_and_set_defaults(app_mode, config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# dataset_query_variable
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
|
||||
tenant_id, app_mode, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# opening_statement
|
||||
config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# suggested_questions_after_answer
|
||||
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||
config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# speech_to_text
|
||||
config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# text_to_speech
|
||||
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# return retriever resource
|
||||
config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
255
dify/api/core/app/apps/chat/app_generator.py
Normal file
255
dify/api/core/app/apps/chat/app_generator.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, copy_current_request_context, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.chat.app_runner import ChatAppRunner
|
||||
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account
|
||||
from models.model import App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param invoke_from: invoke from source
|
||||
:param streaming: is stream
|
||||
"""
|
||||
if not args.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
query = args["query"]
|
||||
if not isinstance(query, str):
|
||||
raise ValueError("query must be a string")
|
||||
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
conversation_id = args.get("conversation_id")
|
||||
if conversation_id:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
|
||||
# validate override model config
|
||||
override_model_config_dict = None
|
||||
if args.get("model_config"):
|
||||
if invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ValueError("Only in App debug mode can override model config")
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = ChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config", {})
|
||||
)
|
||||
|
||||
# always enable retriever resource in debugger mode
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = ChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = ChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread with request context
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
return self._generate_worker(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: ChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
):
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param conversation_id: conversation ID
|
||||
:param message_id: message ID
|
||||
:return:
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
|
||||
# chatbot app
|
||||
runner = ChatAppRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.close()
|
||||
223
dify/api/core/app/apps/chat/app_runner.py
Normal file
223
dify/api/core/app/apps/chat/app_runner.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfig
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
ChatAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatAppRunner(AppRunner):
|
||||
"""
|
||||
Chat Application Runner
|
||||
"""
|
||||
|
||||
def run(
|
||||
self,
|
||||
application_generate_entity: ChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
):
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(ChatAppConfig, app_config)
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(stmt)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
image_detail_config = (
|
||||
application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
application_generate_entity.file_upload_config
|
||||
and application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
# moderation
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
app_id=app_record.id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationError as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
if query:
|
||||
# annotation reply
|
||||
annotation_reply = self.query_app_annotations_to_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=annotation_reply.content,
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_config.external_data_variables
|
||||
if external_data_tools:
|
||||
inputs = self.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
|
||||
# get context from datasets
|
||||
context = None
|
||||
if app_config.dataset and app_config.dataset.dataset_ids:
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager,
|
||||
app_record.id,
|
||||
message.id,
|
||||
application_generate_entity.user_id,
|
||||
application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
||||
context = dataset_retrieval.retrieve(
|
||||
app_id=app_record.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
config=app_config.dataset,
|
||||
query=query,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=(
|
||||
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
|
||||
),
|
||||
hit_callback=hit_callback,
|
||||
memory=memory,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=application_generate_entity.model_conf.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
||||
)
|
||||
122
dify/api/core/app/apps/chat/generate_response_converter.py
Normal file
122
dify/api/core/app/apps/chat/generate_response_converter.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"conversation_id": blocking_response.data.conversation_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get("metadata", {})
|
||||
if isinstance(metadata, dict):
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
else:
|
||||
response["metadata"] = {}
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
:return:
|
||||
"""
|
||||
for chunk in stream_response:
|
||||
chunk = cast(ChatbotAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
:return:
|
||||
"""
|
||||
for chunk in stream_response:
|
||||
chunk = cast(ChatbotAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
Reference in New Issue
Block a user