dify
This commit is contained in:
0
dify/api/core/prompt/__init__.py
Normal file
0
dify/api/core/prompt/__init__.py
Normal file
302
dify/api/core/prompt/advanced_prompt_transform.py
Normal file
302
dify/api/core/prompt/advanced_prompt_transform.py
Normal file
@@ -0,0 +1,302 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import file_manager
|
||||
from core.file.models import File
|
||||
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
class AdvancedPromptTransform(PromptTransform):
|
||||
"""
|
||||
Advanced Prompt Transform for Workflow LLM Node.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
with_variable_tmpl: bool = False,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
|
||||
):
|
||||
self.with_variable_tmpl = with_variable_tmpl
|
||||
self.image_detail_config = image_detail_config
|
||||
|
||||
def get_prompt(
|
||||
self,
|
||||
*,
|
||||
prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate,
|
||||
inputs: Mapping[str, str],
|
||||
query: str,
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
prompt_messages = []
|
||||
|
||||
if isinstance(prompt_template, CompletionModelPromptTemplate):
|
||||
prompt_messages = self._get_completion_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
|
||||
prompt_messages = self._get_chat_model_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _get_completion_model_prompt_messages(
|
||||
self,
|
||||
prompt_template: CompletionModelPromptTemplate,
|
||||
inputs: Mapping[str, str],
|
||||
query: str | None,
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Get completion model prompt messages.
|
||||
"""
|
||||
raw_prompt = prompt_template.text
|
||||
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
|
||||
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
|
||||
|
||||
if memory and memory_config and memory_config.role_prefix:
|
||||
role_prefix = memory_config.role_prefix
|
||||
prompt_inputs = self._set_histories_variable(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
raw_prompt=raw_prompt,
|
||||
role_prefix=role_prefix,
|
||||
parser=parser,
|
||||
prompt_inputs=prompt_inputs,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if query:
|
||||
prompt_inputs = self._set_query_variable(query, parser, prompt_inputs)
|
||||
|
||||
prompt = parser.format(prompt_inputs)
|
||||
else:
|
||||
prompt = raw_prompt
|
||||
prompt_inputs = inputs
|
||||
|
||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||
|
||||
if files:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _get_chat_model_prompt_messages(
|
||||
self,
|
||||
prompt_template: list[ChatModelMessage],
|
||||
inputs: Mapping[str, str],
|
||||
query: str | None,
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Get chat model prompt messages.
|
||||
"""
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for prompt_item in prompt_template:
|
||||
raw_prompt = prompt_item.text
|
||||
|
||||
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
|
||||
if self.with_variable_tmpl:
|
||||
vp = VariablePool.empty()
|
||||
for k, v in inputs.items():
|
||||
if k.startswith("#"):
|
||||
vp.add(k[1:-1].split("."), v)
|
||||
raw_prompt = raw_prompt.replace("{{#context#}}", context or "")
|
||||
prompt = vp.convert_template(raw_prompt).text
|
||||
else:
|
||||
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
prompt_inputs = self._set_context_variable(
|
||||
context=context, parser=parser, prompt_inputs=prompt_inputs
|
||||
)
|
||||
prompt = parser.format(prompt_inputs)
|
||||
elif prompt_item.edition_type == "jinja2":
|
||||
prompt = raw_prompt
|
||||
prompt_inputs = inputs
|
||||
prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs)
|
||||
else:
|
||||
raise ValueError(f"Invalid edition type: {prompt_item.edition_type}")
|
||||
|
||||
if prompt_item.role == PromptMessageRole.USER:
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
elif prompt_item.role == PromptMessageRole.SYSTEM and prompt:
|
||||
prompt_messages.append(SystemPromptMessage(content=prompt))
|
||||
elif prompt_item.role == PromptMessageRole.ASSISTANT:
|
||||
prompt_messages.append(AssistantPromptMessage(content=prompt))
|
||||
|
||||
if query and memory_config and memory_config.query_prompt_template:
|
||||
parser = PromptTemplateParser(
|
||||
template=memory_config.query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
|
||||
)
|
||||
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
prompt_inputs["#sys.query#"] = query
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
|
||||
|
||||
query = parser.format(prompt_inputs)
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
if memory and memory_config:
|
||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||
|
||||
if files and query is not None:
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
elif files:
|
||||
if not query:
|
||||
# get last message
|
||||
last_message = prompt_messages[-1] if prompt_messages else None
|
||||
if last_message and last_message.role == PromptMessageRole.USER:
|
||||
# get last user message content and add files
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=cast(str, last_message.content)))
|
||||
|
||||
last_message.content = prompt_message_contents
|
||||
else:
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=""))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
elif query:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _set_context_variable(
|
||||
self, context: str | None, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str]
|
||||
) -> Mapping[str, str]:
|
||||
prompt_inputs = dict(prompt_inputs)
|
||||
if "#context#" in parser.variable_keys:
|
||||
if context:
|
||||
prompt_inputs["#context#"] = context
|
||||
else:
|
||||
prompt_inputs["#context#"] = ""
|
||||
|
||||
return prompt_inputs
|
||||
|
||||
def _set_query_variable(
|
||||
self, query: str, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str]
|
||||
) -> Mapping[str, str]:
|
||||
prompt_inputs = dict(prompt_inputs)
|
||||
if "#query#" in parser.variable_keys:
|
||||
if query:
|
||||
prompt_inputs["#query#"] = query
|
||||
else:
|
||||
prompt_inputs["#query#"] = ""
|
||||
|
||||
return prompt_inputs
|
||||
|
||||
def _set_histories_variable(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
raw_prompt: str,
|
||||
role_prefix: MemoryConfig.RolePrefix,
|
||||
parser: PromptTemplateParser,
|
||||
prompt_inputs: Mapping[str, str],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> Mapping[str, str]:
|
||||
prompt_inputs = dict(prompt_inputs)
|
||||
if "#histories#" in parser.variable_keys:
|
||||
if memory:
|
||||
inputs = {"#histories#": "", **prompt_inputs}
|
||||
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
|
||||
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||
|
||||
histories = self._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
max_token_limit=rest_tokens,
|
||||
human_prefix=role_prefix.user,
|
||||
ai_prefix=role_prefix.assistant,
|
||||
)
|
||||
prompt_inputs["#histories#"] = histories
|
||||
else:
|
||||
prompt_inputs["#histories#"] = ""
|
||||
|
||||
return prompt_inputs
|
||||
80
dify/api/core/prompt/agent_history_prompt_transform.py
Normal file
80
dify/api/core/prompt/agent_history_prompt_transform.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
ModelConfigWithCredentialsEntity,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
|
||||
|
||||
class AgentHistoryPromptTransform(PromptTransform):
|
||||
"""
|
||||
History Prompt Transform for Agent App
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_messages: list[PromptMessage],
|
||||
history_messages: list[PromptMessage],
|
||||
memory: TokenBufferMemory | None = None,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.prompt_messages = prompt_messages
|
||||
self.history_messages = history_messages
|
||||
self.memory = memory
|
||||
|
||||
def get_prompt(self) -> list[PromptMessage]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
num_system = 0
|
||||
for prompt_message in self.history_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
prompt_messages.append(prompt_message)
|
||||
num_system += 1
|
||||
|
||||
if not self.memory:
|
||||
return prompt_messages
|
||||
|
||||
max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
|
||||
|
||||
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages
|
||||
)
|
||||
if curr_message_tokens <= max_token_limit:
|
||||
return self.history_messages
|
||||
|
||||
# number of prompt has been appended in current message
|
||||
num_prompt = 0
|
||||
# append prompt messages in desc order
|
||||
for prompt_message in self.history_messages[::-1]:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
continue
|
||||
prompt_messages.append(prompt_message)
|
||||
num_prompt += 1
|
||||
# a message is start with UserPromptMessage
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages
|
||||
)
|
||||
# if current message token is overflow, drop all the prompts in current message and break
|
||||
if curr_message_tokens > max_token_limit:
|
||||
prompt_messages = prompt_messages[:-num_prompt]
|
||||
break
|
||||
num_prompt = 0
|
||||
# return prompt messages in asc order
|
||||
message_prompts = prompt_messages[num_system:]
|
||||
message_prompts.reverse()
|
||||
|
||||
# merge system and message prompt
|
||||
prompt_messages = prompt_messages[:num_system]
|
||||
prompt_messages.extend(message_prompts)
|
||||
return prompt_messages
|
||||
0
dify/api/core/prompt/entities/__init__.py
Normal file
0
dify/api/core/prompt/entities/__init__.py
Normal file
50
dify/api/core/prompt/entities/advanced_prompt_entities.py
Normal file
50
dify/api/core/prompt/entities/advanced_prompt_entities.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
|
||||
|
||||
class ChatModelMessage(BaseModel):
|
||||
"""
|
||||
Chat Message.
|
||||
"""
|
||||
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
edition_type: Literal["basic", "jinja2"] | None = None
|
||||
|
||||
|
||||
class CompletionModelPromptTemplate(BaseModel):
|
||||
"""
|
||||
Completion Model Prompt Template.
|
||||
"""
|
||||
|
||||
text: str
|
||||
edition_type: Literal["basic", "jinja2"] | None = None
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""
|
||||
Memory Config.
|
||||
"""
|
||||
|
||||
class RolePrefix(BaseModel):
|
||||
"""
|
||||
Role Prefix.
|
||||
"""
|
||||
|
||||
user: str
|
||||
assistant: str
|
||||
|
||||
class WindowConfig(BaseModel):
|
||||
"""
|
||||
Window Config.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
size: int | None = None
|
||||
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: str | None = None
|
||||
0
dify/api/core/prompt/prompt_templates/__init__.py
Normal file
0
dify/api/core/prompt/prompt_templates/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
CONTEXT = "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n" # noqa: E501
|
||||
|
||||
BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n" # noqa: E501
|
||||
|
||||
CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}\nHere are the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: " # noqa: E501
|
||||
},
|
||||
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
|
||||
},
|
||||
"stop": ["Human:"],
|
||||
}
|
||||
|
||||
CHAT_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}}
|
||||
|
||||
COMPLETION_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}}
|
||||
|
||||
COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
|
||||
"stop": ["Human:"],
|
||||
}
|
||||
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}" # noqa: E501
|
||||
},
|
||||
"conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"},
|
||||
},
|
||||
"stop": ["用户:"],
|
||||
}
|
||||
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}
|
||||
}
|
||||
|
||||
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}
|
||||
}
|
||||
|
||||
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
|
||||
"stop": ["用户:"],
|
||||
}
|
||||
13
dify/api/core/prompt/prompt_templates/baichuan_chat.json
Normal file
13
dify/api/core/prompt/prompt_templates/baichuan_chat.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"human_prefix": "用户",
|
||||
"assistant_prefix": "助手",
|
||||
"context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n",
|
||||
"histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n",
|
||||
"system_prompt_orders": [
|
||||
"context_prompt",
|
||||
"pre_prompt",
|
||||
"histories_prompt"
|
||||
],
|
||||
"query_prompt": "\n\n用户:{{#query#}}",
|
||||
"stops": ["用户:"]
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n",
|
||||
"system_prompt_orders": [
|
||||
"context_prompt",
|
||||
"pre_prompt"
|
||||
],
|
||||
"query_prompt": "{{#query#}}",
|
||||
"stops": null
|
||||
}
|
||||
13
dify/api/core/prompt/prompt_templates/common_chat.json
Normal file
13
dify/api/core/prompt/prompt_templates/common_chat.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"human_prefix": "Human",
|
||||
"assistant_prefix": "Assistant",
|
||||
"context_prompt": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
|
||||
"histories_prompt": "Here is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n",
|
||||
"system_prompt_orders": [
|
||||
"context_prompt",
|
||||
"pre_prompt",
|
||||
"histories_prompt"
|
||||
],
|
||||
"query_prompt": "\n\nHuman: {{#query#}}\n\nAssistant: ",
|
||||
"stops": ["\nHuman:", "</histories>"]
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"context_prompt": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
|
||||
"system_prompt_orders": [
|
||||
"context_prompt",
|
||||
"pre_prompt"
|
||||
],
|
||||
"query_prompt": "{{#query#}}",
|
||||
"stops": null
|
||||
}
|
||||
90
dify/api/core/prompt/prompt_transform.py
Normal file
90
dify/api/core/prompt/prompt_transform.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
|
||||
|
||||
class PromptTransform:
|
||||
def _append_chat_histories(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> list[PromptMessage]:
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
||||
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
|
||||
prompt_messages.extend(histories)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _calculate_rest_token(
|
||||
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _get_history_messages_from_memory(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
max_token_limit: int,
|
||||
human_prefix: str | None = None,
|
||||
ai_prefix: str | None = None,
|
||||
) -> str:
|
||||
"""Get memory messages."""
|
||||
kwargs: dict[str, Any] = {"max_token_limit": max_token_limit}
|
||||
|
||||
if human_prefix:
|
||||
kwargs["human_prefix"] = human_prefix
|
||||
|
||||
if ai_prefix:
|
||||
kwargs["ai_prefix"] = ai_prefix
|
||||
|
||||
if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
|
||||
kwargs["message_limit"] = memory_config.window.size
|
||||
|
||||
return memory.get_history_prompt_text(**kwargs)
|
||||
|
||||
def _get_history_messages_list_from_memory(
|
||||
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
) -> list[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
return list(
|
||||
memory.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=memory_config.window.size
|
||||
if (
|
||||
memory_config.window.enabled
|
||||
and memory_config.window.size is not None
|
||||
and memory_config.window.size > 0
|
||||
)
|
||||
else None,
|
||||
)
|
||||
)
|
||||
347
dify/api/core/prompt/simple_prompt_transform.py
Normal file
347
dify/api/core/prompt/simple_prompt_transform.py
Normal file
@@ -0,0 +1,347 @@
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.app.app_config.entities import PromptTemplateEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import file_manager
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from models.model import AppMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
class ModelMode(StrEnum):
|
||||
COMPLETION = auto()
|
||||
CHAT = auto()
|
||||
|
||||
|
||||
prompt_file_contents: dict[str, Any] = {}
|
||||
|
||||
|
||||
class SimplePromptTransform(PromptTransform):
|
||||
"""
|
||||
Simple Prompt Transform for Chatbot App Basic Mode.
|
||||
"""
|
||||
|
||||
def get_prompt(
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: Mapping[str, str],
|
||||
query: str,
|
||||
files: Sequence["File"],
|
||||
context: str | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
|
||||
model_mode = ModelMode(model_config.mode)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
prompt_messages, stops = self._get_chat_model_prompt_messages(
|
||||
app_mode=app_mode,
|
||||
pre_prompt=prompt_template_entity.simple_prompt_template or "",
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
else:
|
||||
prompt_messages, stops = self._get_completion_model_prompt_messages(
|
||||
app_mode=app_mode,
|
||||
pre_prompt=prompt_template_entity.simple_prompt_template or "",
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
return prompt_messages, stops
|
||||
|
||||
def _get_prompt_str_and_rules(
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str | None = None,
|
||||
context: str | None = None,
|
||||
histories: str | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
# get prompt template
|
||||
prompt_template_config = self.get_prompt_template(
|
||||
app_mode=app_mode,
|
||||
provider=model_config.provider,
|
||||
model=model_config.model,
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=context is not None,
|
||||
query_in_prompt=query is not None,
|
||||
with_memory_prompt=histories is not None,
|
||||
)
|
||||
|
||||
custom_variable_keys_obj = prompt_template_config["custom_variable_keys"]
|
||||
special_variable_keys_obj = prompt_template_config["special_variable_keys"]
|
||||
|
||||
# Type check for custom_variable_keys
|
||||
if not isinstance(custom_variable_keys_obj, list):
|
||||
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}")
|
||||
custom_variable_keys = cast(list[str], custom_variable_keys_obj)
|
||||
|
||||
# Type check for special_variable_keys
|
||||
if not isinstance(special_variable_keys_obj, list):
|
||||
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}")
|
||||
special_variable_keys = cast(list[str], special_variable_keys_obj)
|
||||
|
||||
variables = {k: inputs[k] for k in custom_variable_keys if k in inputs}
|
||||
|
||||
for v in special_variable_keys:
|
||||
# support #context#, #query# and #histories#
|
||||
if v == "#context#":
|
||||
variables["#context#"] = context or ""
|
||||
elif v == "#query#":
|
||||
variables["#query#"] = query or ""
|
||||
elif v == "#histories#":
|
||||
variables["#histories#"] = histories or ""
|
||||
|
||||
prompt_template = prompt_template_config["prompt_template"]
|
||||
if not isinstance(prompt_template, PromptTemplateParser):
|
||||
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template)}")
|
||||
|
||||
prompt = prompt_template.format(variables)
|
||||
|
||||
prompt_rules = prompt_template_config["prompt_rules"]
|
||||
if not isinstance(prompt_rules, dict):
|
||||
raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}")
|
||||
|
||||
return prompt, prompt_rules
|
||||
|
||||
def get_prompt_template(
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
provider: str,
|
||||
model: str,
|
||||
pre_prompt: str,
|
||||
has_context: bool,
|
||||
query_in_prompt: bool,
|
||||
with_memory_prompt: bool = False,
|
||||
) -> dict[str, object]:
|
||||
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
|
||||
|
||||
custom_variable_keys: list[str] = []
|
||||
special_variable_keys: list[str] = []
|
||||
|
||||
prompt = ""
|
||||
for order in prompt_rules["system_prompt_orders"]:
|
||||
if order == "context_prompt" and has_context:
|
||||
prompt += prompt_rules["context_prompt"]
|
||||
special_variable_keys.append("#context#")
|
||||
elif order == "pre_prompt" and pre_prompt:
|
||||
prompt += pre_prompt + "\n"
|
||||
pre_prompt_template = PromptTemplateParser(template=pre_prompt)
|
||||
custom_variable_keys = pre_prompt_template.variable_keys
|
||||
elif order == "histories_prompt" and with_memory_prompt:
|
||||
prompt += prompt_rules["histories_prompt"]
|
||||
special_variable_keys.append("#histories#")
|
||||
|
||||
if query_in_prompt:
|
||||
prompt += prompt_rules.get("query_prompt", "{{#query#}}")
|
||||
special_variable_keys.append("#query#")
|
||||
|
||||
return {
|
||||
"prompt_template": PromptTemplateParser(template=prompt),
|
||||
"custom_variable_keys": custom_variable_keys,
|
||||
"special_variable_keys": special_variable_keys,
|
||||
"prompt_rules": prompt_rules,
|
||||
}
|
||||
|
||||
def _get_chat_model_prompt_messages(
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: str | None,
|
||||
files: Sequence["File"],
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
# get prompt
|
||||
prompt, _ = self._get_prompt_str_and_rules(
|
||||
app_mode=app_mode,
|
||||
model_config=model_config,
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
context=context,
|
||||
)
|
||||
|
||||
if prompt and query:
|
||||
prompt_messages.append(SystemPromptMessage(content=prompt))
|
||||
|
||||
if memory:
|
||||
prompt_messages = self._append_chat_histories(
|
||||
memory=memory,
|
||||
memory_config=MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False,
|
||||
)
|
||||
),
|
||||
prompt_messages=prompt_messages,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if query:
|
||||
prompt_messages.append(self._get_last_user_message(query, files, image_detail_config))
|
||||
else:
|
||||
prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config))
|
||||
|
||||
return prompt_messages, None
|
||||
|
||||
def _get_completion_model_prompt_messages(
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: str | None,
|
||||
files: Sequence["File"],
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||
# get prompt
|
||||
prompt, prompt_rules = self._get_prompt_str_and_rules(
|
||||
app_mode=app_mode,
|
||||
model_config=model_config,
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
context=context,
|
||||
)
|
||||
|
||||
if memory:
|
||||
tmp_human_message = UserPromptMessage(content=prompt)
|
||||
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||
histories = self._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
memory_config=MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False,
|
||||
)
|
||||
),
|
||||
max_token_limit=rest_tokens,
|
||||
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
||||
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
|
||||
)
|
||||
|
||||
# get prompt
|
||||
prompt, prompt_rules = self._get_prompt_str_and_rules(
|
||||
app_mode=app_mode,
|
||||
model_config=model_config,
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
context=context,
|
||||
histories=histories,
|
||||
)
|
||||
|
||||
stops = prompt_rules.get("stops")
|
||||
if stops is not None and len(stops) == 0:
|
||||
stops = None
|
||||
|
||||
return [self._get_last_user_message(prompt, files, image_detail_config)], stops
|
||||
|
||||
def _get_last_user_message(
|
||||
self,
|
||||
prompt: str,
|
||||
files: Sequence["File"],
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> UserPromptMessage:
|
||||
if files:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||
|
||||
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
||||
else:
|
||||
prompt_message = UserPromptMessage(content=prompt)
|
||||
|
||||
return prompt_message
|
||||
|
||||
def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str):
|
||||
"""
|
||||
Get simple prompt rule.
|
||||
:param app_mode: app mode
|
||||
:param provider: model provider
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model)
|
||||
|
||||
# Check if the prompt file is already loaded
|
||||
if prompt_file_name in prompt_file_contents:
|
||||
return cast(dict, prompt_file_contents[prompt_file_name])
|
||||
|
||||
# Get the absolute path of the subdirectory
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")
|
||||
json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json")
|
||||
|
||||
# Open the JSON file and read its content
|
||||
with open(json_file_path, encoding="utf-8") as json_file:
|
||||
content = json.load(json_file)
|
||||
|
||||
# Store the content of the prompt file
|
||||
prompt_file_contents[prompt_file_name] = content
|
||||
|
||||
return cast(dict, content)
|
||||
|
||||
def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
|
||||
# baichuan
|
||||
is_baichuan = False
|
||||
if provider == "baichuan":
|
||||
is_baichuan = True
|
||||
else:
|
||||
baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"]
|
||||
if provider in baichuan_supported_providers and "baichuan" in model.lower():
|
||||
is_baichuan = True
|
||||
|
||||
if is_baichuan:
|
||||
if app_mode == AppMode.COMPLETION:
|
||||
return "baichuan_completion"
|
||||
else:
|
||||
return "baichuan_chat"
|
||||
|
||||
# common
|
||||
if app_mode == AppMode.COMPLETION:
|
||||
return "common_completion"
|
||||
else:
|
||||
return "common_chat"
|
||||
0
dify/api/core/prompt/utils/__init__.py
Normal file
0
dify/api/core/prompt/utils/__init__.py
Normal file
25
dify/api/core/prompt/utils/extract_thread_messages.py
Normal file
25
dify/api/core/prompt/utils/extract_thread_messages.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from constants import UUID_NIL
|
||||
from models import Message
|
||||
|
||||
|
||||
def extract_thread_messages(messages: Sequence[Message]):
|
||||
thread_messages: list[Message] = []
|
||||
next_message = None
|
||||
|
||||
for message in messages:
|
||||
if not message.parent_message_id:
|
||||
# If the message is regenerated and does not have a parent message, it is the start of a new thread
|
||||
thread_messages.append(message)
|
||||
break
|
||||
|
||||
if not next_message:
|
||||
thread_messages.append(message)
|
||||
next_message = message.parent_message_id
|
||||
else:
|
||||
if next_message in {message.id, UUID_NIL}:
|
||||
thread_messages.append(message)
|
||||
next_message = message.parent_message_id
|
||||
|
||||
return thread_messages
|
||||
24
dify/api/core/prompt/utils/get_thread_messages_length.py
Normal file
24
dify/api/core/prompt/utils/get_thread_messages_length.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
|
||||
|
||||
def get_thread_messages_length(conversation_id: str) -> int:
|
||||
"""
|
||||
Get the number of thread messages based on the parent message id.
|
||||
"""
|
||||
# Fetch all messages related to the conversation
|
||||
stmt = select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at.desc())
|
||||
|
||||
messages = db.session.scalars(stmt).all()
|
||||
|
||||
# Extract thread messages
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# Exclude the newly created message with an empty answer
|
||||
if thread_messages and not thread_messages[0].answer:
|
||||
thread_messages.pop(0)
|
||||
|
||||
return len(thread_messages)
|
||||
113
dify/api/core/prompt/utils/prompt_message_util.py
Normal file
113
dify/api/core/prompt/utils/prompt_message_util.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
|
||||
|
||||
class PromptMessageUtil:
|
||||
@staticmethod
|
||||
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]):
|
||||
"""
|
||||
Prompt messages to prompt for saving.
|
||||
:param model_mode: model mode
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
prompts = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
tool_calls = []
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
role = "user"
|
||||
elif prompt_message.role == PromptMessageRole.ASSISTANT:
|
||||
role = "assistant"
|
||||
if isinstance(prompt_message, AssistantPromptMessage):
|
||||
tool_calls = [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
for tool_call in prompt_message.tool_calls
|
||||
]
|
||||
elif prompt_message.role == PromptMessageRole.SYSTEM:
|
||||
role = "system"
|
||||
elif prompt_message.role == PromptMessageRole.TOOL:
|
||||
role = "tool"
|
||||
else:
|
||||
continue
|
||||
|
||||
text = ""
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
files.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
"detail": content.detail.value,
|
||||
}
|
||||
)
|
||||
elif isinstance(content, AudioPromptMessageContent):
|
||||
files.append(
|
||||
{
|
||||
"type": "audio",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
"format": content.format,
|
||||
}
|
||||
)
|
||||
else:
|
||||
text = cast(str, prompt_message.content)
|
||||
|
||||
prompt = {"role": role, "text": text, "files": files}
|
||||
|
||||
if tool_calls:
|
||||
prompt["tool_calls"] = tool_calls
|
||||
|
||||
prompts.append(prompt)
|
||||
else:
|
||||
prompt_message = prompt_messages[0]
|
||||
text = ""
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
files.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
"detail": content.detail.value,
|
||||
}
|
||||
)
|
||||
else:
|
||||
text = cast(str, prompt_message.content)
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"role": "user",
|
||||
"text": text,
|
||||
}
|
||||
|
||||
if files:
|
||||
params["files"] = files
|
||||
|
||||
prompts.append(params)
|
||||
|
||||
return prompts
|
||||
46
dify/api/core/prompt/utils/prompt_template_parser.py
Normal file
46
dify/api/core/prompt/utils/prompt_template_parser.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
|
||||
REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#histories#|#query#|#context#)\}\}")
|
||||
WITH_VARIABLE_TMPL_REGEX = re.compile(
|
||||
r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#[a-zA-Z0-9_]{1,50}\.[a-zA-Z0-9_\.]{1,100}#|#histories#|#query#|#context#)\}\}"
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplateParser:
|
||||
"""
|
||||
Rules:
|
||||
|
||||
1. Template variables must be enclosed in `{{}}`.
|
||||
2. The template variable Key can only be: letters + numbers + underscore, with a maximum length of 16 characters,
|
||||
and can only start with letters and underscores.
|
||||
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
|
||||
4. In addition to the above, 3 types of special template variable Keys are accepted:
|
||||
`{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed.
|
||||
"""
|
||||
|
||||
def __init__(self, template: str, with_variable_tmpl: bool = False):
|
||||
self.template = template
|
||||
self.with_variable_tmpl = with_variable_tmpl
|
||||
self.regex = WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX
|
||||
self.variable_keys = self.extract()
|
||||
|
||||
def extract(self):
|
||||
# Regular expression to match the template rules
|
||||
return re.findall(self.regex, self.template)
|
||||
|
||||
def format(self, inputs: Mapping[str, str], remove_template_variables: bool = True) -> str:
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
value = inputs.get(key, match.group(0)) # return original matched string if key not found
|
||||
|
||||
if remove_template_variables and isinstance(value, str):
|
||||
return PromptTemplateParser.remove_template_variables(value, self.with_variable_tmpl)
|
||||
return value
|
||||
|
||||
prompt = re.sub(self.regex, replacer, self.template)
|
||||
return re.sub(r"<\|.*?\|>", "", prompt)
|
||||
|
||||
@classmethod
|
||||
def remove_template_variables(cls, text: str, with_variable_tmpl: bool = False):
|
||||
return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r"{\1}", text)
|
||||
Reference in New Issue
Block a user