dify
This commit is contained in:
0
dify/api/core/prompt/utils/__init__.py
Normal file
0
dify/api/core/prompt/utils/__init__.py
Normal file
25
dify/api/core/prompt/utils/extract_thread_messages.py
Normal file
25
dify/api/core/prompt/utils/extract_thread_messages.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from constants import UUID_NIL
|
||||
from models import Message
|
||||
|
||||
|
||||
def extract_thread_messages(messages: Sequence[Message]):
|
||||
thread_messages: list[Message] = []
|
||||
next_message = None
|
||||
|
||||
for message in messages:
|
||||
if not message.parent_message_id:
|
||||
# If the message is regenerated and does not have a parent message, it is the start of a new thread
|
||||
thread_messages.append(message)
|
||||
break
|
||||
|
||||
if not next_message:
|
||||
thread_messages.append(message)
|
||||
next_message = message.parent_message_id
|
||||
else:
|
||||
if next_message in {message.id, UUID_NIL}:
|
||||
thread_messages.append(message)
|
||||
next_message = message.parent_message_id
|
||||
|
||||
return thread_messages
|
||||
24
dify/api/core/prompt/utils/get_thread_messages_length.py
Normal file
24
dify/api/core/prompt/utils/get_thread_messages_length.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
|
||||
|
||||
def get_thread_messages_length(conversation_id: str) -> int:
|
||||
"""
|
||||
Get the number of thread messages based on the parent message id.
|
||||
"""
|
||||
# Fetch all messages related to the conversation
|
||||
stmt = select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at.desc())
|
||||
|
||||
messages = db.session.scalars(stmt).all()
|
||||
|
||||
# Extract thread messages
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# Exclude the newly created message with an empty answer
|
||||
if thread_messages and not thread_messages[0].answer:
|
||||
thread_messages.pop(0)
|
||||
|
||||
return len(thread_messages)
|
||||
113
dify/api/core/prompt/utils/prompt_message_util.py
Normal file
113
dify/api/core/prompt/utils/prompt_message_util.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
|
||||
|
||||
class PromptMessageUtil:
|
||||
@staticmethod
|
||||
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]):
|
||||
"""
|
||||
Prompt messages to prompt for saving.
|
||||
:param model_mode: model mode
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
prompts = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
tool_calls = []
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
role = "user"
|
||||
elif prompt_message.role == PromptMessageRole.ASSISTANT:
|
||||
role = "assistant"
|
||||
if isinstance(prompt_message, AssistantPromptMessage):
|
||||
tool_calls = [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
for tool_call in prompt_message.tool_calls
|
||||
]
|
||||
elif prompt_message.role == PromptMessageRole.SYSTEM:
|
||||
role = "system"
|
||||
elif prompt_message.role == PromptMessageRole.TOOL:
|
||||
role = "tool"
|
||||
else:
|
||||
continue
|
||||
|
||||
text = ""
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
files.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
"detail": content.detail.value,
|
||||
}
|
||||
)
|
||||
elif isinstance(content, AudioPromptMessageContent):
|
||||
files.append(
|
||||
{
|
||||
"type": "audio",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
"format": content.format,
|
||||
}
|
||||
)
|
||||
else:
|
||||
text = cast(str, prompt_message.content)
|
||||
|
||||
prompt = {"role": role, "text": text, "files": files}
|
||||
|
||||
if tool_calls:
|
||||
prompt["tool_calls"] = tool_calls
|
||||
|
||||
prompts.append(prompt)
|
||||
else:
|
||||
prompt_message = prompt_messages[0]
|
||||
text = ""
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
files.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
"detail": content.detail.value,
|
||||
}
|
||||
)
|
||||
else:
|
||||
text = cast(str, prompt_message.content)
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"role": "user",
|
||||
"text": text,
|
||||
}
|
||||
|
||||
if files:
|
||||
params["files"] = files
|
||||
|
||||
prompts.append(params)
|
||||
|
||||
return prompts
|
||||
46
dify/api/core/prompt/utils/prompt_template_parser.py
Normal file
46
dify/api/core/prompt/utils/prompt_template_parser.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
|
||||
REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#histories#|#query#|#context#)\}\}")
|
||||
WITH_VARIABLE_TMPL_REGEX = re.compile(
|
||||
r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#[a-zA-Z0-9_]{1,50}\.[a-zA-Z0-9_\.]{1,100}#|#histories#|#query#|#context#)\}\}"
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplateParser:
|
||||
"""
|
||||
Rules:
|
||||
|
||||
1. Template variables must be enclosed in `{{}}`.
|
||||
2. The template variable Key can only be: letters + numbers + underscore, with a maximum length of 16 characters,
|
||||
and can only start with letters and underscores.
|
||||
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
|
||||
4. In addition to the above, 3 types of special template variable Keys are accepted:
|
||||
`{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed.
|
||||
"""
|
||||
|
||||
def __init__(self, template: str, with_variable_tmpl: bool = False):
|
||||
self.template = template
|
||||
self.with_variable_tmpl = with_variable_tmpl
|
||||
self.regex = WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX
|
||||
self.variable_keys = self.extract()
|
||||
|
||||
def extract(self):
|
||||
# Regular expression to match the template rules
|
||||
return re.findall(self.regex, self.template)
|
||||
|
||||
def format(self, inputs: Mapping[str, str], remove_template_variables: bool = True) -> str:
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
value = inputs.get(key, match.group(0)) # return original matched string if key not found
|
||||
|
||||
if remove_template_variables and isinstance(value, str):
|
||||
return PromptTemplateParser.remove_template_variables(value, self.with_variable_tmpl)
|
||||
return value
|
||||
|
||||
prompt = re.sub(self.regex, replacer, self.template)
|
||||
return re.sub(r"<\|.*?\|>", "", prompt)
|
||||
|
||||
@classmethod
|
||||
def remove_template_variables(cls, text: str, with_variable_tmpl: bool = False):
|
||||
return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r"{\1}", text)
|
||||
Reference in New Issue
Block a user