dify
This commit is contained in:
3
dify/api/core/workflow/nodes/__init__.py
Normal file
3
dify/api/core/workflow/nodes/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from core.workflow.enums import NodeType
|
||||
|
||||
__all__ = ["NodeType"]
|
||||
3
dify/api/core/workflow/nodes/agent/__init__.py
Normal file
3
dify/api/core/workflow/nodes/agent/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .agent_node import AgentNode
|
||||
|
||||
__all__ = ["AgentNode"]
|
||||
756
dify/api/core/workflow/nodes/agent/agent_node.py
Normal file
756
dify/api/core/workflow/nodes/agent/agent_node.py
Normal file
@@ -0,0 +1,756 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models import ToolFile
|
||||
from models.model import Conversation
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .exc import (
|
||||
AgentInputTypeError,
|
||||
AgentInvocationError,
|
||||
AgentMessageTransformError,
|
||||
AgentNodeError,
|
||||
AgentVariableNotFoundError,
|
||||
AgentVariableTypeError,
|
||||
ToolFileNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
|
||||
class AgentNode(Node):
|
||||
"""
|
||||
Agent Node
|
||||
"""
|
||||
|
||||
node_type = NodeType.AGENT
|
||||
_node_data: AgentNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AgentNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self._node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to get agent strategy: {str(e)}",
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
agent_parameters = strategy.get_parameters()
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self._node_data,
|
||||
strategy=strategy,
|
||||
)
|
||||
parameters_for_log = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self._node_data,
|
||||
for_log=True,
|
||||
strategy=strategy,
|
||||
)
|
||||
credentials = self._generate_credentials(parameters=parameters)
|
||||
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(error),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": self._node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
f"Failed to transform agent message: {str(e)}", original_error=e
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(transform_error),
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_agent_parameters(
|
||||
self,
|
||||
*,
|
||||
agent_parameters: Sequence[AgentStrategyParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
for_log: bool = False,
|
||||
strategy: "PluginAgentStrategy",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
agent_parameters (Sequence[AgentParameter]): The list of agent parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (AgentNodeData): The data associated with the agent node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
if agent_input.type == "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
elif agent_input.type in {"mixed", "constant"}:
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
else:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_type=provider_type,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use the node_data.version field for judgment
|
||||
# But for backward compatibility with historical data
|
||||
# this version field judgment is still preserved here.
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
runtime_variable_pool = variable_pool
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
ToolParameter.ToolParameterForm.FORM
|
||||
if tool_runtime_params.name in manual_input_params
|
||||
else tool_runtime_params.form
|
||||
)
|
||||
manual_input_value = {}
|
||||
if tool_runtime.entity.parameters:
|
||||
manual_input_value = {
|
||||
key: value for key, value in parameters.items() if key in manual_input_params
|
||||
}
|
||||
runtime_parameters = {
|
||||
**tool_runtime.runtime.runtime_parameters,
|
||||
**manual_input_value,
|
||||
}
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = cast(dict[str, Any], value)
|
||||
model_instance, model_schema = self._fetch_model(value)
|
||||
# memory config
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
memory = self._fetch_memory(model_instance)
|
||||
if memory:
|
||||
prompt_messages = memory.get_history_prompt_messages(
|
||||
message_limit=node_data.memory.window.size or None
|
||||
)
|
||||
history_prompt_messages = [
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
]
|
||||
value["history_prompt_messages"] = history_prompt_messages
|
||||
if model_schema:
|
||||
# remove structured output feature to support old version agent plugin
|
||||
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
|
||||
value["entity"] = model_schema.model_dump(mode="json")
|
||||
else:
|
||||
value["entity"] = None
|
||||
result[parameter_name] = value
|
||||
|
||||
return result
|
||||
|
||||
def _generate_credentials(
|
||||
self,
|
||||
parameters: dict[str, Any],
|
||||
) -> "InvokeCredentials":
|
||||
"""
|
||||
Generate credentials based on the given agent parameters.
|
||||
"""
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
credentials = InvokeCredentials()
|
||||
|
||||
# generate credentials for tools selector
|
||||
credentials.tool_credentials = {}
|
||||
for tool in parameters.get("tools", []):
|
||||
if tool.get("credential_id"):
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
except ValidationError:
|
||||
continue
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AgentNodeData.model_validate(node_data)
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
if input.type in ["mixed", "constant"]:
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def agent_strategy_icon(self) -> str | None:
|
||||
"""
|
||||
Get agent strategy icon
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(self.tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
return memory
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM, model=model_name
|
||||
)
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_instance, model_schema
|
||||
|
||||
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
|
||||
if model_schema.features:
|
||||
for feature in model_schema.features[:]: # Create a copy to safely modify during iteration
|
||||
try:
|
||||
AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value
|
||||
except ValueError:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
def _filter_mcp_type_tool(
|
||||
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter MCP type tool
|
||||
:param strategy: plugin agent strategy
|
||||
:param tool: tool
|
||||
:return: filtered tool dict
|
||||
"""
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json_list: list[dict] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise AgentVariableTypeError(
|
||||
"When 'stream' is True, 'variable_value' must be a string.",
|
||||
variable_name=variable_name,
|
||||
expected_type="str",
|
||||
actual_type=type(variable_value).__name__,
|
||||
)
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, dict)
|
||||
# Validate that meta contains a 'file' key
|
||||
if "file" not in message.meta:
|
||||
raise AgentNodeError("File message is missing 'file' key in meta")
|
||||
|
||||
# Validate that the file is an instance of File
|
||||
if not isinstance(message.meta["file"], File):
|
||||
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
message_id=message.message.id,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.message_id == agent_log.message_id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.message_id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json_list:
|
||||
json_output.extend(json_list)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
# Send final chunk events for all streamed outputs
|
||||
# Final chunk for text stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Final chunks for any streamed variables
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"text": text,
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
45
dify/api/core/workflow/nodes/agent/entities.py
Normal file
45
dify/api/core/workflow/nodes/agent/entities.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
class AgentNodeData(BaseNodeData):
|
||||
agent_strategy_provider_name: str # redundancy
|
||||
agent_strategy_name: str
|
||||
agent_strategy_label: str # redundancy
|
||||
memory: MemoryConfig | None = None
|
||||
# The version of the tool parameter.
|
||||
# If this value is None, it indicates this is a previous version
|
||||
# and requires using the legacy parameter parsing rules.
|
||||
tool_node_version: str | None = None
|
||||
|
||||
class AgentInput(BaseModel):
|
||||
value: Union[list[str], list[ToolSelector], Any]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
|
||||
agent_parameters: dict[str, AgentInput]
|
||||
|
||||
|
||||
class ParamsAutoGenerated(IntEnum):
|
||||
CLOSE = 0
|
||||
OPEN = 1
|
||||
|
||||
|
||||
class AgentOldVersionModelFeatures(StrEnum):
|
||||
"""
|
||||
Enum class for old SDK version llm feature.
|
||||
"""
|
||||
|
||||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = auto()
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = auto()
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
121
dify/api/core/workflow/nodes/agent/exc.py
Normal file
121
dify/api/core/workflow/nodes/agent/exc.py
Normal file
@@ -0,0 +1,121 @@
|
||||
class AgentNodeError(Exception):
|
||||
"""Base exception for all agent node errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class AgentStrategyError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent strategy."""
|
||||
|
||||
def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None):
|
||||
self.strategy_name = strategy_name
|
||||
self.provider_name = provider_name
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentStrategyNotFoundError(AgentStrategyError):
|
||||
"""Exception raised when the specified agent strategy is not found."""
|
||||
|
||||
def __init__(self, strategy_name: str, provider_name: str | None = None):
|
||||
super().__init__(
|
||||
f"Agent strategy '{strategy_name}' not found"
|
||||
+ (f" for provider '{provider_name}'" if provider_name else ""),
|
||||
strategy_name,
|
||||
provider_name,
|
||||
)
|
||||
|
||||
|
||||
class AgentInvocationError(AgentNodeError):
|
||||
"""Exception raised when there's an error invoking the agent."""
|
||||
|
||||
def __init__(self, message: str, original_error: Exception | None = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentParameterError(AgentNodeError):
|
||||
"""Exception raised when there's an error with agent parameters."""
|
||||
|
||||
def __init__(self, message: str, parameter_name: str | None = None):
|
||||
self.parameter_name = parameter_name
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentVariableError(AgentNodeError):
|
||||
"""Exception raised when there's an error with variables in the agent node."""
|
||||
|
||||
def __init__(self, message: str, variable_name: str | None = None):
|
||||
self.variable_name = variable_name
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentVariableNotFoundError(AgentVariableError):
|
||||
"""Exception raised when a variable is not found in the variable pool."""
|
||||
|
||||
def __init__(self, variable_name: str):
|
||||
super().__init__(f"Variable '{variable_name}' does not exist", variable_name)
|
||||
|
||||
|
||||
class AgentInputTypeError(AgentNodeError):
|
||||
"""Exception raised when an unknown agent input type is encountered."""
|
||||
|
||||
def __init__(self, input_type: str):
|
||||
super().__init__(f"Unknown agent input type '{input_type}'")
|
||||
|
||||
|
||||
class ToolFileError(AgentNodeError):
|
||||
"""Exception raised when there's an error with a tool file."""
|
||||
|
||||
def __init__(self, message: str, file_id: str | None = None):
|
||||
self.file_id = file_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ToolFileNotFoundError(ToolFileError):
|
||||
"""Exception raised when a tool file is not found."""
|
||||
|
||||
def __init__(self, file_id: str):
|
||||
super().__init__(f"Tool file '{file_id}' does not exist", file_id)
|
||||
|
||||
|
||||
class AgentMessageTransformError(AgentNodeError):
|
||||
"""Exception raised when there's an error transforming agent messages."""
|
||||
|
||||
def __init__(self, message: str, original_error: Exception | None = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentModelError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the model used by the agent."""
|
||||
|
||||
def __init__(self, message: str, model_name: str | None = None, provider: str | None = None):
|
||||
self.model_name = model_name
|
||||
self.provider = provider
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentMemoryError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent's memory."""
|
||||
|
||||
def __init__(self, message: str, conversation_id: str | None = None):
|
||||
self.conversation_id = conversation_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentVariableTypeError(AgentNodeError):
|
||||
"""Exception raised when a variable has an unexpected type."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
variable_name: str | None = None,
|
||||
expected_type: str | None = None,
|
||||
actual_type: str | None = None,
|
||||
):
|
||||
self.variable_name = variable_name
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
super().__init__(message)
|
||||
0
dify/api/core/workflow/nodes/answer/__init__.py
Normal file
0
dify/api/core/workflow/nodes/answer/__init__.py
Normal file
96
dify/api/core/workflow/nodes/answer/answer_node.py
Normal file
96
dify/api/core/workflow/nodes/answer/answer_node.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.variables import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.answer.entities import AnswerNodeData
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerNode(Node):
|
||||
node_type = NodeType.ANSWER
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
_node_data: AnswerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AnswerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer)
|
||||
files = self._extract_files_from_segments(segments.value)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)},
|
||||
)
|
||||
|
||||
def _extract_files_from_segments(self, segments: Sequence[Segment]):
|
||||
"""Extract all files from segments containing FileSegment or ArrayFileSegment instances.
|
||||
|
||||
FileSegment contains a single file, while ArrayFileSegment contains multiple files.
|
||||
This method flattens all files into a single list.
|
||||
"""
|
||||
files = []
|
||||
for segment in segments:
|
||||
if isinstance(segment, FileSegment):
|
||||
# Single file - wrap in list for consistency
|
||||
files.append(segment.value)
|
||||
elif isinstance(segment, ArrayFileSegment):
|
||||
# Multiple files - extend the list
|
||||
files.extend(segment.value)
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AnswerNodeData.model_validate(node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
|
||||
def get_streaming_template(self) -> Template:
|
||||
"""
|
||||
Get the template for streaming.
|
||||
|
||||
Returns:
|
||||
Template instance for this Answer node
|
||||
"""
|
||||
return Template.from_answer_template(self._node_data.answer)
|
||||
65
dify/api/core/workflow/nodes/answer/entities.py
Normal file
65
dify/api/core/workflow/nodes/answer/entities.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class AnswerNodeData(BaseNodeData):
|
||||
"""
|
||||
Answer Node Data.
|
||||
"""
|
||||
|
||||
answer: str = Field(..., description="answer template string")
|
||||
|
||||
|
||||
class GenerateRouteChunk(BaseModel):
|
||||
"""
|
||||
Generate Route Chunk.
|
||||
"""
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
VAR = auto()
|
||||
TEXT = auto()
|
||||
|
||||
type: ChunkType = Field(..., description="generate route chunk type")
|
||||
|
||||
|
||||
class VarGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Var Generate Route Chunk.
|
||||
"""
|
||||
|
||||
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
|
||||
"""generate route chunk type"""
|
||||
value_selector: Sequence[str] = Field(..., description="value selector")
|
||||
|
||||
|
||||
class TextGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Text Generate Route Chunk.
|
||||
"""
|
||||
|
||||
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT
|
||||
"""generate route chunk type"""
|
||||
text: str = Field(..., description="text")
|
||||
|
||||
|
||||
class AnswerNodeDoubleLink(BaseModel):
|
||||
node_id: str = Field(..., description="node id")
|
||||
source_node_ids: list[str] = Field(..., description="source node ids")
|
||||
target_node_ids: list[str] = Field(..., description="target node ids")
|
||||
|
||||
|
||||
class AnswerStreamGenerateRoute(BaseModel):
|
||||
"""
|
||||
AnswerStreamGenerateRoute entity
|
||||
"""
|
||||
|
||||
answer_dependencies: dict[str, list[str]] = Field(
|
||||
..., description="answer dependencies (answer node id -> dependent answer node ids)"
|
||||
)
|
||||
answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
|
||||
..., description="answer generate route (answer node id -> generate route chunks)"
|
||||
)
|
||||
11
dify/api/core/workflow/nodes/base/__init__.py
Normal file
11
dify/api/core/workflow/nodes/base/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||
|
||||
__all__ = [
|
||||
"BaseIterationNodeData",
|
||||
"BaseIterationState",
|
||||
"BaseLoopNodeData",
|
||||
"BaseLoopState",
|
||||
"BaseNodeData",
|
||||
"LLMUsageTrackingMixin",
|
||||
]
|
||||
171
dify/api/core/workflow/nodes/base/entities.py
Normal file
171
dify/api/core/workflow/nodes/base/entities.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.workflow.enums import ErrorStrategy
|
||||
|
||||
from .exc import DefaultValueTypeError
|
||||
|
||||
_NumberType = Union[int, float]
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
"""
|
||||
Variable Selector.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_selector: Sequence[str]
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILES = "array[file]"
|
||||
|
||||
|
||||
class DefaultValue(BaseModel):
|
||||
value: Any = None
|
||||
type: DefaultValueType
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
"""Unified number conversion handler"""
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> "DefaultValue":
|
||||
# Type validation configuration
|
||||
type_validators = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
},
|
||||
DefaultValueType.NUMBER: {
|
||||
"type": _NumberType,
|
||||
"converter": self._convert_number,
|
||||
},
|
||||
DefaultValueType.OBJECT: {
|
||||
"type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_NUMBER: {
|
||||
"type": list,
|
||||
"element_type": _NumberType,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_STRING: {
|
||||
"type": list,
|
||||
"element_type": str,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_OBJECT: {
|
||||
"type": list,
|
||||
"element_type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
}
|
||||
|
||||
validator: dict[str, Any] = type_validators.get(self.type, {})
|
||||
if not validator:
|
||||
if self.type == DefaultValueType.ARRAY_FILES:
|
||||
# Handle files type
|
||||
return self
|
||||
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
|
||||
|
||||
# Handle string input cases
|
||||
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
|
||||
self.value = validator["converter"](self.value)
|
||||
|
||||
# Validate base type
|
||||
if not isinstance(self.value, validator["type"]):
|
||||
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
|
||||
|
||||
# Validate array element types
|
||||
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
|
||||
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: str | None = None
|
||||
version: str = "1"
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
iteration_node_id: str
|
||||
index: int
|
||||
inputs: dict
|
||||
|
||||
class MetaData(BaseModel):
|
||||
pass
|
||||
|
||||
metadata: MetaData
|
||||
|
||||
|
||||
class BaseLoopNodeData(BaseNodeData):
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
class BaseLoopState(BaseModel):
|
||||
loop_node_id: str
|
||||
index: int
|
||||
inputs: dict
|
||||
|
||||
class MetaData(BaseModel):
|
||||
pass
|
||||
|
||||
metadata: MetaData
|
||||
10
dify/api/core/workflow/nodes/base/exc.py
Normal file
10
dify/api/core/workflow/nodes/base/exc.py
Normal file
@@ -0,0 +1,10 @@
|
||||
class BaseNodeError(ValueError):
|
||||
"""Base class for node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DefaultValueTypeError(BaseNodeError):
|
||||
"""Raised when the default value type is invalid."""
|
||||
|
||||
pass
|
||||
538
dify/api/core/workflow/nodes/base/node.py
Normal file
538
dify/api/core/workflow/nodes/base/node.py
Normal file
@@ -0,0 +1,538 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from typing import Any, ClassVar
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
AgentLogEvent,
|
||||
IterationFailedEvent,
|
||||
IterationNextEvent,
|
||||
IterationStartedEvent,
|
||||
IterationSucceededEvent,
|
||||
LoopFailedEvent,
|
||||
LoopNextEvent,
|
||||
LoopStartedEvent,
|
||||
LoopSucceededEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
PauseRequestedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import UserFrom
|
||||
|
||||
from .entities import BaseNodeData, RetryConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Node:
|
||||
node_type: ClassVar["NodeType"]
|
||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
self.workflow_id = graph_init_params.workflow_id
|
||||
self.graph_config = graph_init_params.graph_config
|
||||
self.user_id = graph_init_params.user_id
|
||||
self.user_from = UserFrom(graph_init_params.user_from)
|
||||
self.invoke_from = InvokeFrom(graph_init_params.invoke_from)
|
||||
self.workflow_call_depth = graph_init_params.call_depth
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.state: NodeState = NodeState.UNKNOWN # node execution state
|
||||
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
self._node_id = node_id
|
||||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
@abstractmethod
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
# Generate a single node execution ID to use for all events
|
||||
if not self._node_execution_id:
|
||||
self._node_execution_id = str(uuid4())
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
# Create and push start event with required fields
|
||||
start_event = NodeRunStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.title,
|
||||
in_iteration_id=None,
|
||||
start_at=self._start_at,
|
||||
)
|
||||
|
||||
# === FIXME(-LAN-): Needs to refactor.
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
|
||||
if isinstance(self, ToolNode):
|
||||
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
if isinstance(self, DatasourceNode):
|
||||
plugin_id = getattr(self.get_base_node_data(), "plugin_id", "")
|
||||
provider_name = getattr(self.get_base_node_data(), "provider_name", "")
|
||||
|
||||
start_event.provider_id = f"{plugin_id}/{provider_name}"
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
|
||||
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
|
||||
if isinstance(self, TriggerEventNode):
|
||||
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
|
||||
from typing import cast
|
||||
|
||||
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData
|
||||
|
||||
if isinstance(self, AgentNode):
|
||||
start_event.agent_strategy = AgentNodeStrategyInit(
|
||||
name=cast(AgentNodeData, self.get_base_node_data()).agent_strategy_name,
|
||||
icon=self.agent_strategy_icon,
|
||||
)
|
||||
|
||||
# ===
|
||||
yield start_event
|
||||
|
||||
try:
|
||||
result = self._run()
|
||||
|
||||
# Handle NodeRunResult
|
||||
if isinstance(result, NodeRunResult):
|
||||
yield self._convert_node_run_result_to_graph_node_event(result)
|
||||
return
|
||||
|
||||
# Handle event stream
|
||||
for event in result:
|
||||
# NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase
|
||||
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
yield self._dispatch(event)
|
||||
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
event.id = self._node_execution_id
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type="WorkflowNodeError",
|
||||
)
|
||||
yield NodeRunFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=result,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
config: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""Extracts references variable selectors from node configuration.
|
||||
|
||||
The `config` parameter represents the configuration for a specific node type and corresponds
|
||||
to the `data` field in the node definition object.
|
||||
|
||||
The returned mapping has the following structure:
|
||||
|
||||
{'1747829548239.#1747829667553.result#': ['1747829667553', 'result']}
|
||||
|
||||
For loop and iteration nodes, the mapping may look like this:
|
||||
|
||||
{
|
||||
"1748332301644.input_selector": ["1748332363630", "result"],
|
||||
"1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"],
|
||||
}
|
||||
|
||||
where `1748332301644` is the ID of the loop / iteration node,
|
||||
and `1748332325079` is the ID of the node inside the loop or iteration node.
|
||||
|
||||
Here, the key consists of two parts: the current node ID (provided as the `node_id`
|
||||
parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector,
|
||||
enclosed in `#` symbols. These two parts are separated by a dot (`.`).
|
||||
|
||||
The value is a list of string representing the variable selector, where the first element is the node ID
|
||||
of the referenced variable, and the second element is the variable name within that node.
|
||||
|
||||
The meaning of the above response is:
|
||||
|
||||
The node with ID `1747829548239` references the variable `result` from the node with
|
||||
ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a
|
||||
reference to the `result` output variable of node `1747829667553`.
|
||||
|
||||
:param graph_config: graph config
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
# Pass raw dict data instead of creating NodeData instance
|
||||
data = cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
|
||||
)
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
return {}
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||
"""
|
||||
Check if this node blocks the output of specific variables.
|
||||
|
||||
This method is used to determine if a node must complete execution before
|
||||
the specified variables can be used in streaming output.
|
||||
|
||||
:param variable_selectors: Set of variable selectors, each as a tuple (e.g., ('conversation', 'str'))
|
||||
:return: True if this node blocks output of any of the specified variables, False otherwise
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def version(cls) -> str:
|
||||
"""`node_version` returns the version of current node type."""
|
||||
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
|
||||
#
|
||||
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
|
||||
# in `api/core/workflow/nodes/__init__.py`.
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return False
|
||||
|
||||
# Abstract methods that subclasses must implement to provide access
|
||||
# to BaseNodeData properties in a type-safe way
|
||||
|
||||
@abstractmethod
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_title(self) -> str:
|
||||
"""Get the node title."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
"""Get the BaseNodeData object for this node."""
|
||||
...
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
def error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
return self._get_error_strategy()
|
||||
|
||||
@property
|
||||
def retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
return self._get_retry_config()
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
"""Get the node title."""
|
||||
return self._get_title()
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
return self._get_description()
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
return self._get_default_value_dict()
|
||||
|
||||
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
|
||||
match result.status:
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
return NodeRunFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=result,
|
||||
error=result.error,
|
||||
)
|
||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return NodeRunSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=result,
|
||||
)
|
||||
case _:
|
||||
raise Exception(f"result status {result.status} not supported")
|
||||
|
||||
@singledispatchmethod
|
||||
def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase:
|
||||
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
|
||||
match event.node_run_result.status:
|
||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return NodeRunSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=event.node_run_result,
|
||||
)
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
return NodeRunFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=event.node_run_result,
|
||||
error=event.node_run_result.error,
|
||||
)
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
f"Node {self._node_id} does not support status {event.node_run_result.status}"
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
|
||||
return NodeRunPauseRequestedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
|
||||
reason=event.reason,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
|
||||
return NodeRunAgentLogEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
message_id=event.message_id,
|
||||
label=event.label,
|
||||
node_execution_id=event.node_execution_id,
|
||||
parent_id=event.parent_id,
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
||||
return NodeRunLoopStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
|
||||
return NodeRunLoopNextEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
index=event.index,
|
||||
pre_loop_output=event.pre_loop_output,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
|
||||
return NodeRunLoopSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
|
||||
return NodeRunLoopFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
|
||||
return NodeRunIterationStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
|
||||
return NodeRunIterationNextEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.pre_iteration_output,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
|
||||
return NodeRunIterationSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
|
||||
return NodeRunIterationFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
||||
return NodeRunRetrieverResourceEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
retriever_resources=event.retriever_resources,
|
||||
context=event.context,
|
||||
node_version=self.version(),
|
||||
)
|
||||
148
dify/api/core/workflow/nodes/base/template.py
Normal file
148
dify/api/core/workflow/nodes/base/template.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Template structures for Response nodes (Answer and End).
|
||||
|
||||
This module provides a unified template structure for both Answer and End nodes,
|
||||
similar to SegmentGroup but focused on template representation without values.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TemplateSegment(ABC):
|
||||
"""Base class for template segments."""
|
||||
|
||||
@abstractmethod
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the segment."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextSegment(TemplateSegment):
|
||||
"""A text segment in a template."""
|
||||
|
||||
text: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VariableSegment(TemplateSegment):
|
||||
"""A variable reference segment in a template."""
|
||||
|
||||
selector: Sequence[str]
|
||||
variable_name: str | None = None # Optional variable name for End nodes
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{{#" + ".".join(self.selector) + "#}}"
|
||||
|
||||
|
||||
# Type alias for segments
|
||||
TemplateSegmentUnion = Union[TextSegment, VariableSegment]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Template:
|
||||
"""Unified template structure for Response nodes.
|
||||
|
||||
Similar to SegmentGroup, but represents the template structure
|
||||
without variable values - only marking variable selectors.
|
||||
"""
|
||||
|
||||
segments: list[TemplateSegmentUnion]
|
||||
|
||||
@classmethod
|
||||
def from_answer_template(cls, template_str: str) -> "Template":
|
||||
"""Create a Template from an Answer node template string.
|
||||
|
||||
Example:
|
||||
"Hello, {{#node1.name#}}" -> [TextSegment("Hello, "), VariableSegment(["node1", "name"])]
|
||||
|
||||
Args:
|
||||
template_str: The answer template string
|
||||
|
||||
Returns:
|
||||
Template instance
|
||||
"""
|
||||
parser = VariableTemplateParser(template_str)
|
||||
segments: list[TemplateSegmentUnion] = []
|
||||
|
||||
# Extract variable selectors to find all variables
|
||||
variable_selectors = parser.extract_variable_selectors()
|
||||
var_map = {var.variable: var.value_selector for var in variable_selectors}
|
||||
|
||||
# Parse template to get ordered segments
|
||||
# We need to split the template by variable placeholders while preserving order
|
||||
import re
|
||||
|
||||
# Create a regex pattern that matches variable placeholders
|
||||
pattern = r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}"
|
||||
|
||||
# Split template while keeping the delimiters (variable placeholders)
|
||||
parts = re.split(pattern, template_str)
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
if not part:
|
||||
continue
|
||||
|
||||
# Check if this part is a variable reference (odd indices after split)
|
||||
if i % 2 == 1: # Odd indices are variable keys
|
||||
# Remove the # symbols from the variable key
|
||||
var_key = part
|
||||
if var_key in var_map:
|
||||
segments.append(VariableSegment(selector=list(var_map[var_key])))
|
||||
else:
|
||||
# This shouldn't happen with valid templates
|
||||
segments.append(TextSegment(text="{{" + part + "}}"))
|
||||
else:
|
||||
# Even indices are text segments
|
||||
segments.append(TextSegment(text=part))
|
||||
|
||||
return cls(segments=segments)
|
||||
|
||||
@classmethod
|
||||
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
|
||||
"""Create a Template from an End node outputs configuration.
|
||||
|
||||
End nodes are treated as templates of concatenated variables with newlines.
|
||||
|
||||
Example:
|
||||
[{"variable": "text", "value_selector": ["node1", "text"]},
|
||||
{"variable": "result", "value_selector": ["node2", "result"]}]
|
||||
->
|
||||
[VariableSegment(["node1", "text"]),
|
||||
TextSegment("\n"),
|
||||
VariableSegment(["node2", "result"])]
|
||||
|
||||
Args:
|
||||
outputs_config: List of output configurations with variable and value_selector
|
||||
|
||||
Returns:
|
||||
Template instance
|
||||
"""
|
||||
segments: list[TemplateSegmentUnion] = []
|
||||
|
||||
for i, output in enumerate(outputs_config):
|
||||
if i > 0:
|
||||
# Add newline separator between variables
|
||||
segments.append(TextSegment(text="\n"))
|
||||
|
||||
value_selector = output.get("value_selector", [])
|
||||
variable_name = output.get("variable", "")
|
||||
if value_selector:
|
||||
segments.append(VariableSegment(selector=list(value_selector), variable_name=variable_name))
|
||||
|
||||
if len(segments) > 0 and isinstance(segments[-1], TextSegment):
|
||||
segments = segments[:-1]
|
||||
|
||||
return cls(segments=segments)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the template."""
|
||||
return "".join(str(segment) for segment in self.segments)
|
||||
28
dify/api/core/workflow/nodes/base/usage_tracking_mixin.py
Normal file
28
dify/api/core/workflow/nodes/base/usage_tracking_mixin.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class LLMUsageTrackingMixin:
|
||||
"""Provides shared helpers for merging and recording LLM usage within workflow nodes."""
|
||||
|
||||
graph_runtime_state: GraphRuntimeState
|
||||
|
||||
@staticmethod
|
||||
def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage:
|
||||
"""Return a combined usage snapshot, preserving zero-value inputs."""
|
||||
if new_usage is None or new_usage.total_tokens <= 0:
|
||||
return current
|
||||
if current.total_tokens == 0:
|
||||
return new_usage
|
||||
return current.plus(new_usage)
|
||||
|
||||
def _accumulate_usage(self, usage: LLMUsage) -> None:
|
||||
"""Push usage into the graph runtime accumulator for downstream reporting."""
|
||||
if usage.total_tokens <= 0:
|
||||
return
|
||||
|
||||
current_usage = self.graph_runtime_state.llm_usage
|
||||
if current_usage.total_tokens == 0:
|
||||
self.graph_runtime_state.llm_usage = usage.model_copy()
|
||||
else:
|
||||
self.graph_runtime_state.llm_usage = current_usage.plus(usage)
|
||||
130
dify/api/core/workflow/nodes/base/variable_template_parser.py
Normal file
130
dify/api/core/workflow/nodes/base/variable_template_parser.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from .entities import VariableSelector
|
||||
|
||||
REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
|
||||
|
||||
SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
|
||||
|
||||
|
||||
def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]:
|
||||
parts = SELECTOR_PATTERN.split(template)
|
||||
selectors = []
|
||||
for part in filter(lambda x: x, parts):
|
||||
if "." in part and part[0] == "#" and part[-1] == "#":
|
||||
selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split(".")))
|
||||
return selectors
|
||||
|
||||
|
||||
class VariableTemplateParser:
|
||||
"""
|
||||
!NOTE: Consider to use the new `segments` module instead of this class.
|
||||
|
||||
A class for parsing and manipulating template variables in a string.
|
||||
|
||||
Rules:
|
||||
|
||||
1. Template variables must be enclosed in `{{}}`.
|
||||
2. The template variable Key can only be: #node_id.var1.var2#.
|
||||
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
|
||||
|
||||
Example usage:
|
||||
|
||||
template = "Hello, {{#node_id.query.name#}}! Your age is {{#node_id.query.age#}}."
|
||||
parser = VariableTemplateParser(template)
|
||||
|
||||
# Extract template variable keys
|
||||
variable_keys = parser.extract()
|
||||
print(variable_keys)
|
||||
# Output: ['#node_id.query.name#', '#node_id.query.age#']
|
||||
|
||||
# Extract variable selectors
|
||||
variable_selectors = parser.extract_variable_selectors()
|
||||
print(variable_selectors)
|
||||
# Output: [VariableSelector(variable='#node_id.query.name#', value_selector=['node_id', 'query', 'name']),
|
||||
# VariableSelector(variable='#node_id.query.age#', value_selector=['node_id', 'query', 'age'])]
|
||||
|
||||
# Format the template string
|
||||
inputs = {'#node_id.query.name#': 'John', '#node_id.query.age#': 25}}
|
||||
formatted_string = parser.format(inputs)
|
||||
print(formatted_string)
|
||||
# Output: "Hello, John! Your age is 25."
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
self.template = template
|
||||
self.variable_keys = self.extract()
|
||||
|
||||
def extract(self):
|
||||
"""
|
||||
Extracts all the template variable keys from the template string.
|
||||
|
||||
Returns:
|
||||
A list of template variable keys.
|
||||
"""
|
||||
# Regular expression to match the template rules
|
||||
matches = re.findall(REGEX, self.template)
|
||||
|
||||
first_group_matches = [match[0] for match in matches]
|
||||
|
||||
return list(set(first_group_matches))
|
||||
|
||||
def extract_variable_selectors(self) -> list[VariableSelector]:
|
||||
"""
|
||||
Extracts the variable selectors from the template variable keys.
|
||||
|
||||
Returns:
|
||||
A list of VariableSelector objects representing the variable selectors.
|
||||
"""
|
||||
variable_selectors = []
|
||||
for variable_key in self.variable_keys:
|
||||
remove_hash = variable_key.replace("#", "")
|
||||
split_result = remove_hash.split(".")
|
||||
if len(split_result) < 2:
|
||||
continue
|
||||
|
||||
variable_selectors.append(VariableSelector(variable=variable_key, value_selector=split_result))
|
||||
|
||||
return variable_selectors
|
||||
|
||||
def format(self, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Formats the template string by replacing the template variables with their corresponding values.
|
||||
|
||||
Args:
|
||||
inputs: A dictionary containing the values for the template variables.
|
||||
|
||||
Returns:
|
||||
The formatted string with template variables replaced by their values.
|
||||
"""
|
||||
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
value = inputs.get(key, match.group(0)) # return original matched string if key not found
|
||||
|
||||
if value is None:
|
||||
value = ""
|
||||
# convert the value to string
|
||||
if isinstance(value, list | dict | bool | int | float):
|
||||
value = str(value)
|
||||
|
||||
# remove template variables if required
|
||||
return VariableTemplateParser.remove_template_variables(value)
|
||||
|
||||
prompt = re.sub(REGEX, replacer, self.template)
|
||||
return re.sub(r"<\|.*?\|>", "", prompt)
|
||||
|
||||
@classmethod
|
||||
def remove_template_variables(cls, text: str):
|
||||
"""
|
||||
Removes the template variables from the given text.
|
||||
|
||||
Args:
|
||||
text: The text from which to remove the template variables.
|
||||
|
||||
Returns:
|
||||
The text with template variables removed.
|
||||
"""
|
||||
return re.sub(REGEX, r"{\1}", text)
|
||||
3
dify/api/core/workflow/nodes/code/__init__.py
Normal file
3
dify/api/core/workflow/nodes/code/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .code_node import CodeNode
|
||||
|
||||
__all__ = ["CodeNode"]
|
||||
444
dify/api/core/workflow/nodes/code/code_node.py
Normal file
444
dify/api/core/workflow/nodes/code/code_node.py
Normal file
@@ -0,0 +1,444 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
DepthLimitError,
|
||||
OutputValidationError,
|
||||
)
|
||||
|
||||
|
||||
class CodeNode(Node):
|
||||
node_type = NodeType.CODE
|
||||
|
||||
_node_data: CodeNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = CodeNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
code_language = CodeLanguage.PYTHON3
|
||||
if filters:
|
||||
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
|
||||
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
|
||||
|
||||
return code_provider.get_default_config()
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get code language
|
||||
code_language = self._node_data.code_language
|
||||
code = self._node_data.code
|
||||
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in self._node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
variables[variable_name] = [v.to_dict() for v in variable.value] if variable.value else None
|
||||
else:
|
||||
variables[variable_name] = variable.to_object() if variable else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=code_language,
|
||||
code=code,
|
||||
inputs=variables,
|
||||
)
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result=result, output_schema=self._node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
)
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
def _check_string(self, value: str | None, variable: str) -> str | None:
|
||||
"""
|
||||
Check string
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{variable}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
|
||||
)
|
||||
|
||||
return value.replace("\x00", "")
|
||||
|
||||
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
return value
|
||||
|
||||
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
|
||||
"""
|
||||
Check number
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` is out of range,"
|
||||
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
|
||||
)
|
||||
|
||||
if isinstance(value, float):
|
||||
decimal_value = Decimal(str(value)).normalize()
|
||||
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
|
||||
# raise error if precision is too high
|
||||
if precision > dify_config.CODE_MAX_PRECISION:
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` has too high precision,"
|
||||
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
def _transform_result(
|
||||
self,
|
||||
result: Mapping[str, Any],
|
||||
output_schema: dict[str, CodeNodeData.Output] | None,
|
||||
prefix: str = "",
|
||||
depth: int = 1,
|
||||
):
|
||||
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
|
||||
# Note that `_transform_result` may produce lists containing `None` values,
|
||||
# which don't conform to the type requirements of `Array*Segment` classes.
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
transformed_result: dict[str, Any] = {}
|
||||
if output_schema is None:
|
||||
# validate output thought instance type
|
||||
for output_name, output_value in result.items():
|
||||
if isinstance(output_value, dict):
|
||||
self._transform_result(
|
||||
result=output_value,
|
||||
output_schema=None,
|
||||
prefix=f"{prefix}.{output_name}" if prefix else output_name,
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif isinstance(output_value, bool):
|
||||
self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name)
|
||||
elif isinstance(output_value, int | float):
|
||||
self._check_number(
|
||||
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
|
||||
)
|
||||
elif isinstance(output_value, str):
|
||||
self._check_string(
|
||||
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
|
||||
)
|
||||
elif isinstance(output_value, list):
|
||||
first_element = output_value[0] if len(output_value) > 0 else None
|
||||
if first_element is not None:
|
||||
if isinstance(first_element, int | float) and all(
|
||||
value is None or isinstance(value, int | float) for value in output_value
|
||||
):
|
||||
for i, value in enumerate(output_value):
|
||||
self._check_number(
|
||||
value=value,
|
||||
variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
|
||||
)
|
||||
elif isinstance(first_element, str) and all(
|
||||
value is None or isinstance(value, str) for value in output_value
|
||||
):
|
||||
for i, value in enumerate(output_value):
|
||||
self._check_string(
|
||||
value=value,
|
||||
variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
|
||||
)
|
||||
elif (
|
||||
isinstance(first_element, dict)
|
||||
and all(value is None or isinstance(value, dict) for value in output_value)
|
||||
or isinstance(first_element, list)
|
||||
and all(value is None or isinstance(value, list) for value in output_value)
|
||||
):
|
||||
for i, value in enumerate(output_value):
|
||||
if value is not None:
|
||||
self._transform_result(
|
||||
result=value,
|
||||
output_schema=None,
|
||||
prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
|
||||
depth=depth + 1,
|
||||
)
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}.{output_name} is not a valid array."
|
||||
f" make sure all elements are of the same type."
|
||||
)
|
||||
elif output_value is None:
|
||||
pass
|
||||
else:
|
||||
raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.")
|
||||
|
||||
return result
|
||||
|
||||
parameters_validated = {}
|
||||
for output_name, output_config in output_schema.items():
|
||||
dot = "." if prefix else ""
|
||||
if output_name not in result:
|
||||
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
|
||||
|
||||
if output_config.type == SegmentType.OBJECT:
|
||||
# check if output is object
|
||||
if not isinstance(result.get(output_name), dict):
|
||||
if result[output_name] is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an object,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
transformed_result[output_name] = self._transform_result(
|
||||
result=result[output_name],
|
||||
output_schema=output_config.children,
|
||||
prefix=f"{prefix}.{output_name}",
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif output_config.type == SegmentType.NUMBER:
|
||||
# check if number available
|
||||
value = result.get(output_name)
|
||||
if value is not None and not isinstance(value, (int, float)):
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not a number,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
checked = self._check_number(value=value, variable=f"{prefix}{dot}{output_name}")
|
||||
# If the output is a boolean and the output schema specifies a NUMBER type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
transformed_result[output_name] = self._convert_boolean_to_int(checked)
|
||||
|
||||
elif output_config.type == SegmentType.STRING:
|
||||
# check if string available
|
||||
value = result.get(output_name)
|
||||
if value is not None and not isinstance(value, str):
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} must be a string, got {type(value).__name__} instead"
|
||||
)
|
||||
transformed_result[output_name] = self._check_string(
|
||||
value=value,
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == SegmentType.BOOLEAN:
|
||||
transformed_result[output_name] = self._check_boolean(
|
||||
value=result[output_name],
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == SegmentType.ARRAY_NUMBER:
|
||||
# check if array of number available
|
||||
value = result[output_name]
|
||||
if not isinstance(value, list):
|
||||
if value is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
||||
for i, inner_value in enumerate(value):
|
||||
if not isinstance(inner_value, (int, float)):
|
||||
raise OutputValidationError(
|
||||
f"The element at index {i} of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" a number."
|
||||
)
|
||||
_ = self._check_number(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
transformed_result[output_name] = [
|
||||
# If the element is a boolean and the output schema specifies a `array[number]` type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
self._convert_boolean_to_int(v)
|
||||
for v in value
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_STRING:
|
||||
# check if array of string available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_OBJECT:
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
||||
for i, value in enumerate(result[output_name]):
|
||||
if not isinstance(value, dict):
|
||||
if value is None:
|
||||
pass
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name}[{i}] is not an object,"
|
||||
f" got {type(value)} instead at index {i}."
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
None
|
||||
if value is None
|
||||
else self._transform_result(
|
||||
result=value,
|
||||
output_schema=output_config.children,
|
||||
prefix=f"{prefix}{dot}{output_name}[{i}]",
|
||||
depth=depth + 1,
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
|
||||
# check if array of object available
|
||||
value = result[output_name]
|
||||
if not isinstance(value, list):
|
||||
if value is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
for i, inner_value in enumerate(value):
|
||||
if inner_value is not None and not isinstance(inner_value, bool):
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name}[{i}] is not a boolean,"
|
||||
f" got {type(inner_value)} instead."
|
||||
)
|
||||
_ = self._check_boolean(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
transformed_result[output_name] = value
|
||||
|
||||
else:
|
||||
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
|
||||
|
||||
parameters_validated[output_name] = True
|
||||
|
||||
# check if all output parameters are validated
|
||||
if len(parameters_validated) != len(result):
|
||||
raise CodeNodeError("Not all output parameters are validated.")
|
||||
|
||||
return transformed_result
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = CodeNodeData.model_validate(node_data)
|
||||
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in typed_node_data.variables
|
||||
}
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
@staticmethod
|
||||
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
|
||||
"""This function convert boolean to integers when the output schema specifies a NUMBER type.
|
||||
|
||||
This ensures compatibility with existing workflows that may use
|
||||
`True` and `False` as values for NUMBER type outputs.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
return value
|
||||
47
dify/api/core/workflow/nodes/code/entities.py
Normal file
47
dify/api/core/workflow/nodes/code/entities.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Annotated, Literal, Self
|
||||
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_type(segment_type: SegmentType) -> SegmentType:
|
||||
if segment_type not in _ALLOWED_OUTPUT_FROM_CODE:
|
||||
raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}")
|
||||
return segment_type
|
||||
|
||||
|
||||
class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: dict[str, Self] | None = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
name: str
|
||||
version: str
|
||||
|
||||
variables: list[VariableSelector]
|
||||
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
|
||||
code: str
|
||||
outputs: dict[str, Output]
|
||||
dependencies: list[Dependency] | None = None
|
||||
16
dify/api/core/workflow/nodes/code/exc.py
Normal file
16
dify/api/core/workflow/nodes/code/exc.py
Normal file
@@ -0,0 +1,16 @@
|
||||
class CodeNodeError(ValueError):
|
||||
"""Base class for code node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OutputValidationError(CodeNodeError):
|
||||
"""Raised when there is an output validation error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DepthLimitError(CodeNodeError):
|
||||
"""Raised when the depth limit is reached."""
|
||||
|
||||
pass
|
||||
3
dify/api/core/workflow/nodes/datasource/__init__.py
Normal file
3
dify/api/core/workflow/nodes/datasource/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .datasource_node import DatasourceNode
|
||||
|
||||
__all__ = ["DatasourceNode"]
|
||||
502
dify/api/core/workflow/nodes/datasource/datasource_node.py
Normal file
502
dify/api/core/workflow/nodes/datasource/datasource_node.py
Normal file
@@ -0,0 +1,502 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
DatasourceParameter,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from core.file import File
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.tool.exc import ToolFileError
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from .entities import DatasourceNodeData
|
||||
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||
|
||||
|
||||
class DatasourceNode(Node):
|
||||
"""
|
||||
Datasource Node
|
||||
"""
|
||||
|
||||
_node_data: DatasourceNodeData
|
||||
node_type = NodeType.DATASOURCE
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = DatasourceNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the datasource node
|
||||
"""
|
||||
|
||||
node_data = self._node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
if not datasource_type_segement:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
|
||||
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
|
||||
if not datasource_info_segement:
|
||||
raise DatasourceNodeError("Datasource info is not set")
|
||||
datasource_info_value = datasource_info_segement.value
|
||||
if not isinstance(datasource_info_value, dict):
|
||||
raise DatasourceNodeError("Invalid datasource info format")
|
||||
datasource_info: dict[str, Any] = datasource_info_value
|
||||
# get datasource runtime
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
|
||||
|
||||
parameters_for_log = datasource_info
|
||||
|
||||
try:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
credential_id=datasource_info.get("credential_id", ""),
|
||||
)
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
online_document_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||
workspace_id=datasource_info.get("workspace_id", ""),
|
||||
page_id=datasource_info.get("page", {}).get("page_id", ""),
|
||||
type=datasource_info.get("page", {}).get("type", ""),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
)
|
||||
yield from self._transform_message(
|
||||
messages=online_document_result,
|
||||
parameters_for_log=parameters_for_log,
|
||||
datasource_info=datasource_info,
|
||||
)
|
||||
case DatasourceProviderType.ONLINE_DRIVE:
|
||||
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
online_drive_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.online_drive_download_file(
|
||||
user_id=self.user_id,
|
||||
request=OnlineDriveDownloadFileRequest(
|
||||
id=datasource_info.get("id", ""),
|
||||
bucket=datasource_info.get("bucket"),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
)
|
||||
yield from self._transform_datasource_file_message(
|
||||
messages=online_drive_result,
|
||||
parameters_for_log=parameters_for_log,
|
||||
datasource_info=datasource_info,
|
||||
variable_pool=variable_pool,
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
**datasource_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
related_id = datasource_info.get("related_id")
|
||||
if not related_id:
|
||||
raise DatasourceNodeError("File is not exist")
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first()
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file Info")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=upload_file.source_url,
|
||||
)
|
||||
variable_pool.add([self._node_id, "file"], file_info)
|
||||
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": file_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to transform datasource message: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to invoke datasource: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
datasource_parameters: Sequence[DatasourceParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: DatasourceNodeData,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
if node_data.datasource_parameters:
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
parameter = datasource_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
datasource_input = node_data.datasource_parameters[parameter_name]
|
||||
if datasource_input.type == "variable":
|
||||
variable = variable_pool.get(datasource_input.value)
|
||||
if variable is None:
|
||||
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif datasource_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(datasource_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
typed_node_data = DatasourceNodeData.model_validate(node_data)
|
||||
result = {}
|
||||
if typed_node_data.datasource_parameters:
|
||||
for parameter_name in typed_node_data.datasource_parameters:
|
||||
input = typed_node_data.datasource_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceMessage.MessageType.BINARY_LINK,
|
||||
DatasourceMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
# mark the end of the stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={**variables},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _transform_datasource_file_message(
|
||||
self,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
datasource_type: DatasourceProviderType,
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
file = None
|
||||
for message in message_stream:
|
||||
if message.type == DatasourceMessage.MessageType.BINARY_LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
if file:
|
||||
variable_pool.add([self._node_id, "file"], file)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": file,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
41
dify/api/core/workflow/nodes/datasource/entities.py
Normal file
41
dify/api/core/workflow/nodes/datasource/entities.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
plugin_id: str
|
||||
provider_name: str # redundancy
|
||||
provider_type: str
|
||||
datasource_name: str | None = "local_file"
|
||||
datasource_configurations: dict[str, Any] | None = None
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"] | None = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get("value")
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, str | int | float | bool):
|
||||
raise ValueError("value must be a string, int, float, or bool")
|
||||
return typ
|
||||
|
||||
datasource_parameters: dict[str, DatasourceInput] | None = None
|
||||
16
dify/api/core/workflow/nodes/datasource/exc.py
Normal file
16
dify/api/core/workflow/nodes/datasource/exc.py
Normal file
@@ -0,0 +1,16 @@
|
||||
class DatasourceNodeError(ValueError):
|
||||
"""Base exception for datasource node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceParameterError(DatasourceNodeError):
|
||||
"""Exception raised for errors in datasource parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceFileError(DatasourceNodeError):
|
||||
"""Exception raised for errors related to datasource files."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,4 @@
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .node import DocumentExtractorNode
|
||||
|
||||
__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"]
|
||||
@@ -0,0 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class DocumentExtractorNodeData(BaseNodeData):
|
||||
variable_selector: Sequence[str]
|
||||
14
dify/api/core/workflow/nodes/document_extractor/exc.py
Normal file
14
dify/api/core/workflow/nodes/document_extractor/exc.py
Normal file
@@ -0,0 +1,14 @@
|
||||
class DocumentExtractorError(ValueError):
|
||||
"""Base exception for errors related to the DocumentExtractorNode."""
|
||||
|
||||
|
||||
class FileDownloadError(DocumentExtractorError):
|
||||
"""Exception raised when there's an error downloading a file."""
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(DocumentExtractorError):
|
||||
"""Exception raised when trying to extract text from an unsupported file type."""
|
||||
|
||||
|
||||
class TextExtractionError(DocumentExtractorError):
|
||||
"""Exception raised when there's an error during text extraction from a file."""
|
||||
693
dify/api/core/workflow/nodes/document_extractor/node.py
Normal file
693
dify/api/core/workflow/nodes/document_extractor/node.py
Normal file
@@ -0,0 +1,693 @@
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
import chardet
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypandoc
|
||||
import pypdfium2
|
||||
import webvtt
|
||||
import yaml
|
||||
from docx.document import Document
|
||||
from docx.oxml.table import CT_Tbl
|
||||
from docx.oxml.text.paragraph import CT_P
|
||||
from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayStringSegment, FileSegment
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentExtractorNode(Node):
|
||||
"""
|
||||
Extracts text content from various file types.
|
||||
Supports plain text, PDF, and DOC/DOCX files.
|
||||
"""
|
||||
|
||||
node_type = NodeType.DOCUMENT_EXTRACTOR
|
||||
|
||||
_node_data: DocumentExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = DocumentExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
variable_selector = self._node_data.variable_selector
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
||||
|
||||
if variable is None:
|
||||
error_message = f"File variable not found for selector: {variable_selector}"
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message)
|
||||
if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment):
|
||||
error_message = f"Variable {variable_selector} is not an ArrayFileSegment"
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message)
|
||||
|
||||
value = variable.value
|
||||
inputs = {"variable_selector": variable_selector}
|
||||
process_data = {"documents": value if isinstance(value, list) else [value]}
|
||||
|
||||
try:
|
||||
if isinstance(value, list):
|
||||
extracted_text_list = list(map(_extract_text_from_file, value))
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"text": ArrayStringSegment(value=extracted_text_list)},
|
||||
)
|
||||
elif isinstance(value, File):
|
||||
extracted_text = _extract_text_from_file(value)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"text": extracted_text},
|
||||
)
|
||||
else:
|
||||
raise DocumentExtractorError(f"Unsupported variable type: {type(value)}")
|
||||
except DocumentExtractorError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
|
||||
|
||||
return {node_id + ".files": typed_node_data.variable_selector}
|
||||
|
||||
|
||||
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
"""Extract text from a file based on its MIME type."""
|
||||
match mime_type:
|
||||
case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml":
|
||||
return _extract_text_from_plain_text(file_content)
|
||||
case "application/pdf":
|
||||
return _extract_text_from_pdf(file_content)
|
||||
case "application/msword":
|
||||
return _extract_text_from_doc(file_content)
|
||||
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
return _extract_text_from_docx(file_content)
|
||||
case "text/csv":
|
||||
return _extract_text_from_csv(file_content)
|
||||
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel":
|
||||
return _extract_text_from_excel(file_content)
|
||||
case "application/vnd.ms-powerpoint":
|
||||
return _extract_text_from_ppt(file_content)
|
||||
case "application/vnd.openxmlformats-officedocument.presentationml.presentation":
|
||||
return _extract_text_from_pptx(file_content)
|
||||
case "application/epub+zip":
|
||||
return _extract_text_from_epub(file_content)
|
||||
case "message/rfc822":
|
||||
return _extract_text_from_eml(file_content)
|
||||
case "application/vnd.ms-outlook":
|
||||
return _extract_text_from_msg(file_content)
|
||||
case "application/json":
|
||||
return _extract_text_from_json(file_content)
|
||||
case "application/x-yaml" | "text/yaml":
|
||||
return _extract_text_from_yaml(file_content)
|
||||
case "text/vtt":
|
||||
return _extract_text_from_vtt(file_content)
|
||||
case "text/properties":
|
||||
return _extract_text_from_properties(file_content)
|
||||
case _:
|
||||
raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
|
||||
|
||||
|
||||
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
|
||||
"""Extract text from a file based on its file extension."""
|
||||
match file_extension:
|
||||
case (
|
||||
".txt"
|
||||
| ".markdown"
|
||||
| ".md"
|
||||
| ".mdx"
|
||||
| ".html"
|
||||
| ".htm"
|
||||
| ".xml"
|
||||
| ".c"
|
||||
| ".h"
|
||||
| ".cpp"
|
||||
| ".hpp"
|
||||
| ".cc"
|
||||
| ".cxx"
|
||||
| ".c++"
|
||||
| ".py"
|
||||
| ".js"
|
||||
| ".ts"
|
||||
| ".jsx"
|
||||
| ".tsx"
|
||||
| ".java"
|
||||
| ".php"
|
||||
| ".rb"
|
||||
| ".go"
|
||||
| ".rs"
|
||||
| ".swift"
|
||||
| ".kt"
|
||||
| ".scala"
|
||||
| ".sh"
|
||||
| ".bash"
|
||||
| ".bat"
|
||||
| ".ps1"
|
||||
| ".sql"
|
||||
| ".r"
|
||||
| ".m"
|
||||
| ".pl"
|
||||
| ".lua"
|
||||
| ".vim"
|
||||
| ".asm"
|
||||
| ".s"
|
||||
| ".css"
|
||||
| ".scss"
|
||||
| ".less"
|
||||
| ".sass"
|
||||
| ".ini"
|
||||
| ".cfg"
|
||||
| ".conf"
|
||||
| ".toml"
|
||||
| ".env"
|
||||
| ".log"
|
||||
| ".vtt"
|
||||
):
|
||||
return _extract_text_from_plain_text(file_content)
|
||||
case ".json":
|
||||
return _extract_text_from_json(file_content)
|
||||
case ".yaml" | ".yml":
|
||||
return _extract_text_from_yaml(file_content)
|
||||
case ".pdf":
|
||||
return _extract_text_from_pdf(file_content)
|
||||
case ".doc":
|
||||
return _extract_text_from_doc(file_content)
|
||||
case ".docx":
|
||||
return _extract_text_from_docx(file_content)
|
||||
case ".csv":
|
||||
return _extract_text_from_csv(file_content)
|
||||
case ".xls" | ".xlsx":
|
||||
return _extract_text_from_excel(file_content)
|
||||
case ".ppt":
|
||||
return _extract_text_from_ppt(file_content)
|
||||
case ".pptx":
|
||||
return _extract_text_from_pptx(file_content)
|
||||
case ".epub":
|
||||
return _extract_text_from_epub(file_content)
|
||||
case ".eml":
|
||||
return _extract_text_from_eml(file_content)
|
||||
case ".msg":
|
||||
return _extract_text_from_msg(file_content)
|
||||
case ".properties":
|
||||
return _extract_text_from_properties(file_content)
|
||||
case _:
|
||||
raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}")
|
||||
|
||||
|
||||
def _extract_text_from_plain_text(file_content: bytes) -> str:
|
||||
try:
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
encoding = "utf-8"
|
||||
|
||||
return file_content.decode(encoding, errors="ignore")
|
||||
except (UnicodeDecodeError, LookupError) as e:
|
||||
# If decoding fails, try with utf-8 as last resort
|
||||
try:
|
||||
return file_content.decode("utf-8", errors="ignore")
|
||||
except UnicodeDecodeError:
|
||||
raise TextExtractionError(f"Failed to decode plain text file: {e}") from e
|
||||
|
||||
|
||||
def _extract_text_from_json(file_content: bytes) -> str:
|
||||
try:
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
encoding = "utf-8"
|
||||
|
||||
json_data = json.loads(file_content.decode(encoding, errors="ignore"))
|
||||
return json.dumps(json_data, indent=2, ensure_ascii=False)
|
||||
except (UnicodeDecodeError, LookupError, json.JSONDecodeError) as e:
|
||||
# If decoding fails, try with utf-8 as last resort
|
||||
try:
|
||||
json_data = json.loads(file_content.decode("utf-8", errors="ignore"))
|
||||
return json.dumps(json_data, indent=2, ensure_ascii=False)
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e
|
||||
|
||||
|
||||
def _extract_text_from_yaml(file_content: bytes) -> str:
|
||||
"""Extract the content from yaml file"""
|
||||
try:
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
encoding = "utf-8"
|
||||
|
||||
yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore"))
|
||||
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
|
||||
except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e:
|
||||
# If decoding fails, try with utf-8 as last resort
|
||||
try:
|
||||
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore"))
|
||||
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
|
||||
except (UnicodeDecodeError, yaml.YAMLError):
|
||||
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
|
||||
|
||||
|
||||
def _extract_text_from_pdf(file_content: bytes) -> str:
|
||||
try:
|
||||
pdf_file = io.BytesIO(file_content)
|
||||
pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True)
|
||||
text = ""
|
||||
for page in pdf_document:
|
||||
text_page = page.get_textpage()
|
||||
text += text_page.get_text_range()
|
||||
text_page.close()
|
||||
page.close()
|
||||
return text
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOC file.
|
||||
"""
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
if not dify_config.UNSTRUCTURED_API_URL:
|
||||
raise TextExtractionError("UNSTRUCTURED_API_URL must be set")
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
with open(temp_file.name, "rb") as file:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOC: {str(e)}") from e
|
||||
|
||||
|
||||
def parser_docx_part(block, doc: Document, content_items, i):
|
||||
if isinstance(block, CT_P):
|
||||
content_items.append((i, "paragraph", Paragraph(block, doc)))
|
||||
elif isinstance(block, CT_Tbl):
|
||||
content_items.append((i, "table", Table(block, doc)))
|
||||
|
||||
|
||||
def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOCX file.
|
||||
For now support only paragraph and table add more if needed
|
||||
"""
|
||||
try:
|
||||
doc_file = io.BytesIO(file_content)
|
||||
doc = docx.Document(doc_file)
|
||||
text = []
|
||||
|
||||
# Keep track of paragraph and table positions
|
||||
content_items: list[tuple[int, str, Table | Paragraph]] = []
|
||||
|
||||
it = iter(doc.element.body)
|
||||
part = next(it, None)
|
||||
i = 0
|
||||
while part is not None:
|
||||
parser_docx_part(part, doc, content_items, i)
|
||||
i = i + 1
|
||||
part = next(it, None)
|
||||
|
||||
# Process sorted content
|
||||
for _, item_type, item in content_items:
|
||||
if item_type == "paragraph":
|
||||
if isinstance(item, Table):
|
||||
continue
|
||||
text.append(item.text)
|
||||
elif item_type == "table":
|
||||
# Process tables
|
||||
if not isinstance(item, Table):
|
||||
continue
|
||||
try:
|
||||
# Check if any cell in the table has text
|
||||
has_content = False
|
||||
for row in item.rows:
|
||||
if any(cell.text.strip() for cell in row.cells):
|
||||
has_content = True
|
||||
break
|
||||
|
||||
if has_content:
|
||||
cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells]
|
||||
markdown_table = f"| {' | '.join(cell_texts)} |\n"
|
||||
markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n"
|
||||
|
||||
for row in item.rows[1:]:
|
||||
# Replace newlines with <br> in each cell
|
||||
row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells]
|
||||
markdown_table += "| " + " | ".join(row_cells) + " |\n"
|
||||
|
||||
text.append(markdown_table)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to extract table from DOC: %s", e)
|
||||
continue
|
||||
|
||||
return "\n".join(text)
|
||||
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e
|
||||
|
||||
|
||||
def _download_file_content(file: File) -> bytes:
|
||||
"""Download the content of a file based on its transfer method."""
|
||||
try:
|
||||
if file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if file.remote_url is None:
|
||||
raise FileDownloadError("Missing URL for remote file")
|
||||
response = ssrf_proxy.get(file.remote_url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
else:
|
||||
return file_manager.download(file)
|
||||
except Exception as e:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_file(file: File):
|
||||
file_content = _download_file_content(file)
|
||||
if file.extension:
|
||||
extracted_text = _extract_text_by_file_extension(file_content=file_content, file_extension=file.extension)
|
||||
elif file.mime_type:
|
||||
extracted_text = _extract_text_by_mime_type(file_content=file_content, mime_type=file.mime_type)
|
||||
else:
|
||||
raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing")
|
||||
return extracted_text
|
||||
|
||||
|
||||
def _extract_text_from_csv(file_content: bytes) -> str:
|
||||
try:
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
encoding = "utf-8"
|
||||
|
||||
try:
|
||||
csv_file = io.StringIO(file_content.decode(encoding, errors="ignore"))
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
# If decoding fails, try with utf-8 as last resort
|
||||
csv_file = io.StringIO(file_content.decode("utf-8", errors="ignore"))
|
||||
|
||||
csv_reader = csv.reader(csv_file)
|
||||
rows = list(csv_reader)
|
||||
|
||||
if not rows:
|
||||
return ""
|
||||
|
||||
# Combine multi-line text in the header row
|
||||
header_row = [cell.replace("\n", " ").replace("\r", "") for cell in rows[0]]
|
||||
|
||||
# Create Markdown table
|
||||
markdown_table = "| " + " | ".join(header_row) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["-" * len(col) for col in rows[0]]) + " |\n"
|
||||
|
||||
# Process each data row and combine multi-line text in each cell
|
||||
for row in rows[1:]:
|
||||
processed_row = [cell.replace("\n", " ").replace("\r", "") for cell in row]
|
||||
markdown_table += "| " + " | ".join(processed_row) + " |\n"
|
||||
|
||||
return markdown_table
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
"""Extract text from an Excel file using pandas."""
|
||||
|
||||
def _construct_markdown_table(df: pd.DataFrame) -> str:
|
||||
"""Manually construct a Markdown table from a DataFrame."""
|
||||
# Construct the header row
|
||||
header_row = "| " + " | ".join(df.columns) + " |"
|
||||
|
||||
# Construct the separator row
|
||||
separator_row = "| " + " | ".join(["-" * len(col) for col in df.columns]) + " |"
|
||||
|
||||
# Construct the data rows
|
||||
data_rows = []
|
||||
for _, row in df.iterrows():
|
||||
data_row = "| " + " | ".join(map(str, row)) + " |"
|
||||
data_rows.append(data_row)
|
||||
|
||||
# Combine all rows into a single string
|
||||
markdown_table = "\n".join([header_row, separator_row] + data_rows)
|
||||
return markdown_table
|
||||
|
||||
try:
|
||||
excel_file = pd.ExcelFile(io.BytesIO(file_content))
|
||||
markdown_table = ""
|
||||
for sheet_name in excel_file.sheet_names:
|
||||
try:
|
||||
df = excel_file.parse(sheet_name=sheet_name)
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
# Combine multi-line text in each cell into a single line
|
||||
df = df.map(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x)
|
||||
|
||||
# Combine multi-line text in column names into a single line
|
||||
df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns])
|
||||
|
||||
# Manually construct the Markdown table
|
||||
markdown_table += _construct_markdown_table(df) + "\n\n"
|
||||
except Exception:
|
||||
continue
|
||||
return markdown_table
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
try:
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
with open(temp_file.name, "rb") as file:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
else:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_ppt(file=file)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
try:
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
with open(temp_file.name, "rb") as file:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
else:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_pptx(file=file)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_epub(file_content: bytes) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.epub import partition_epub
|
||||
|
||||
try:
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
with open(temp_file.name, "rb") as file:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
else:
|
||||
pypandoc.download_pandoc()
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_epub(file=file)
|
||||
return "\n".join([str(element) for element in elements])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_eml(file_content: bytes) -> str:
|
||||
from unstructured.partition.email import partition_email
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_email(file=file)
|
||||
return "\n".join([str(element) for element in elements])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_msg(file_content: bytes) -> str:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_msg(file=file)
|
||||
return "\n".join([str(element) for element in elements])
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_vtt(vtt_bytes: bytes) -> str:
|
||||
text = _extract_text_from_plain_text(vtt_bytes)
|
||||
|
||||
# remove bom
|
||||
text = text.lstrip("\ufeff")
|
||||
|
||||
raw_results = []
|
||||
for caption in webvtt.from_string(text):
|
||||
raw_results.append((caption.voice, caption.text))
|
||||
|
||||
# Merge consecutive utterances by the same speaker
|
||||
merged_results = []
|
||||
if raw_results:
|
||||
current_speaker, current_text = raw_results[0]
|
||||
|
||||
for i in range(1, len(raw_results)):
|
||||
spk, txt = raw_results[i]
|
||||
if spk is None:
|
||||
merged_results.append((None, current_text))
|
||||
continue
|
||||
|
||||
if spk == current_speaker:
|
||||
# If it is the same speaker, merge the utterances (joined by space)
|
||||
current_text += " " + txt
|
||||
else:
|
||||
# If the speaker changes, register the utterance so far and move on
|
||||
merged_results.append((current_speaker, current_text))
|
||||
current_speaker, current_text = spk, txt
|
||||
|
||||
# Add the last element
|
||||
merged_results.append((current_speaker, current_text))
|
||||
else:
|
||||
merged_results = raw_results
|
||||
|
||||
# Return the result in the specified format: Speaker "text" style
|
||||
formatted = [f'{spk or ""} "{txt}"' for spk, txt in merged_results]
|
||||
return "\n".join(formatted)
|
||||
|
||||
|
||||
def _extract_text_from_properties(file_content: bytes) -> str:
|
||||
try:
|
||||
text = _extract_text_from_plain_text(file_content)
|
||||
lines = text.splitlines()
|
||||
result = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
# Preserve comments and empty lines
|
||||
if not line or line.startswith("#") or line.startswith("!"):
|
||||
result.append(line)
|
||||
continue
|
||||
|
||||
if "=" in line:
|
||||
key, value = line.split("=", 1)
|
||||
elif ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
else:
|
||||
key, value = line, ""
|
||||
|
||||
result.append(f"{key.strip()}: {value.strip()}")
|
||||
|
||||
return "\n".join(result)
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from properties file: {str(e)}") from e
|
||||
0
dify/api/core/workflow/nodes/end/__init__.py
Normal file
0
dify/api/core/workflow/nodes/end/__init__.py
Normal file
74
dify/api/core/workflow/nodes/end/end_node.py
Normal file
74
dify/api/core/workflow/nodes/end/end_node.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
|
||||
|
||||
class EndNode(Node):
|
||||
node_type = NodeType.END
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
_node_data: EndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = EndNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node - collect all outputs at once.
|
||||
|
||||
This method runs after streaming is complete (if streaming was enabled).
|
||||
It collects all output variables and returns them.
|
||||
"""
|
||||
output_variables = self._node_data.outputs
|
||||
|
||||
outputs = {}
|
||||
for variable_selector in output_variables:
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
value = variable.to_object() if variable is not None else None
|
||||
outputs[variable_selector.variable] = value
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=outputs,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
def get_streaming_template(self) -> Template:
|
||||
"""
|
||||
Get the template for streaming.
|
||||
|
||||
Returns:
|
||||
Template instance for this End node
|
||||
"""
|
||||
outputs_config = [
|
||||
{"variable": output.variable, "value_selector": output.value_selector} for output in self._node_data.outputs
|
||||
]
|
||||
return Template.from_end_outputs(outputs_config)
|
||||
25
dify/api/core/workflow/nodes/end/entities.py
Normal file
25
dify/api/core/workflow/nodes/end/entities.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
class EndNodeData(BaseNodeData):
|
||||
"""
|
||||
END Node Data.
|
||||
"""
|
||||
|
||||
outputs: list[VariableSelector]
|
||||
|
||||
|
||||
class EndStreamParam(BaseModel):
|
||||
"""
|
||||
EndStreamParam entity
|
||||
"""
|
||||
|
||||
end_dependencies: dict[str, list[str]] = Field(
|
||||
..., description="end dependencies (end node id -> dependent node ids)"
|
||||
)
|
||||
end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field(
|
||||
..., description="end stream variable selector mapping (end node id -> stream variable selectors)"
|
||||
)
|
||||
4
dify/api/core/workflow/nodes/http_request/__init__.py
Normal file
4
dify/api/core/workflow/nodes/http_request/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData
|
||||
from .node import HttpRequestNode
|
||||
|
||||
__all__ = ["BodyData", "HttpRequestNode", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "HttpRequestNodeData"]
|
||||
192
dify/api/core/workflow/nodes/http_request/entities.py
Normal file
192
dify/api/core/workflow/nodes/http_request/entities.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import mimetypes
|
||||
from collections.abc import Sequence
|
||||
from email.message import Message
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class HttpRequestNodeAuthorizationConfig(BaseModel):
|
||||
type: Literal["basic", "bearer", "custom"]
|
||||
api_key: str
|
||||
header: str = ""
|
||||
|
||||
|
||||
class HttpRequestNodeAuthorization(BaseModel):
|
||||
type: Literal["no-auth", "api-key"]
|
||||
config: HttpRequestNodeAuthorizationConfig | None = None
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo):
|
||||
"""
|
||||
Check config, if type is no-auth, config should be None, otherwise it should be a dict.
|
||||
"""
|
||||
if values.data["type"] == "no-auth":
|
||||
return None
|
||||
else:
|
||||
if not v or not isinstance(v, dict):
|
||||
raise ValueError("config should be a dict")
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class BodyData(BaseModel):
|
||||
key: str = ""
|
||||
type: Literal["file", "text"]
|
||||
value: str = ""
|
||||
file: Sequence[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class HttpRequestNodeBody(BaseModel):
|
||||
type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json", "binary"]
|
||||
data: Sequence[BodyData] = Field(default_factory=list)
|
||||
|
||||
@field_validator("data", mode="before")
|
||||
@classmethod
|
||||
def check_data(cls, v: Any):
|
||||
"""For compatibility, if body is not set, return empty list."""
|
||||
if not v:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [BodyData(key="", type="text", value=v)]
|
||||
return v
|
||||
|
||||
|
||||
class HttpRequestNodeTimeout(BaseModel):
|
||||
connect: int = dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT
|
||||
read: int = dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT
|
||||
write: int = dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT
|
||||
|
||||
|
||||
class HttpRequestNodeData(BaseNodeData):
|
||||
"""
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
method: Literal[
|
||||
"get",
|
||||
"post",
|
||||
"put",
|
||||
"patch",
|
||||
"delete",
|
||||
"head",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
]
|
||||
url: str
|
||||
authorization: HttpRequestNodeAuthorization
|
||||
headers: str
|
||||
params: str
|
||||
body: HttpRequestNodeBody | None = None
|
||||
timeout: HttpRequestNodeTimeout | None = None
|
||||
ssl_verify: bool | None = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
|
||||
|
||||
class Response:
|
||||
headers: dict[str, str]
|
||||
response: httpx.Response
|
||||
|
||||
def __init__(self, response: httpx.Response):
|
||||
self.response = response
|
||||
self.headers = dict(response.headers)
|
||||
|
||||
@property
|
||||
def is_file(self):
|
||||
"""
|
||||
Determine if the response contains a file by checking:
|
||||
1. Content-Disposition header (RFC 6266)
|
||||
2. Content characteristics
|
||||
3. MIME type analysis
|
||||
"""
|
||||
content_type = self.content_type.split(";")[0].strip().lower()
|
||||
parsed_content_disposition = self.parsed_content_disposition
|
||||
|
||||
# Check if it's explicitly marked as an attachment
|
||||
if parsed_content_disposition:
|
||||
disp_type = parsed_content_disposition.get_content_disposition() # Returns 'attachment', 'inline', or None
|
||||
filename = parsed_content_disposition.get_filename() # Returns filename if present, None otherwise
|
||||
if disp_type == "attachment" or filename is not None:
|
||||
return True
|
||||
|
||||
# For 'text/' types, only 'csv' should be downloaded as file
|
||||
if content_type.startswith("text/") and "csv" not in content_type:
|
||||
return False
|
||||
|
||||
# For application types, try to detect if it's a text-based format
|
||||
if content_type.startswith("application/"):
|
||||
# Common text-based application types
|
||||
if any(
|
||||
text_type in content_type
|
||||
for text_type in ("json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql")
|
||||
):
|
||||
return False
|
||||
|
||||
# Try to detect if content is text-based by sampling first few bytes
|
||||
try:
|
||||
# Sample first 1024 bytes for text detection
|
||||
content_sample = self.response.content[:1024]
|
||||
content_sample.decode("utf-8")
|
||||
# If we can decode as UTF-8 and find common text patterns, likely not a file
|
||||
text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ")
|
||||
if any(marker in content_sample for marker in text_markers):
|
||||
return False
|
||||
except UnicodeDecodeError:
|
||||
# If we can't decode as UTF-8, likely a binary file
|
||||
return True
|
||||
|
||||
# For other types, use MIME type analysis
|
||||
main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or ""))
|
||||
if main_type:
|
||||
return main_type.split("/")[0] in ("application", "image", "audio", "video")
|
||||
|
||||
# For unknown types, check if it's a media type
|
||||
return any(media_type in content_type for media_type in ("image/", "audio/", "video/"))
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
return self.headers.get("content-type", "")
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return self.response.text
|
||||
|
||||
@property
|
||||
def content(self) -> bytes:
|
||||
return self.response.content
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
return self.response.status_code
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return len(self.content)
|
||||
|
||||
@property
|
||||
def readable_size(self) -> str:
|
||||
if self.size < 1024:
|
||||
return f"{self.size} bytes"
|
||||
elif self.size < 1024 * 1024:
|
||||
return f"{(self.size / 1024):.2f} KB"
|
||||
else:
|
||||
return f"{(self.size / 1024 / 1024):.2f} MB"
|
||||
|
||||
@property
|
||||
def parsed_content_disposition(self) -> Message | None:
|
||||
content_disposition = self.headers.get("content-disposition", "")
|
||||
if content_disposition:
|
||||
msg = Message()
|
||||
msg["content-disposition"] = content_disposition
|
||||
return msg
|
||||
return None
|
||||
26
dify/api/core/workflow/nodes/http_request/exc.py
Normal file
26
dify/api/core/workflow/nodes/http_request/exc.py
Normal file
@@ -0,0 +1,26 @@
|
||||
class HttpRequestNodeError(ValueError):
|
||||
"""Custom error for HTTP request node."""
|
||||
|
||||
|
||||
class AuthorizationConfigError(HttpRequestNodeError):
|
||||
"""Raised when authorization config is missing or invalid."""
|
||||
|
||||
|
||||
class FileFetchError(HttpRequestNodeError):
|
||||
"""Raised when a file cannot be fetched."""
|
||||
|
||||
|
||||
class InvalidHttpMethodError(HttpRequestNodeError):
|
||||
"""Raised when an invalid HTTP method is used."""
|
||||
|
||||
|
||||
class ResponseSizeError(HttpRequestNodeError):
|
||||
"""Raised when the response size exceeds the allowed threshold."""
|
||||
|
||||
|
||||
class RequestBodyError(HttpRequestNodeError):
|
||||
"""Raised when the request body is invalid."""
|
||||
|
||||
|
||||
class InvalidURLError(HttpRequestNodeError):
|
||||
"""Raised when the URL is invalid."""
|
||||
463
dify/api/core/workflow/nodes/http_request/executor.py
Normal file
463
dify/api/core/workflow/nodes/http_request/executor.py
Normal file
@@ -0,0 +1,463 @@
|
||||
import base64
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import httpx
|
||||
from json_repair import repair_json
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_manager
|
||||
from core.file.enums import FileTransferMethod
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from .entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeData,
|
||||
HttpRequestNodeTimeout,
|
||||
Response,
|
||||
)
|
||||
from .exc import (
|
||||
AuthorizationConfigError,
|
||||
FileFetchError,
|
||||
HttpRequestNodeError,
|
||||
InvalidHttpMethodError,
|
||||
InvalidURLError,
|
||||
RequestBodyError,
|
||||
ResponseSizeError,
|
||||
)
|
||||
|
||||
BODY_TYPE_TO_CONTENT_TYPE = {
|
||||
"json": "application/json",
|
||||
"x-www-form-urlencoded": "application/x-www-form-urlencoded",
|
||||
"form-data": "multipart/form-data",
|
||||
"raw-text": "text/plain",
|
||||
}
|
||||
|
||||
|
||||
class Executor:
|
||||
method: Literal[
|
||||
"get",
|
||||
"head",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
]
|
||||
url: str
|
||||
params: list[tuple[str, str]] | None
|
||||
content: str | bytes | None
|
||||
data: Mapping[str, Any] | None
|
||||
files: list[tuple[str, tuple[str | None, bytes, str]]] | None
|
||||
json: Any
|
||||
headers: dict[str, str]
|
||||
auth: HttpRequestNodeAuthorization
|
||||
timeout: HttpRequestNodeTimeout
|
||||
max_retries: int
|
||||
|
||||
boundary: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
node_data: HttpRequestNodeData,
|
||||
timeout: HttpRequestNodeTimeout,
|
||||
variable_pool: VariablePool,
|
||||
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
):
|
||||
# If authorization API key is present, convert the API key using the variable pool
|
||||
if node_data.authorization.type == "api-key":
|
||||
if node_data.authorization.config is None:
|
||||
raise AuthorizationConfigError("authorization config is required")
|
||||
node_data.authorization.config.api_key = variable_pool.convert_template(
|
||||
node_data.authorization.config.api_key
|
||||
).text
|
||||
|
||||
self.url = node_data.url
|
||||
self.method = node_data.method
|
||||
self.auth = node_data.authorization
|
||||
self.timeout = timeout
|
||||
self.ssl_verify = node_data.ssl_verify
|
||||
self.params = None
|
||||
self.headers = {}
|
||||
self.content = None
|
||||
self.files = None
|
||||
self.data = None
|
||||
self.json = None
|
||||
self.max_retries = max_retries
|
||||
|
||||
# init template
|
||||
self.variable_pool = variable_pool
|
||||
self.node_data = node_data
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self):
|
||||
self._init_url()
|
||||
self._init_params()
|
||||
self._init_headers()
|
||||
self._init_body()
|
||||
|
||||
def _init_url(self):
|
||||
self.url = self.variable_pool.convert_template(self.node_data.url).text
|
||||
|
||||
# check if url is a valid URL
|
||||
if not self.url:
|
||||
raise InvalidURLError("url is required")
|
||||
if not self.url.startswith(("http://", "https://")):
|
||||
raise InvalidURLError("url should start with http:// or https://")
|
||||
|
||||
def _init_params(self):
|
||||
"""
|
||||
Almost same as _init_headers(), difference:
|
||||
1. response a list tuple to support same key, like 'aa=1&aa=2'
|
||||
2. param value may have '\n', we need to splitlines then extract the variable value.
|
||||
"""
|
||||
result = []
|
||||
for line in self.node_data.params.splitlines():
|
||||
if not (line := line.strip()):
|
||||
continue
|
||||
|
||||
key, *value = line.split(":", 1)
|
||||
if not (key := key.strip()):
|
||||
continue
|
||||
|
||||
value_str = value[0].strip() if value else ""
|
||||
result.append(
|
||||
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text)
|
||||
)
|
||||
|
||||
if result:
|
||||
self.params = result
|
||||
|
||||
def _init_headers(self):
|
||||
"""
|
||||
Convert the header string of frontend to a dictionary.
|
||||
|
||||
Each line in the header string represents a key-value pair.
|
||||
Keys and values are separated by ':'.
|
||||
Empty values are allowed.
|
||||
|
||||
Examples:
|
||||
'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'}
|
||||
'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'}
|
||||
'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'}
|
||||
|
||||
"""
|
||||
headers = self.variable_pool.convert_template(self.node_data.headers).text
|
||||
self.headers = {
|
||||
key.strip(): (value[0].strip() if value else "")
|
||||
for line in headers.splitlines()
|
||||
if line.strip()
|
||||
for key, *value in [line.split(":", 1)]
|
||||
}
|
||||
|
||||
def _init_body(self):
|
||||
body = self.node_data.body
|
||||
if body is not None:
|
||||
data = body.data
|
||||
match body.type:
|
||||
case "none":
|
||||
self.content = ""
|
||||
case "raw-text":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("raw-text body type should have exactly one item")
|
||||
self.content = self.variable_pool.convert_template(data[0].value).text
|
||||
case "json":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("json body type should have exactly one item")
|
||||
json_string = self.variable_pool.convert_template(data[0].value).text
|
||||
try:
|
||||
repaired = repair_json(json_string)
|
||||
json_object = json.loads(repaired, strict=False)
|
||||
except json.JSONDecodeError as e:
|
||||
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
|
||||
self.json = json_object
|
||||
# self.json = self._parse_object_contains_variables(json_object)
|
||||
case "binary":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("binary body type should have exactly one item")
|
||||
file_selector = data[0].file
|
||||
file_variable = self.variable_pool.get_file(file_selector)
|
||||
if file_variable is None:
|
||||
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
|
||||
file = file_variable.value
|
||||
self.content = file_manager.download(file)
|
||||
case "x-www-form-urlencoded":
|
||||
form_data = {
|
||||
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
|
||||
item.value
|
||||
).text
|
||||
for item in data
|
||||
}
|
||||
self.data = form_data
|
||||
case "form-data":
|
||||
form_data = {
|
||||
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
|
||||
item.value
|
||||
).text
|
||||
for item in filter(lambda item: item.type == "text", data)
|
||||
}
|
||||
file_selectors = {
|
||||
self.variable_pool.convert_template(item.key).text: item.file
|
||||
for item in filter(lambda item: item.type == "file", data)
|
||||
}
|
||||
|
||||
# get files from file_selectors, add support for array file variables
|
||||
files_list = []
|
||||
for key, selector in file_selectors.items():
|
||||
segment = self.variable_pool.get(selector)
|
||||
if isinstance(segment, FileSegment):
|
||||
files_list.append((key, [segment.value]))
|
||||
elif isinstance(segment, ArrayFileSegment):
|
||||
files_list.append((key, list(segment.value)))
|
||||
|
||||
# get files from file_manager
|
||||
files: dict[str, list[tuple[str | None, bytes, str]]] = {}
|
||||
for key, files_in_segment in files_list:
|
||||
for file in files_in_segment:
|
||||
if file.related_id is not None or (
|
||||
file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None
|
||||
):
|
||||
file_tuple = (
|
||||
file.filename,
|
||||
file_manager.download(file),
|
||||
file.mime_type or "application/octet-stream",
|
||||
)
|
||||
if key not in files:
|
||||
files[key] = []
|
||||
files[key].append(file_tuple)
|
||||
|
||||
# convert files to list for httpx request
|
||||
# If there are no actual files, we still need to force httpx to use `multipart/form-data`.
|
||||
# This is achieved by inserting a harmless placeholder file that will be ignored by the server.
|
||||
if not files:
|
||||
self.files = [("__multipart_placeholder__", ("", b"", "application/octet-stream"))]
|
||||
if files:
|
||||
self.files = []
|
||||
for key, file_tuples in files.items():
|
||||
for file_tuple in file_tuples:
|
||||
self.files.append((key, file_tuple))
|
||||
|
||||
self.data = form_data
|
||||
|
||||
def _assembling_headers(self) -> dict[str, Any]:
|
||||
authorization = deepcopy(self.auth)
|
||||
headers = deepcopy(self.headers) or {}
|
||||
if self.auth.type == "api-key":
|
||||
if self.auth.config is None:
|
||||
raise AuthorizationConfigError("self.authorization config is required")
|
||||
if authorization.config is None:
|
||||
raise AuthorizationConfigError("authorization config is required")
|
||||
|
||||
if not authorization.config.header:
|
||||
authorization.config.header = "Authorization"
|
||||
|
||||
if self.auth.config.type == "bearer" and authorization.config.api_key:
|
||||
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
|
||||
elif self.auth.config.type == "basic" and authorization.config.api_key:
|
||||
credentials = authorization.config.api_key
|
||||
if ":" in credentials:
|
||||
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
|
||||
else:
|
||||
encoded_credentials = credentials
|
||||
headers[authorization.config.header] = f"Basic {encoded_credentials}"
|
||||
elif self.auth.config.type == "custom":
|
||||
if authorization.config.header and authorization.config.api_key:
|
||||
headers[authorization.config.header] = authorization.config.api_key
|
||||
|
||||
# Handle Content-Type for multipart/form-data requests
|
||||
# Fix for issue #23829: Missing boundary when using multipart/form-data
|
||||
body = self.node_data.body
|
||||
if body and body.type == "form-data":
|
||||
# For multipart/form-data with files (including placeholder files),
|
||||
# remove any manually set Content-Type header to let httpx handle
|
||||
# For multipart/form-data, if any files are present (including placeholder files),
|
||||
# we must remove any manually set Content-Type header. This is because httpx needs to
|
||||
# automatically set the Content-Type and boundary for multipart encoding whenever files
|
||||
# are included, even if they are placeholders, to avoid boundary issues and ensure correct
|
||||
# file upload behaviour. Manually setting Content-Type can cause httpx to fail to set the
|
||||
# boundary, resulting in invalid requests.
|
||||
if self.files:
|
||||
# Remove Content-Type if it was manually set to avoid boundary issues
|
||||
headers = {k: v for k, v in headers.items() if k.lower() != "content-type"}
|
||||
else:
|
||||
# No files at all, set Content-Type manually
|
||||
if "content-type" not in (k.lower() for k in headers):
|
||||
headers["Content-Type"] = "multipart/form-data"
|
||||
elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE:
|
||||
# Set Content-Type for other body types
|
||||
if "content-type" not in (k.lower() for k in headers):
|
||||
headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
|
||||
|
||||
return headers
|
||||
|
||||
def _validate_and_parse_response(self, response: httpx.Response) -> Response:
|
||||
executor_response = Response(response)
|
||||
|
||||
threshold_size = (
|
||||
dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE
|
||||
if executor_response.is_file
|
||||
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
|
||||
)
|
||||
if executor_response.size > threshold_size:
|
||||
raise ResponseSizeError(
|
||||
f"{'File' if executor_response.is_file else 'Text'} size is too large,"
|
||||
f" max size is {threshold_size / 1024 / 1024:.2f} MB,"
|
||||
f" but current size is {executor_response.readable_size}."
|
||||
)
|
||||
|
||||
return executor_response
|
||||
|
||||
def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
_METHOD_MAP = {
|
||||
"get": ssrf_proxy.get,
|
||||
"head": ssrf_proxy.head,
|
||||
"post": ssrf_proxy.post,
|
||||
"put": ssrf_proxy.put,
|
||||
"delete": ssrf_proxy.delete,
|
||||
"patch": ssrf_proxy.patch,
|
||||
}
|
||||
method_lc = self.method.lower()
|
||||
if method_lc not in _METHOD_MAP:
|
||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||
|
||||
request_args = {
|
||||
"url": self.url,
|
||||
"data": self.data,
|
||||
"files": self.files,
|
||||
"json": self.json,
|
||||
"content": self.content,
|
||||
"headers": headers,
|
||||
"params": self.params,
|
||||
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
|
||||
"ssl_verify": self.ssl_verify,
|
||||
"follow_redirects": True,
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
raise HttpRequestNodeError(str(e)) from e
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
return response
|
||||
|
||||
def invoke(self) -> Response:
|
||||
# assemble headers
|
||||
headers = self._assembling_headers()
|
||||
# do http request
|
||||
response = self._do_http_request(headers)
|
||||
# validate response
|
||||
return self._validate_and_parse_response(response)
|
||||
|
||||
def to_log(self):
|
||||
url_parts = urlparse(self.url)
|
||||
path = url_parts.path or "/"
|
||||
|
||||
# Add query parameters
|
||||
if self.params:
|
||||
query_string = urlencode(self.params)
|
||||
path += f"?{query_string}"
|
||||
elif url_parts.query:
|
||||
path += f"?{url_parts.query}"
|
||||
|
||||
raw = f"{self.method.upper()} {path} HTTP/1.1\r\n"
|
||||
raw += f"Host: {url_parts.netloc}\r\n"
|
||||
|
||||
headers = self._assembling_headers()
|
||||
body = self.node_data.body
|
||||
boundary = f"----WebKitFormBoundary{_generate_random_string(16)}"
|
||||
if body:
|
||||
if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE:
|
||||
headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
|
||||
if body.type == "form-data":
|
||||
headers["Content-Type"] = f"multipart/form-data; boundary={boundary}"
|
||||
for k, v in headers.items():
|
||||
if self.auth.type == "api-key":
|
||||
authorization_header = "Authorization"
|
||||
if self.auth.config and self.auth.config.header:
|
||||
authorization_header = self.auth.config.header
|
||||
if k.lower() == authorization_header.lower():
|
||||
raw += f"{k}: {'*' * len(v)}\r\n"
|
||||
continue
|
||||
raw += f"{k}: {v}\r\n"
|
||||
|
||||
body_string = ""
|
||||
# Only log actual files if present.
|
||||
# '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file.
|
||||
# This prevents logging meaningless placeholder entries.
|
||||
if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files):
|
||||
for file_entry in self.files:
|
||||
# file_entry should be (key, (filename, content, mime_type)), but handle edge cases
|
||||
if len(file_entry) != 2 or len(file_entry[1]) < 2:
|
||||
continue # skip malformed entries
|
||||
key = file_entry[0]
|
||||
content = file_entry[1][1]
|
||||
body_string += f"--{boundary}\r\n"
|
||||
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
# decode content safely
|
||||
try:
|
||||
body_string += content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
body_string += content.decode("utf-8", errors="replace")
|
||||
body_string += "\r\n"
|
||||
body_string += f"--{boundary}--\r\n"
|
||||
elif self.node_data.body:
|
||||
if self.content:
|
||||
if isinstance(self.content, bytes):
|
||||
body_string = self.content.decode("utf-8", errors="replace")
|
||||
else:
|
||||
body_string = self.content
|
||||
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
|
||||
body_string = urlencode(self.data)
|
||||
elif self.data and self.node_data.body.type == "form-data":
|
||||
for key, value in self.data.items():
|
||||
body_string += f"--{boundary}\r\n"
|
||||
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
body_string += f"{value}\r\n"
|
||||
body_string += f"--{boundary}--\r\n"
|
||||
elif self.json:
|
||||
body_string = json.dumps(self.json)
|
||||
elif self.node_data.body.type == "raw-text":
|
||||
if len(self.node_data.body.data) != 1:
|
||||
raise RequestBodyError("raw-text body type should have exactly one item")
|
||||
body_string = self.node_data.body.data[0].value
|
||||
if body_string:
|
||||
raw += f"Content-Length: {len(body_string)}\r\n"
|
||||
raw += "\r\n" # Empty line between headers and body
|
||||
raw += body_string
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def _generate_random_string(n: int) -> str:
|
||||
"""
|
||||
Generate a random string of lowercase ASCII letters.
|
||||
|
||||
Args:
|
||||
n (int): The length of the random string to generate.
|
||||
|
||||
Returns:
|
||||
str: A random string of lowercase ASCII letters with length n.
|
||||
|
||||
Example:
|
||||
>>> _generate_random_string(5)
|
||||
'abcde'
|
||||
"""
|
||||
return "".join(secrets.choice(string.ascii_lowercase) for _ in range(n))
|
||||
249
dify/api/core/workflow/nodes/http_request/node.py
Normal file
249
dify/api/core/workflow/nodes/http_request/node.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from factories import file_factory
|
||||
|
||||
from .entities import (
|
||||
HttpRequestNodeData,
|
||||
HttpRequestNodeTimeout,
|
||||
Response,
|
||||
)
|
||||
from .exc import HttpRequestNodeError, RequestBodyError
|
||||
|
||||
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
|
||||
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
|
||||
write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpRequestNode(Node):
|
||||
node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
_node_data: HttpRequestNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = HttpRequestNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "http-request",
|
||||
"config": {
|
||||
"method": "get",
|
||||
"authorization": {
|
||||
"type": "no-auth",
|
||||
},
|
||||
"body": {"type": "none"},
|
||||
"timeout": {
|
||||
**HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(),
|
||||
"max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
"max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
|
||||
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||
},
|
||||
"ssl_verify": dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
},
|
||||
"retry_config": {
|
||||
"max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
"retry_interval": 0.5 * (2**2),
|
||||
"retry_enabled": True,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
process_data = {}
|
||||
try:
|
||||
http_executor = Executor(
|
||||
node_data=self._node_data,
|
||||
timeout=self._get_request_timeout(self._node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
max_retries=0,
|
||||
)
|
||||
process_data["request"] = http_executor.to_log()
|
||||
|
||||
response = http_executor.invoke()
|
||||
files = self.extract_files(url=http_executor.url, response=response)
|
||||
if not response.response.is_success and (self.error_strategy or self.retry):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
outputs={
|
||||
"status_code": response.status_code,
|
||||
"body": response.text if not files.value else "",
|
||||
"headers": response.headers,
|
||||
"files": files,
|
||||
},
|
||||
process_data={
|
||||
"request": http_executor.to_log(),
|
||||
},
|
||||
error=f"Request failed with status code {response.status_code}",
|
||||
error_type="HTTPResponseCodeError",
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"status_code": response.status_code,
|
||||
"body": response.text if not files.value else "",
|
||||
"headers": response.headers,
|
||||
"files": files,
|
||||
},
|
||||
process_data={
|
||||
"request": http_executor.to_log(),
|
||||
},
|
||||
)
|
||||
except HttpRequestNodeError as e:
|
||||
logger.warning("http request node %s failed to run: %s", self._node_id, e)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
process_data=process_data,
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout:
|
||||
timeout = node_data.timeout
|
||||
if timeout is None:
|
||||
return HTTP_REQUEST_DEFAULT_TIMEOUT
|
||||
|
||||
timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect
|
||||
timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read
|
||||
timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write
|
||||
return timeout
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = HttpRequestNodeData.model_validate(node_data)
|
||||
|
||||
selectors: list[VariableSelector] = []
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
|
||||
if typed_node_data.body:
|
||||
body_type = typed_node_data.body.type
|
||||
data = typed_node_data.body.data
|
||||
match body_type:
|
||||
case "none":
|
||||
pass
|
||||
case "binary":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("invalid body data, should have only one item")
|
||||
selector = data[0].file
|
||||
selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector))
|
||||
case "json" | "raw-text":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("invalid body data, should have only one item")
|
||||
selectors += variable_template_parser.extract_selectors_from_template(data[0].key)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(data[0].value)
|
||||
case "x-www-form-urlencoded":
|
||||
for item in data:
|
||||
selectors += variable_template_parser.extract_selectors_from_template(item.key)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(item.value)
|
||||
case "form-data":
|
||||
for item in data:
|
||||
selectors += variable_template_parser.extract_selectors_from_template(item.key)
|
||||
if item.type == "text":
|
||||
selectors += variable_template_parser.extract_selectors_from_template(item.value)
|
||||
elif item.type == "file":
|
||||
selectors.append(
|
||||
VariableSelector(variable="#" + ".".join(item.file) + "#", value_selector=item.file)
|
||||
)
|
||||
|
||||
mapping = {}
|
||||
for selector_iter in selectors:
|
||||
mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector
|
||||
|
||||
return mapping
|
||||
|
||||
def extract_files(self, url: str, response: Response) -> ArrayFileSegment:
|
||||
"""
|
||||
Extract files from response by checking both Content-Type header and URL
|
||||
"""
|
||||
files: list[File] = []
|
||||
is_file = response.is_file
|
||||
content_type = response.content_type
|
||||
content = response.content
|
||||
parsed_content_disposition = response.parsed_content_disposition
|
||||
content_disposition_type = None
|
||||
|
||||
if not is_file:
|
||||
return ArrayFileSegment(value=[])
|
||||
|
||||
if parsed_content_disposition:
|
||||
content_disposition_filename = parsed_content_disposition.get_filename()
|
||||
if content_disposition_filename:
|
||||
# If filename is available from content-disposition, use it to guess the content type
|
||||
content_disposition_type = mimetypes.guess_type(content_disposition_filename)[0]
|
||||
|
||||
# Guess file extension from URL or Content-Type header
|
||||
filename = url.split("?")[0].split("/")[-1] or ""
|
||||
mime_type = (
|
||||
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
)
|
||||
tool_file_manager = ToolFileManager()
|
||||
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=content,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file.id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
return ArrayFileSegment(value=files)
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
3
dify/api/core/workflow/nodes/human_input/__init__.py
Normal file
3
dify/api/core/workflow/nodes/human_input/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .human_input_node import HumanInputNode
|
||||
|
||||
__all__ = ["HumanInputNode"]
|
||||
10
dify/api/core/workflow/nodes/human_input/entities.py
Normal file
10
dify/api/core/workflow/nodes/human_input/entities.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from pydantic import Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class HumanInputNodeData(BaseNodeData):
|
||||
"""Configuration schema for the HumanInput node."""
|
||||
|
||||
required_variables: list[str] = Field(default_factory=list)
|
||||
pause_reason: str | None = Field(default=None)
|
||||
133
dify/api/core/workflow/nodes/human_input/human_input_node.py
Normal file
133
dify/api/core/workflow/nodes/human_input/human_input_node.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import HumanInputNodeData
|
||||
|
||||
|
||||
class HumanInputNode(Node):
|
||||
node_type = NodeType.HUMAN_INPUT
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
_BRANCH_SELECTION_KEYS: tuple[str, ...] = (
|
||||
"edge_source_handle",
|
||||
"edgeSourceHandle",
|
||||
"source_handle",
|
||||
"selected_branch",
|
||||
"selectedBranch",
|
||||
"branch",
|
||||
"branch_id",
|
||||
"branchId",
|
||||
"handle",
|
||||
)
|
||||
|
||||
_node_data: HumanInputNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = HumanInputNodeData(**data)
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def _run(self): # type: ignore[override]
|
||||
if self._is_completion_ready():
|
||||
branch_handle = self._resolve_branch_selection()
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={},
|
||||
edge_source_handle=branch_handle or "source",
|
||||
)
|
||||
|
||||
return self._pause_generator()
|
||||
|
||||
def _pause_generator(self):
|
||||
yield PauseRequestedEvent(reason=HumanInputRequired())
|
||||
|
||||
def _is_completion_ready(self) -> bool:
|
||||
"""Determine whether all required inputs are satisfied."""
|
||||
|
||||
if not self._node_data.required_variables:
|
||||
return False
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
for selector_str in self._node_data.required_variables:
|
||||
parts = selector_str.split(".")
|
||||
if len(parts) != 2:
|
||||
return False
|
||||
segment = variable_pool.get(parts)
|
||||
if segment is None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _resolve_branch_selection(self) -> str | None:
|
||||
"""Determine the branch handle selected by human input if available."""
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
for key in self._BRANCH_SELECTION_KEYS:
|
||||
handle = self._extract_branch_handle(variable_pool.get((self.id, key)))
|
||||
if handle:
|
||||
return handle
|
||||
|
||||
default_values = self._node_data.default_value_dict
|
||||
for key in self._BRANCH_SELECTION_KEYS:
|
||||
handle = self._normalize_branch_value(default_values.get(key))
|
||||
if handle:
|
||||
return handle
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_branch_handle(segment: Any) -> str | None:
|
||||
if segment is None:
|
||||
return None
|
||||
|
||||
candidate = getattr(segment, "to_object", None)
|
||||
raw_value = candidate() if callable(candidate) else getattr(segment, "value", None)
|
||||
if raw_value is None:
|
||||
return None
|
||||
|
||||
return HumanInputNode._normalize_branch_value(raw_value)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_branch_value(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
return stripped or None
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"):
|
||||
candidate = value.get(key)
|
||||
if isinstance(candidate, str) and candidate:
|
||||
return candidate
|
||||
|
||||
return None
|
||||
3
dify/api/core/workflow/nodes/if_else/__init__.py
Normal file
3
dify/api/core/workflow/nodes/if_else/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .if_else_node import IfElseNode
|
||||
|
||||
__all__ = ["IfElseNode"]
|
||||
26
dify/api/core/workflow/nodes/if_else/entities.py
Normal file
26
dify/api/core/workflow/nodes/if_else/entities.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class IfElseNodeData(BaseNodeData):
|
||||
"""
|
||||
If Else Node Data.
|
||||
"""
|
||||
|
||||
class Case(BaseModel):
|
||||
"""
|
||||
Case entity representing a single logical condition group
|
||||
"""
|
||||
|
||||
case_id: str
|
||||
logical_operator: Literal["and", "or"]
|
||||
conditions: list[Condition]
|
||||
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
cases: list[Case] | None = None
|
||||
150
dify/api/core/workflow/nodes/if_else/if_else_node.py
Normal file
150
dify/api/core/workflow/nodes/if_else/if_else_node.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class IfElseNode(Node):
|
||||
node_type = NodeType.IF_ELSE
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
_node_data: IfElseNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IfElseNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
node_inputs: dict[str, Sequence[Mapping[str, Any]]] = {"conditions": []}
|
||||
|
||||
process_data: dict[str, list] = {"condition_results": []}
|
||||
|
||||
input_conditions: Sequence[Mapping[str, Any]] = []
|
||||
final_result = False
|
||||
selected_case_id = "false"
|
||||
condition_processor = ConditionProcessor()
|
||||
try:
|
||||
# Check if the new cases structure is used
|
||||
if self._node_data.cases:
|
||||
for case in self._node_data.cases:
|
||||
input_conditions, group_result, final_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=case.conditions,
|
||||
operator=case.logical_operator,
|
||||
)
|
||||
|
||||
process_data["condition_results"].append(
|
||||
{
|
||||
"group": case.model_dump(),
|
||||
"results": group_result,
|
||||
"final_result": final_result,
|
||||
}
|
||||
)
|
||||
|
||||
# Break if a case passes (logical short-circuit)
|
||||
if final_result:
|
||||
selected_case_id = case.case_id # Capture the ID of the passing case
|
||||
break
|
||||
|
||||
else:
|
||||
# TODO: Update database then remove this
|
||||
# Fallback to old structure if cases are not defined
|
||||
input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated]
|
||||
condition_processor=condition_processor,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=self._node_data.conditions or [],
|
||||
operator=self._node_data.logical_operator or "and",
|
||||
)
|
||||
|
||||
selected_case_id = "true" if final_result else "false"
|
||||
|
||||
process_data["condition_results"].append(
|
||||
{"group": "default", "results": group_result, "final_result": final_result}
|
||||
)
|
||||
|
||||
node_inputs["conditions"] = input_conditions
|
||||
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_data, error=str(e)
|
||||
)
|
||||
|
||||
outputs = {"result": final_result, "selected_case_id": selected_case_id}
|
||||
|
||||
data = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
edge_source_handle=selected_case_id or "false", # Use case ID or 'default'
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = IfElseNodeData.model_validate(node_data)
|
||||
|
||||
var_mapping: dict[str, list[str]] = {}
|
||||
for case in typed_node_data.cases or []:
|
||||
for condition in case.conditions:
|
||||
key = f"{node_id}.#{'.'.join(condition.variable_selector)}#"
|
||||
var_mapping[key] = condition.variable_selector
|
||||
|
||||
return var_mapping
|
||||
|
||||
|
||||
@deprecated("This function is deprecated. You should use the new cases structure.")
|
||||
def _should_not_use_old_function(
|
||||
*,
|
||||
condition_processor: ConditionProcessor,
|
||||
variable_pool: VariablePool,
|
||||
conditions: list[Condition],
|
||||
operator: Literal["and", "or"],
|
||||
):
|
||||
return condition_processor.process_conditions(
|
||||
variable_pool=variable_pool,
|
||||
conditions=conditions,
|
||||
operator=operator,
|
||||
)
|
||||
5
dify/api/core/workflow/nodes/iteration/__init__.py
Normal file
5
dify/api/core/workflow/nodes/iteration/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .entities import IterationNodeData
|
||||
from .iteration_node import IterationNode
|
||||
from .iteration_start_node import IterationStartNode
|
||||
|
||||
__all__ = ["IterationNode", "IterationNodeData", "IterationStartNode"]
|
||||
64
dify/api/core/workflow/nodes/iteration/entities.py
Normal file
64
dify/api/core/workflow/nodes/iteration/entities.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
|
||||
|
||||
class ErrorHandleMode(StrEnum):
|
||||
TERMINATED = "terminated"
|
||||
CONTINUE_ON_ERROR = "continue-on-error"
|
||||
REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output"
|
||||
|
||||
|
||||
class IterationNodeData(BaseIterationNodeData):
|
||||
"""
|
||||
Iteration Node Data.
|
||||
"""
|
||||
|
||||
parent_loop_id: str | None = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
is_parallel: bool = False # open the parallel mode or not
|
||||
parallel_nums: int = 10 # the numbers of parallel
|
||||
error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error
|
||||
flatten_output: bool = True # whether to flatten the output array if all elements are lists
|
||||
|
||||
|
||||
class IterationStartNodeData(BaseNodeData):
|
||||
"""
|
||||
Iteration Start Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IterationState(BaseIterationState):
|
||||
"""
|
||||
Iteration State.
|
||||
"""
|
||||
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Any = None
|
||||
|
||||
class MetaData(BaseIterationState.MetaData):
|
||||
"""
|
||||
Data.
|
||||
"""
|
||||
|
||||
iterator_length: int
|
||||
|
||||
def get_last_output(self) -> Any:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
if self.outputs:
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Any:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
return self.current_output
|
||||
22
dify/api/core/workflow/nodes/iteration/exc.py
Normal file
22
dify/api/core/workflow/nodes/iteration/exc.py
Normal file
@@ -0,0 +1,22 @@
|
||||
class IterationNodeError(ValueError):
|
||||
"""Base class for iteration node errors."""
|
||||
|
||||
|
||||
class IteratorVariableNotFoundError(IterationNodeError):
|
||||
"""Raised when the iterator variable is not found."""
|
||||
|
||||
|
||||
class InvalidIteratorValueError(IterationNodeError):
|
||||
"""Raised when the iterator value is invalid."""
|
||||
|
||||
|
||||
class StartNodeIdNotFoundError(IterationNodeError):
|
||||
"""Raised when the start node ID is not found."""
|
||||
|
||||
|
||||
class IterationGraphNotFoundError(IterationNodeError):
|
||||
"""Raised when the iteration graph is not found."""
|
||||
|
||||
|
||||
class IterationIndexNotFoundError(IterationNodeError):
|
||||
"""Raised when the iteration index is not found."""
|
||||
667
dify/api/core/workflow/nodes/iteration/iteration_node.py
Normal file
667
dify/api/core/workflow/nodes/iteration/iteration_node.py
Normal file
@@ -0,0 +1,667 @@
|
||||
import contextvars
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
IterationFailedEvent,
|
||||
IterationNextEvent,
|
||||
IterationStartedEvent,
|
||||
IterationSucceededEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .exc import (
|
||||
InvalidIteratorValueError,
|
||||
IterationGraphNotFoundError,
|
||||
IterationIndexNotFoundError,
|
||||
IterationNodeError,
|
||||
IteratorVariableNotFoundError,
|
||||
StartNodeIdNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
||||
|
||||
|
||||
class IterationNode(LLMUsageTrackingMixin, Node):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
|
||||
node_type = NodeType.ITERATION
|
||||
execution_type = NodeExecutionType.CONTAINER
|
||||
_node_data: IterationNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "iteration",
|
||||
"config": {
|
||||
"is_parallel": False,
|
||||
"parallel_nums": 10,
|
||||
"error_handle_mode": ErrorHandleMode.TERMINATED,
|
||||
"flatten_output": True,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore
|
||||
variable = self._get_iterator_variable()
|
||||
|
||||
if self._is_empty_iteration(variable):
|
||||
yield from self._handle_empty_iteration(variable)
|
||||
return
|
||||
|
||||
iterator_list_value = self._validate_and_get_iterator_list(variable)
|
||||
inputs = {"iterator_selector": iterator_list_value}
|
||||
|
||||
self._validate_start_node()
|
||||
|
||||
started_at = naive_utc_now()
|
||||
iter_run_map: dict[str, float] = {}
|
||||
outputs: list[object] = []
|
||||
usage_accumulator = [LLMUsage.empty_usage()]
|
||||
|
||||
yield IterationStartedEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
metadata={"iteration_length": len(iterator_list_value)},
|
||||
)
|
||||
|
||||
try:
|
||||
yield from self._execute_iterations(
|
||||
iterator_list_value=iterator_list_value,
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
usage_accumulator=usage_accumulator,
|
||||
)
|
||||
|
||||
self._accumulate_usage(usage_accumulator[0])
|
||||
yield from self._handle_iteration_success(
|
||||
started_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
iterator_list_value=iterator_list_value,
|
||||
iter_run_map=iter_run_map,
|
||||
usage=usage_accumulator[0],
|
||||
)
|
||||
except IterationNodeError as e:
|
||||
self._accumulate_usage(usage_accumulator[0])
|
||||
yield from self._handle_iteration_failure(
|
||||
started_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
iterator_list_value=iterator_list_value,
|
||||
iter_run_map=iter_run_map,
|
||||
usage=usage_accumulator[0],
|
||||
error=e,
|
||||
)
|
||||
|
||||
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
|
||||
|
||||
if not variable:
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
||||
|
||||
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
|
||||
return variable
|
||||
|
||||
def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]:
|
||||
return isinstance(variable, NoneSegment) or len(variable.value) == 0
|
||||
|
||||
def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]:
|
||||
# Try our best to preserve the type information.
|
||||
if isinstance(variable, ArraySegment):
|
||||
output = variable.model_copy(update={"value": []})
|
||||
else:
|
||||
output = ArrayAnySegment(value=[])
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
# TODO(QuantumGhost): is it possible to compute the type of `output`
|
||||
# from graph definition?
|
||||
outputs={"output": output},
|
||||
)
|
||||
)
|
||||
|
||||
def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]:
|
||||
iterator_list_value = variable.to_object()
|
||||
|
||||
if not isinstance(iterator_list_value, list):
|
||||
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||
|
||||
return cast(list[object], iterator_list_value)
|
||||
|
||||
def _validate_start_node(self) -> None:
|
||||
if not self._node_data.start_node_id:
|
||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
||||
|
||||
def _execute_iterations(
|
||||
self,
|
||||
iterator_list_value: Sequence[object],
|
||||
outputs: list[object],
|
||||
iter_run_map: dict[str, float],
|
||||
usage_accumulator: list[LLMUsage],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
if self._node_data.is_parallel:
|
||||
# Parallel mode execution
|
||||
yield from self._execute_parallel_iterations(
|
||||
iterator_list_value=iterator_list_value,
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
usage_accumulator=usage_accumulator,
|
||||
)
|
||||
else:
|
||||
# Sequential mode execution
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
yield IterationNextEvent(index=index)
|
||||
|
||||
graph_engine = self._create_graph_engine(index, item)
|
||||
|
||||
# Run the iteration
|
||||
yield from self._run_single_iter(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||
outputs=outputs,
|
||||
graph_engine=graph_engine,
|
||||
)
|
||||
|
||||
# Sync conversation variables after each iteration completes
|
||||
self._sync_conversation_variables_from_snapshot(
|
||||
self._extract_conversation_variable_snapshot(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool
|
||||
)
|
||||
)
|
||||
|
||||
# Accumulate usage from this iteration
|
||||
usage_accumulator[0] = self._merge_usage(
|
||||
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
|
||||
)
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
def _execute_parallel_iterations(
|
||||
self,
|
||||
iterator_list_value: Sequence[object],
|
||||
outputs: list[object],
|
||||
iter_run_map: dict[str, float],
|
||||
usage_accumulator: list[LLMUsage],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
# Initialize outputs list with None values to maintain order
|
||||
outputs.extend([None] * len(iterator_list_value))
|
||||
|
||||
# Determine the number of parallel workers
|
||||
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all iteration tasks
|
||||
future_to_index: dict[
|
||||
Future[
|
||||
tuple[
|
||||
datetime,
|
||||
list[GraphNodeEventBase],
|
||||
object | None,
|
||||
dict[str, VariableUnion],
|
||||
LLMUsage,
|
||||
]
|
||||
],
|
||||
int,
|
||||
] = {}
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
yield IterationNextEvent(index=index)
|
||||
future = executor.submit(
|
||||
self._execute_single_iteration_parallel,
|
||||
index=index,
|
||||
item=item,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
future_to_index[future] = index
|
||||
|
||||
# Process completed iterations as they finish
|
||||
for future in as_completed(future_to_index):
|
||||
index = future_to_index[future]
|
||||
try:
|
||||
result = future.result()
|
||||
(
|
||||
iter_start_at,
|
||||
events,
|
||||
output_value,
|
||||
conversation_snapshot,
|
||||
iteration_usage,
|
||||
) = result
|
||||
|
||||
# Update outputs at the correct index
|
||||
outputs[index] = output_value
|
||||
|
||||
# Yield all events from this iteration
|
||||
yield from events
|
||||
|
||||
# Update tokens and timing
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
|
||||
|
||||
# Sync conversation variables after iteration completion
|
||||
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors based on error_handle_mode
|
||||
match self._node_data.error_handle_mode:
|
||||
case ErrorHandleMode.TERMINATED:
|
||||
# Cancel remaining futures and re-raise
|
||||
for f in future_to_index:
|
||||
if f != future:
|
||||
f.cancel()
|
||||
raise IterationNodeError(str(e))
|
||||
case ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
outputs[index] = None
|
||||
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs[index] = None # Will be filtered later
|
||||
|
||||
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
|
||||
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs[:] = [output for output in outputs if output is not None]
|
||||
|
||||
def _execute_single_iteration_parallel(
|
||||
self,
|
||||
index: int,
|
||||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
events: list[GraphNodeEventBase] = []
|
||||
outputs_temp: list[object] = []
|
||||
|
||||
graph_engine = self._create_graph_engine(index, item)
|
||||
|
||||
# Collect events instead of yielding them directly
|
||||
for event in self._run_single_iter(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||
outputs=outputs_temp,
|
||||
graph_engine=graph_engine,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Get the output value from the temporary outputs list
|
||||
output_value = outputs_temp[0] if outputs_temp else None
|
||||
conversation_snapshot = self._extract_conversation_variable_snapshot(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool
|
||||
)
|
||||
|
||||
return (
|
||||
iter_start_at,
|
||||
events,
|
||||
output_value,
|
||||
conversation_snapshot,
|
||||
graph_engine.graph_runtime_state.llm_usage,
|
||||
)
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
started_at: datetime,
|
||||
inputs: dict[str, Sequence[object]],
|
||||
outputs: list[object],
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
*,
|
||||
usage: LLMUsage,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
# Flatten the list of lists if all outputs are lists
|
||||
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
||||
|
||||
yield IterationSucceededEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": flattened_outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
)
|
||||
|
||||
# Yield final success event
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": flattened_outputs},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]:
|
||||
"""
|
||||
Flatten the outputs list if all elements are lists.
|
||||
This maintains backward compatibility with version 1.8.1 behavior.
|
||||
|
||||
If flatten_output is False, returns outputs as-is (nested structure).
|
||||
If flatten_output is True (default), flattens the list if all elements are lists.
|
||||
"""
|
||||
# If flatten_output is disabled, return outputs as-is
|
||||
if not self._node_data.flatten_output:
|
||||
return outputs
|
||||
|
||||
if not outputs:
|
||||
return outputs
|
||||
|
||||
# Check if all non-None outputs are lists
|
||||
non_none_outputs = [output for output in outputs if output is not None]
|
||||
if not non_none_outputs:
|
||||
return outputs
|
||||
|
||||
if all(isinstance(output, list) for output in non_none_outputs):
|
||||
# Flatten the list of lists
|
||||
flattened: list[Any] = []
|
||||
for output in outputs:
|
||||
if isinstance(output, list):
|
||||
flattened.extend(output)
|
||||
elif output is not None:
|
||||
# This shouldn't happen based on our check, but handle it gracefully
|
||||
flattened.append(output)
|
||||
return flattened
|
||||
|
||||
return outputs
|
||||
|
||||
def _handle_iteration_failure(
|
||||
self,
|
||||
started_at: datetime,
|
||||
inputs: dict[str, Sequence[object]],
|
||||
outputs: list[object],
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
*,
|
||||
usage: LLMUsage,
|
||||
error: IterationNodeError,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
# Flatten the list of lists if all outputs are lists (even in failure case)
|
||||
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
||||
|
||||
yield IterationFailedEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": flattened_outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
error=str(error),
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(error),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = IterationNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping: dict[str, Sequence[str]] = {
|
||||
f"{node_id}.input_selector": typed_node_data.iterator_selector,
|
||||
}
|
||||
iteration_node_ids = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("iteration_id") == node_id:
|
||||
in_iteration_node_id = node.get("id")
|
||||
if in_iteration_node_id:
|
||||
iteration_node_ids.add(in_iteration_node_id)
|
||||
|
||||
# Get node configs from graph_config instead of non-existent node_id_config_mapping
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
for sub_node_id, sub_node_config in node_configs.items():
|
||||
if sub_node_config.get("data", {}).get("iteration_id") != node_id:
|
||||
continue
|
||||
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
continue
|
||||
node_version = sub_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
sub_node_variable_mapping = {}
|
||||
|
||||
# remove iteration variables
|
||||
sub_node_variable_mapping = {
|
||||
sub_node_id + "." + key: value
|
||||
for key, value in sub_node_variable_mapping.items()
|
||||
if value[0] != node_id
|
||||
}
|
||||
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
# remove variable out from iteration
|
||||
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
|
||||
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
|
||||
|
||||
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
|
||||
parent_pool = self.graph_runtime_state.variable_pool
|
||||
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
|
||||
current_keys = set(parent_conversations.keys())
|
||||
snapshot_keys = set(snapshot.keys())
|
||||
|
||||
for removed_key in current_keys - snapshot_keys:
|
||||
parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key))
|
||||
|
||||
for name, variable in snapshot.items():
|
||||
parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable)
|
||||
|
||||
def _append_iteration_info_to_event(
|
||||
self,
|
||||
event: GraphNodeEventBase,
|
||||
iter_run_index: int,
|
||||
):
|
||||
event.in_iteration_id = self._node_id
|
||||
iter_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self._node_id,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index,
|
||||
}
|
||||
|
||||
current_metadata = event.node_run_result.metadata
|
||||
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
|
||||
event.node_run_result.metadata = {**current_metadata, **iter_metadata}
|
||||
|
||||
def _run_single_iter(
|
||||
self,
|
||||
*,
|
||||
variable_pool: VariablePool,
|
||||
outputs: list[object],
|
||||
graph_engine: "GraphEngine",
|
||||
) -> Generator[GraphNodeEventBase, None, None]:
|
||||
rst = graph_engine.run()
|
||||
# get current iteration index
|
||||
index_variable = variable_pool.get([self._node_id, "index"])
|
||||
if not isinstance(index_variable, IntegerVariable):
|
||||
raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found")
|
||||
current_index = index_variable.value
|
||||
for event in rst:
|
||||
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.ITERATION_START:
|
||||
continue
|
||||
|
||||
if isinstance(event, GraphNodeEventBase):
|
||||
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
|
||||
yield event
|
||||
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
|
||||
result = variable_pool.get(self._node_data.output_selector)
|
||||
if result is None:
|
||||
outputs.append(None)
|
||||
else:
|
||||
outputs.append(result.to_object())
|
||||
return
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
match self._node_data.error_handle_mode:
|
||||
case ErrorHandleMode.TERMINATED:
|
||||
raise IterationNodeError(event.error)
|
||||
case ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
outputs.append(None)
|
||||
return
|
||||
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
return
|
||||
|
||||
def _create_graph_engine(self, index: int, item: object):
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
graph_config=self.graph_config,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
invoke_from=self.invoke_from.value,
|
||||
call_depth=self.workflow_call_depth,
|
||||
)
|
||||
# Create a deep copy of the variable pool for each iteration
|
||||
variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True)
|
||||
|
||||
# append iteration variable (item, index) to variable pool
|
||||
variable_pool_copy.add([self._node_id, "index"], index)
|
||||
variable_pool_copy.add([self._node_id, "item"], item)
|
||||
|
||||
# Create a new GraphRuntimeState for this iteration
|
||||
graph_runtime_state_copy = GraphRuntimeState(
|
||||
variable_pool=variable_pool_copy,
|
||||
start_at=self.graph_runtime_state.start_at,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
|
||||
# Create a new node factory with the new GraphRuntimeState
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
|
||||
)
|
||||
|
||||
# Initialize the iteration graph with the new node factory
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
raise IterationGraphNotFoundError("iteration graph not found")
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph=iteration_graph,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
)
|
||||
|
||||
return graph_engine
|
||||
@@ -0,0 +1,49 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import IterationStartNodeData
|
||||
|
||||
|
||||
class IterationStartNode(Node):
|
||||
"""
|
||||
Iteration Start Node.
|
||||
"""
|
||||
|
||||
node_type = NodeType.ITERATION_START
|
||||
|
||||
_node_data: IterationStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationStartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
3
dify/api/core/workflow/nodes/knowledge_index/__init__.py
Normal file
3
dify/api/core/workflow/nodes/knowledge_index/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .knowledge_index_node import KnowledgeIndexNode
|
||||
|
||||
__all__ = ["KnowledgeIndexNode"]
|
||||
160
dify/api/core/workflow/nodes/knowledge_index/entities.py
Normal file
160
dify/api/core/workflow/nodes/knowledge_index/entities.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Reranking Model Config.
|
||||
"""
|
||||
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
"""
|
||||
Vector Setting.
|
||||
"""
|
||||
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
"""
|
||||
Keyword Setting.
|
||||
"""
|
||||
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
"""
|
||||
|
||||
vector_setting: VectorSetting
|
||||
keyword_setting: KeywordSetting
|
||||
|
||||
|
||||
class EmbeddingSetting(BaseModel):
|
||||
"""
|
||||
Embedding Setting.
|
||||
"""
|
||||
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class EconomySetting(BaseModel):
|
||||
"""
|
||||
Economy Setting.
|
||||
"""
|
||||
|
||||
keyword_number: int
|
||||
|
||||
|
||||
class RetrievalSetting(BaseModel):
|
||||
"""
|
||||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: RetrievalMethod
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
reranking_mode: str = "reranking_model"
|
||||
reranking_enable: bool = True
|
||||
reranking_model: RerankingModelConfig | None = None
|
||||
weights: WeightedScoreConfig | None = None
|
||||
|
||||
|
||||
class IndexMethod(BaseModel):
|
||||
"""
|
||||
Knowledge Index Setting.
|
||||
"""
|
||||
|
||||
indexing_technique: Literal["high_quality", "economy"]
|
||||
embedding_setting: EmbeddingSetting
|
||||
economy_setting: EconomySetting
|
||||
|
||||
|
||||
class FileInfo(BaseModel):
|
||||
"""
|
||||
File Info.
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
|
||||
|
||||
class OnlineDocumentIcon(BaseModel):
|
||||
"""
|
||||
Document Icon.
|
||||
"""
|
||||
|
||||
icon_url: str
|
||||
icon_type: str
|
||||
icon_emoji: str
|
||||
|
||||
|
||||
class OnlineDocumentInfo(BaseModel):
|
||||
"""
|
||||
Online document info.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
workspace_id: str | None = None
|
||||
page_id: str
|
||||
page_type: str
|
||||
icon: OnlineDocumentIcon | None = None
|
||||
|
||||
|
||||
class WebsiteInfo(BaseModel):
|
||||
"""
|
||||
website import info.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
url: str
|
||||
|
||||
|
||||
class GeneralStructureChunk(BaseModel):
|
||||
"""
|
||||
General Structure Chunk.
|
||||
"""
|
||||
|
||||
general_chunks: list[str]
|
||||
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
||||
|
||||
|
||||
class ParentChildChunk(BaseModel):
|
||||
"""
|
||||
Parent Child Chunk.
|
||||
"""
|
||||
|
||||
parent_content: str
|
||||
child_contents: list[str]
|
||||
|
||||
|
||||
class ParentChildStructureChunk(BaseModel):
|
||||
"""
|
||||
Parent Child Structure Chunk.
|
||||
"""
|
||||
|
||||
parent_child_chunks: list[ParentChildChunk]
|
||||
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
||||
|
||||
|
||||
class KnowledgeIndexNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge index Node Data.
|
||||
"""
|
||||
|
||||
type: str = "knowledge-index"
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
22
dify/api/core/workflow/nodes/knowledge_index/exc.py
Normal file
22
dify/api/core/workflow/nodes/knowledge_index/exc.py
Normal file
@@ -0,0 +1,22 @@
|
||||
class KnowledgeIndexNodeError(ValueError):
|
||||
"""Base class for KnowledgeIndexNode errors."""
|
||||
|
||||
|
||||
class ModelNotExistError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model does not exist."""
|
||||
|
||||
|
||||
class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model credentials are not initialized."""
|
||||
|
||||
|
||||
class ModelNotSupportedError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model is not supported."""
|
||||
|
||||
|
||||
class ModelQuotaExceededError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model provider quota is exceeded."""
|
||||
|
||||
|
||||
class InvalidModelTypeError(KnowledgeIndexNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
@@ -0,0 +1,214 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
KnowledgeIndexNodeError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeIndexNode(Node):
|
||||
_node_data: KnowledgeIndexNodeData
|
||||
node_type = NodeType.KNOWLEDGE_INDEX
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = KnowledgeIndexNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = self._node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||
if not dataset_id:
|
||||
raise KnowledgeIndexNodeError("Dataset ID is required.")
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first()
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.")
|
||||
|
||||
# extract variables
|
||||
variable = variable_pool.get(node_data.index_chunk_variable_selector)
|
||||
if not variable:
|
||||
raise KnowledgeIndexNodeError("Index chunk variable is required.")
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
if invoke_from:
|
||||
is_preview = invoke_from.value == InvokeFrom.DEBUGGER
|
||||
else:
|
||||
is_preview = False
|
||||
chunks = variable.value
|
||||
variables = {"chunks": chunks}
|
||||
if not chunks:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
|
||||
)
|
||||
|
||||
# index knowledge
|
||||
try:
|
||||
if is_preview:
|
||||
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
outputs=outputs,
|
||||
)
|
||||
results = self._invoke_knowledge_index(
|
||||
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
|
||||
)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=results)
|
||||
|
||||
except KnowledgeIndexNodeError as e:
|
||||
logger.warning("Error when running knowledge index node")
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
def _invoke_knowledge_index(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
node_data: KnowledgeIndexNodeData,
|
||||
chunks: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
) -> Any:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if not document_id:
|
||||
raise KnowledgeIndexNodeError("Document ID is required.")
|
||||
original_document_id = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID])
|
||||
|
||||
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
|
||||
if not batch:
|
||||
raise KnowledgeIndexNodeError("Batch is required.")
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if not document:
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
|
||||
doc_id_value = document.id
|
||||
ds_id_value = dataset.id
|
||||
dataset_name_value = dataset.name
|
||||
document_name_value = document.name
|
||||
created_at_value = document.created_at
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
||||
if original_document_id:
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == original_document_id.value)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
index_processor.index(dataset, document, chunks)
|
||||
indexing_end_at = time.perf_counter()
|
||||
document.indexing_latency = indexing_end_at - indexing_start_at
|
||||
# update document status
|
||||
document.indexing_status = "completed"
|
||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.word_count = (
|
||||
db.session.query(func.sum(DocumentSegment.word_count))
|
||||
.where(
|
||||
DocumentSegment.document_id == doc_id_value,
|
||||
DocumentSegment.dataset_id == ds_id_value,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
db.session.add(document)
|
||||
# update document segment status
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == doc_id_value,
|
||||
DocumentSegment.dataset_id == ds_id_value,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
"dataset_id": ds_id_value,
|
||||
"dataset_name": dataset_name_value,
|
||||
"batch": batch.value,
|
||||
"document_id": doc_id_value,
|
||||
"document_name": document_name_value,
|
||||
"created_at": created_at_value.timestamp(),
|
||||
"display_status": "completed",
|
||||
}
|
||||
|
||||
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
return index_processor.format_preview(chunks)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def get_streaming_template(self) -> Template:
|
||||
"""
|
||||
Get the template for streaming.
|
||||
|
||||
Returns:
|
||||
Template instance for this knowledge index node
|
||||
"""
|
||||
return Template(segments=[])
|
||||
@@ -0,0 +1,3 @@
|
||||
from .knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
|
||||
__all__ = ["KnowledgeRetrievalNode"]
|
||||
134
dify/api/core/workflow/nodes/knowledge_retrieval/entities.py
Normal file
134
dify/api/core/workflow/nodes/knowledge_retrieval/entities.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Reranking Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
"""
|
||||
Vector Setting.
|
||||
"""
|
||||
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
"""
|
||||
Keyword Setting.
|
||||
"""
|
||||
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
"""
|
||||
|
||||
vector_setting: VectorSetting
|
||||
keyword_setting: KeywordSetting
|
||||
|
||||
|
||||
class MultipleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
Multiple Retrieval Config.
|
||||
"""
|
||||
|
||||
top_k: int
|
||||
score_threshold: float | None = None
|
||||
reranking_mode: str = "reranking_model"
|
||||
reranking_enable: bool = True
|
||||
reranking_model: RerankingModelConfig | None = None
|
||||
weights: WeightedScoreConfig | None = None
|
||||
|
||||
|
||||
class SingleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
Single Retrieval Config.
|
||||
"""
|
||||
|
||||
model: ModelConfig
|
||||
|
||||
|
||||
SupportedComparisonOperator = Literal[
|
||||
# for string or array
|
||||
"contains",
|
||||
"not contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
">",
|
||||
"<",
|
||||
"≥",
|
||||
"≤",
|
||||
# for time
|
||||
"before",
|
||||
"after",
|
||||
]
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition detail
|
||||
"""
|
||||
|
||||
name: str
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None | int | float = None
|
||||
|
||||
|
||||
class MetadataFilteringCondition(BaseModel):
|
||||
"""
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
|
||||
type: str = "knowledge-retrieval"
|
||||
query_variable_selector: list[str]
|
||||
dataset_ids: list[str]
|
||||
retrieval_mode: Literal["single", "multiple"]
|
||||
multiple_retrieval_config: MultipleRetrievalConfig | None = None
|
||||
single_retrieval_config: SingleRetrievalConfig | None = None
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
|
||||
metadata_model_config: ModelConfig | None = None
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@property
|
||||
def structured_output_enabled(self) -> bool:
|
||||
# NOTE(QuantumGhost): Temporary workaround for issue #20725
|
||||
# (https://github.com/langgenius/dify/issues/20725).
|
||||
#
|
||||
# The proper fix would be to make `KnowledgeRetrievalNode` inherit
|
||||
# from `BaseNode` instead of `LLMNode`.
|
||||
return False
|
||||
22
dify/api/core/workflow/nodes/knowledge_retrieval/exc.py
Normal file
22
dify/api/core/workflow/nodes/knowledge_retrieval/exc.py
Normal file
@@ -0,0 +1,22 @@
|
||||
class KnowledgeRetrievalNodeError(ValueError):
|
||||
"""Base class for KnowledgeRetrievalNode errors."""
|
||||
|
||||
|
||||
class ModelNotExistError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the model does not exist."""
|
||||
|
||||
|
||||
class ModelCredentialsNotInitializedError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the model credentials are not initialized."""
|
||||
|
||||
|
||||
class ModelNotSupportedError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the model is not supported."""
|
||||
|
||||
|
||||
class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the model provider quota is exceeded."""
|
||||
|
||||
|
||||
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
@@ -0,0 +1,792 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables import (
|
||||
StringSegment,
|
||||
)
|
||||
from core.variables.segments import ArrayObjectSegment
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_2,
|
||||
METADATA_FILTER_COMPLETION_PROMPT,
|
||||
METADATA_FILTER_SYSTEM_PROMPT,
|
||||
METADATA_FILTER_USER_PROMPT_1,
|
||||
METADATA_FILTER_USER_PROMPT_2,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .entities import KnowledgeRetrievalNodeData
|
||||
from .exc import (
|
||||
InvalidModelTypeError,
|
||||
KnowledgeRetrievalNodeError,
|
||||
ModelCredentialsNotInitializedError,
|
||||
ModelNotExistError,
|
||||
ModelNotSupportedError,
|
||||
ModelQuotaExceededError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
|
||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
_node_data: KnowledgeRetrievalNodeData
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs = []
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# extract variables
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
||||
if not isinstance(variable, StringSegment):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error="Query variable is not string type.",
|
||||
)
|
||||
query = variable.value
|
||||
variables = {"query": query}
|
||||
if not query:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
|
||||
)
|
||||
# TODO(-LAN-): Move this check outside.
|
||||
# check rate limit
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{self.tenant_id}"
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=self.tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
session.add(rate_limit_log)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
|
||||
error_type="RateLimitExceeded",
|
||||
)
|
||||
|
||||
# retrieve knowledge
|
||||
usage = LLMUsage.empty_usage()
|
||||
try:
|
||||
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
|
||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data={"usage": jsonable_encoder(usage)},
|
||||
outputs=outputs, # type: ignore
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
except KnowledgeRetrievalNodeError as e:
|
||||
logger.warning("Error when running knowledge retrieval node")
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _fetch_dataset_retriever(
|
||||
self, node_data: KnowledgeRetrievalNodeData, query: str
|
||||
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
available_datasets = []
|
||||
dataset_ids = node_data.dataset_ids
|
||||
|
||||
# Subquery: Count the number of available documents for each dataset
|
||||
subquery = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
)
|
||||
.group_by(Document.dataset_id)
|
||||
.having(func.count(Document.id) > 0)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
results = (
|
||||
db.session.query(Dataset)
|
||||
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
|
||||
.where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
|
||||
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
|
||||
.all()
|
||||
)
|
||||
|
||||
# avoid blocking at retrieval
|
||||
db.session.close()
|
||||
|
||||
for dataset in results:
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
|
||||
[dataset.id for dataset in available_datasets], query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, metadata_usage)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
# fetch model config
|
||||
if node_data.single_retrieval_config is None:
|
||||
raise ValueError("single_retrieval_config is required")
|
||||
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model)
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model, credentials=model_config.credentials
|
||||
)
|
||||
|
||||
if model_schema:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
all_documents = dataset_retrieval.single_retrieve(
|
||||
available_datasets=available_datasets,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
user_from=self.user_from.value,
|
||||
query=query,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
planning_strategy=planning_strategy,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
reranking_model = {
|
||||
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
|
||||
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
|
||||
}
|
||||
else:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
|
||||
if node_data.multiple_retrieval_config.weights is None:
|
||||
raise ValueError("weights is required")
|
||||
reranking_model = None
|
||||
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
|
||||
weights = {
|
||||
"vector_setting": {
|
||||
"vector_weight": vector_setting.vector_weight,
|
||||
"embedding_provider_name": vector_setting.embedding_provider_name,
|
||||
"embedding_model_name": vector_setting.embedding_model_name,
|
||||
},
|
||||
"keyword_setting": {
|
||||
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
|
||||
},
|
||||
}
|
||||
else:
|
||||
reranking_model = None
|
||||
weights = None
|
||||
all_documents = dataset_retrieval.multiple_retrieve(
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
available_datasets=available_datasets,
|
||||
query=query,
|
||||
top_k=node_data.multiple_retrieval_config.top_k,
|
||||
score_threshold=node_data.multiple_retrieval_config.score_threshold
|
||||
if node_data.multiple_retrieval_config.score_threshold is not None
|
||||
else 0.0,
|
||||
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model=reranking_model,
|
||||
weights=weights,
|
||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
|
||||
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
retrieval_resource_list = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
source = {
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": "workflow",
|
||||
"score": item.metadata.get("score"),
|
||||
"doc_metadata": item.metadata,
|
||||
},
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
retrieval_resource_list.append(source)
|
||||
# deal with dify documents
|
||||
if dify_documents:
|
||||
records = RetrievalService.format_retrieval_documents(dify_documents)
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
|
||||
stmt = select(Document).where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
document = db.session.scalar(stmt)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": "workflow",
|
||||
"score": record.score or 0.0,
|
||||
"child_chunks": [
|
||||
{
|
||||
"id": str(getattr(chunk, "id", "")),
|
||||
"content": str(getattr(chunk, "content", "")),
|
||||
"position": int(getattr(chunk, "position", 0)),
|
||||
"score": float(getattr(chunk, "score", 0.0)),
|
||||
}
|
||||
for chunk in (record.child_chunks or [])
|
||||
],
|
||||
"segment_hit_count": segment.hit_count,
|
||||
"segment_word_count": segment.word_count,
|
||||
"segment_position": segment.position,
|
||||
"segment_index_node_hash": segment.index_node_hash,
|
||||
"doc_metadata": document.doc_metadata,
|
||||
},
|
||||
"title": document.name,
|
||||
}
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.get_sign_content()
|
||||
retrieval_resource_list.append(source)
|
||||
if retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item["metadata"]["position"] = position
|
||||
return retrieval_resource_list, usage
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
document_query = db.session.query(Document).where(
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
filters: list[Any] = []
|
||||
metadata_condition = None
|
||||
if node_data.metadata_filtering_mode == "disabled":
|
||||
return None, None, usage
|
||||
elif node_data.metadata_filtering_mode == "automatic":
|
||||
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, automatic_usage)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
self._process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
filter.get("value"),
|
||||
filters,
|
||||
)
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=filter.get("metadata_name"), # type: ignore
|
||||
comparison_operator=filter.get("condition"), # type: ignore
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
if (
|
||||
node_data.metadata_filtering_conditions
|
||||
and node_data.metadata_filtering_conditions.logical_operator == "and"
|
||||
):
|
||||
document_query = document_query.where(and_(*filters))
|
||||
else:
|
||||
document_query = document_query.where(or_(*filters))
|
||||
documents = document_query.all()
|
||||
# group by dataset_id
|
||||
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||
for document in documents:
|
||||
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
||||
return metadata_filter_document_ids, metadata_condition, usage
|
||||
|
||||
def _automatic_metadata_filter_func(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
# get all metadata field
|
||||
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
metadata_fields = db.session.scalars(stmt).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
if node_data.metadata_model_config is None:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance and fetch model config
|
||||
model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
|
||||
# fetch prompt messages
|
||||
prompt_template = self._get_prompt_template(
|
||||
node_data=node_data,
|
||||
metadata_fields=all_metadata_fields,
|
||||
query=query or "",
|
||||
)
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query=query,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
sys_files=[],
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=node_data.metadata_model_config,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
usage = self._merge_usage(usage, event.usage)
|
||||
break
|
||||
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
automatic_metadata_filters = []
|
||||
if "metadata_map" in result_text_json:
|
||||
metadata_map = result_text_json["metadata_map"]
|
||||
for item in metadata_map:
|
||||
if item.get("metadata_field_name") in all_metadata_fields:
|
||||
automatic_metadata_filters.append(
|
||||
{
|
||||
"metadata_name": item.get("metadata_field_name"),
|
||||
"value": item.get("metadata_field_value"),
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
return [], usage
|
||||
return automatic_metadata_filters, usage
|
||||
|
||||
def _process_metadata_filter_func(
|
||||
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
||||
) -> list[Any]:
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return filters
|
||||
|
||||
json_field = Document.doc_metadata[metadata_name].as_string()
|
||||
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
|
||||
case "start with":
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
|
||||
case "end with":
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
case "in":
|
||||
if isinstance(value, str):
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(False))
|
||||
else:
|
||||
filters.append(json_field.in_(value_list))
|
||||
|
||||
case "not in":
|
||||
if isinstance(value, str):
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(True))
|
||||
else:
|
||||
filters.append(json_field.notin_(value_list))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
filters.append(json_field == value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
|
||||
|
||||
case "is not" | "≠":
|
||||
if isinstance(value, str):
|
||||
filters.append(json_field != value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
|
||||
|
||||
case "empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].is_(None))
|
||||
|
||||
case "not empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].isnot(None))
|
||||
|
||||
case "before" | "<":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
|
||||
|
||||
case "after" | ">":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
|
||||
|
||||
case "≤" | "<=":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
|
||||
|
||||
case "≥" | ">=":
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
return filters
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
|
||||
return variable_mapping
|
||||
|
||||
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model_name = model.name
|
||||
provider_name = model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = model.completion_params
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = model.mode
|
||||
if not model_mode:
|
||||
raise ModelNotExistError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
|
||||
model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore
|
||||
input_text = query
|
||||
|
||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
|
||||
)
|
||||
prompt_messages.append(system_prompt_messages)
|
||||
user_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_1)
|
||||
assistant_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_1)
|
||||
user_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_2)
|
||||
assistant_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_2)
|
||||
user_prompt_message_3 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=METADATA_FILTER_USER_PROMPT_3.format(
|
||||
input_text=input_text,
|
||||
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_3)
|
||||
return prompt_messages
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return LLMNodeCompletionModelPromptTemplate(
|
||||
text=METADATA_FILTER_COMPLETION_PROMPT.format(
|
||||
input_text=input_text,
|
||||
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")
|
||||
@@ -0,0 +1,66 @@
|
||||
METADATA_FILTER_SYSTEM_PROMPT = """
|
||||
### Job Description',
|
||||
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
|
||||
### Task
|
||||
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤", "before", "after"] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||
### Format
|
||||
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
""" # noqa: E501
|
||||
|
||||
METADATA_FILTER_USER_PROMPT_1 = """
|
||||
{ "input_text": "I want to know which company’s email address test@example.com is?",
|
||||
"metadata_fields": ["filename", "email", "phone", "address"]
|
||||
}
|
||||
"""
|
||||
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_1 = """
|
||||
```json
|
||||
{"metadata_map": [
|
||||
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
|
||||
]
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
METADATA_FILTER_USER_PROMPT_2 = """
|
||||
{"input_text": "What are the movies with a score of more than 9 in 2024?",
|
||||
"metadata_fields": ["name", "year", "rating", "country"]}
|
||||
"""
|
||||
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_2 = """
|
||||
```json
|
||||
{"metadata_map": [
|
||||
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
|
||||
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
|
||||
]}
|
||||
```
|
||||
"""
|
||||
|
||||
METADATA_FILTER_USER_PROMPT_3 = """
|
||||
'{{"input_text": "{input_text}",',
|
||||
'"metadata_fields": {metadata_fields}}}'
|
||||
"""
|
||||
|
||||
METADATA_FILTER_COMPLETION_PROMPT = """
|
||||
### Job Description
|
||||
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
|
||||
### Task
|
||||
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||
### Format
|
||||
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
### Example
|
||||
Here is the chat example between human and assistant, inside <example></example> XML tags.
|
||||
<example>
|
||||
User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
|
||||
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
|
||||
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
|
||||
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
|
||||
</example>
|
||||
### User Input
|
||||
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
|
||||
### Assistant Output
|
||||
""" # noqa: E501
|
||||
3
dify/api/core/workflow/nodes/list_operator/__init__.py
Normal file
3
dify/api/core/workflow/nodes/list_operator/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .node import ListOperatorNode
|
||||
|
||||
__all__ = ["ListOperatorNode"]
|
||||
69
dify/api/core/workflow/nodes/list_operator/entities.py
Normal file
69
dify/api/core/workflow/nodes/list_operator/entities.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class FilterOperator(StrEnum):
|
||||
# string conditions
|
||||
CONTAINS = "contains"
|
||||
START_WITH = "start with"
|
||||
END_WITH = "end with"
|
||||
IS = "is"
|
||||
IN = "in"
|
||||
EMPTY = "empty"
|
||||
NOT_CONTAINS = "not contains"
|
||||
IS_NOT = "is not"
|
||||
NOT_IN = "not in"
|
||||
NOT_EMPTY = "not empty"
|
||||
# number conditions
|
||||
EQUAL = "="
|
||||
NOT_EQUAL = "≠"
|
||||
LESS_THAN = "<"
|
||||
GREATER_THAN = ">"
|
||||
GREATER_THAN_OR_EQUAL = "≥"
|
||||
LESS_THAN_OR_EQUAL = "≤"
|
||||
|
||||
|
||||
class Order(StrEnum):
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class FilterCondition(BaseModel):
|
||||
key: str = ""
|
||||
comparison_operator: FilterOperator = FilterOperator.CONTAINS
|
||||
# the value is bool if the filter operator is comparing with
|
||||
# a boolean constant.
|
||||
value: str | Sequence[str] | bool = ""
|
||||
|
||||
|
||||
class FilterBy(BaseModel):
|
||||
enabled: bool = False
|
||||
conditions: Sequence[FilterCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OrderByConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
key: str = ""
|
||||
value: Order = Order.ASC
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
enabled: bool = False
|
||||
size: int = -1
|
||||
|
||||
|
||||
class ExtractConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
serial: str = "1"
|
||||
|
||||
|
||||
class ListOperatorNodeData(BaseNodeData):
|
||||
variable: Sequence[str] = Field(default_factory=list)
|
||||
filter_by: FilterBy
|
||||
order_by: OrderByConfig
|
||||
limit: Limit
|
||||
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)
|
||||
16
dify/api/core/workflow/nodes/list_operator/exc.py
Normal file
16
dify/api/core/workflow/nodes/list_operator/exc.py
Normal file
@@ -0,0 +1,16 @@
|
||||
class ListOperatorError(ValueError):
|
||||
"""Base class for all ListOperator errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidFilterValueError(ListOperatorError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidKeyError(ListOperatorError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidConditionError(ListOperatorError):
|
||||
pass
|
||||
370
dify/api/core/workflow/nodes/list_operator/node.py
Normal file
370
dify/api/core/workflow/nodes/list_operator/node.py
Normal file
@@ -0,0 +1,370 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, TypeAlias, TypeVar
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import FilterOperator, ListOperatorNodeData, Order
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
|
||||
_SUPPORTED_TYPES_TUPLE = (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayStringSegment,
|
||||
ArrayBooleanSegment,
|
||||
)
|
||||
_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
|
||||
"""Returns the negation of a given filter function. If the original filter
|
||||
returns `True` for a value, the negated filter will return `False`, and vice versa.
|
||||
"""
|
||||
|
||||
def wrapper(value: _T) -> bool:
|
||||
return not filter_(value)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ListOperatorNode(Node):
|
||||
node_type = NodeType.LIST_OPERATOR
|
||||
|
||||
_node_data: ListOperatorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ListOperatorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
inputs: dict[str, Sequence[object]] = {}
|
||||
process_data: dict[str, Sequence[object]] = {}
|
||||
outputs: dict[str, Any] = {}
|
||||
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
|
||||
if variable is None:
|
||||
error_message = f"Variable not found for selector: {self._node_data.variable}"
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
if not variable.value:
|
||||
inputs = {"variable": []}
|
||||
process_data = {"variable": []}
|
||||
if isinstance(variable, ArraySegment):
|
||||
result = variable.model_copy(update={"value": []})
|
||||
else:
|
||||
result = ArrayAnySegment(value=[])
|
||||
outputs = {"result": result, "first_record": None, "last_record": None}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
|
||||
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
inputs = {"variable": [item.to_dict() for item in variable.value]}
|
||||
process_data["variable"] = [item.to_dict() for item in variable.value]
|
||||
else:
|
||||
inputs = {"variable": variable.value}
|
||||
process_data["variable"] = variable.value
|
||||
|
||||
try:
|
||||
# Filter
|
||||
if self._node_data.filter_by.enabled:
|
||||
variable = self._apply_filter(variable)
|
||||
|
||||
# Extract
|
||||
if self._node_data.extract_by.enabled:
|
||||
variable = self._extract_slice(variable)
|
||||
|
||||
# Order
|
||||
if self._node_data.order_by.enabled:
|
||||
variable = self._apply_order(variable)
|
||||
|
||||
# Slice
|
||||
if self._node_data.limit.enabled:
|
||||
variable = self._apply_slice(variable)
|
||||
|
||||
outputs = {
|
||||
"result": variable,
|
||||
"first_record": variable.value[0] if variable.value else None,
|
||||
"last_record": variable.value[-1] if variable.value else None,
|
||||
}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
except ListOperatorError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
filter_func: Callable[[Any], bool]
|
||||
result: list[Any] = []
|
||||
for condition in self._node_data.filter_by.conditions:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
if not isinstance(condition.value, str):
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayNumberSegment):
|
||||
if not isinstance(condition.value, str):
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
if isinstance(condition.value, str):
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
elif isinstance(condition.value, bool):
|
||||
raise ValueError(f"File filter expects a string value, got {type(condition.value)}")
|
||||
else:
|
||||
value = condition.value
|
||||
filter_func = _get_file_filter_func(
|
||||
key=condition.key,
|
||||
condition=condition.comparison_operator,
|
||||
value=value,
|
||||
)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
if not isinstance(condition.value, bool):
|
||||
raise ValueError(f"Boolean filter expects a boolean value, got {type(condition.value)}")
|
||||
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
return variable
|
||||
|
||||
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
|
||||
result = sorted(variable.value, reverse=self._node_data.order_by.value == Order.DESC)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
result = _order_file(
|
||||
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
|
||||
)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
|
||||
return variable
|
||||
|
||||
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
result = variable.value[: self._node_data.limit.size]
|
||||
return variable.model_copy(update={"value": result})
|
||||
|
||||
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
|
||||
if value < 1:
|
||||
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
|
||||
if value > len(variable.value):
|
||||
raise InvalidKeyError(f"Invalid serial index: must be <= {len(variable.value)}, got {value}")
|
||||
value -= 1
|
||||
result = variable.value[value]
|
||||
return variable.model_copy(update={"value": [result]})
|
||||
|
||||
|
||||
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
|
||||
match key:
|
||||
case "size":
|
||||
return lambda x: x.size
|
||||
case _:
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
|
||||
match key:
|
||||
case "name":
|
||||
return lambda x: x.filename or ""
|
||||
case "type":
|
||||
return lambda x: x.type
|
||||
case "extension":
|
||||
return lambda x: x.extension or ""
|
||||
case "mime_type":
|
||||
return lambda x: x.mime_type or ""
|
||||
case "transfer_method":
|
||||
return lambda x: x.transfer_method
|
||||
case "url":
|
||||
return lambda x: x.remote_url or ""
|
||||
case "related_id":
|
||||
return lambda x: x.related_id or ""
|
||||
case _:
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
|
||||
match condition:
|
||||
case "contains":
|
||||
return _contains(value)
|
||||
case "start with":
|
||||
return _startswith(value)
|
||||
case "end with":
|
||||
return _endswith(value)
|
||||
case "is":
|
||||
return _is(value)
|
||||
case "in":
|
||||
return _in(value)
|
||||
case "empty":
|
||||
return lambda x: x == ""
|
||||
case "not contains":
|
||||
return _negation(_contains(value))
|
||||
case "is not":
|
||||
return _negation(_is(value))
|
||||
case "not in":
|
||||
return _negation(_in(value))
|
||||
case "not empty":
|
||||
return lambda x: x != ""
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
|
||||
match condition:
|
||||
case "in":
|
||||
return _in(value)
|
||||
case "not in":
|
||||
return _negation(_in(value))
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
|
||||
match condition:
|
||||
case "=":
|
||||
return _eq(value)
|
||||
case "≠":
|
||||
return _ne(value)
|
||||
case "<":
|
||||
return _lt(value)
|
||||
case "≤":
|
||||
return _le(value)
|
||||
case ">":
|
||||
return _gt(value)
|
||||
case "≥":
|
||||
return _ge(value)
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]:
|
||||
match condition:
|
||||
case FilterOperator.IS:
|
||||
return _is(value)
|
||||
case FilterOperator.IS_NOT:
|
||||
return _negation(_is(value))
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
extract_func: Callable[[File], Any]
|
||||
if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str):
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
if key in {"type", "transfer_method"}:
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
elif key == "size" and isinstance(value, str):
|
||||
extract_func = _get_file_extract_number_func(key=key)
|
||||
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
|
||||
else:
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
def _contains(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: value in x
|
||||
|
||||
|
||||
def _startswith(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: x.startswith(value)
|
||||
|
||||
|
||||
def _endswith(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: x.endswith(value)
|
||||
|
||||
|
||||
def _is(value: _T) -> Callable[[_T], bool]:
|
||||
return lambda x: x == value
|
||||
|
||||
|
||||
def _in(value: str | Sequence[str]) -> Callable[[str], bool]:
|
||||
return lambda x: x in value
|
||||
|
||||
|
||||
def _eq(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x == value
|
||||
|
||||
|
||||
def _ne(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x != value
|
||||
|
||||
|
||||
def _lt(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x < value
|
||||
|
||||
|
||||
def _le(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x <= value
|
||||
|
||||
|
||||
def _gt(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x > value
|
||||
|
||||
|
||||
def _ge(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x >= value
|
||||
|
||||
|
||||
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
|
||||
extract_func: Callable[[File], Any]
|
||||
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url", "related_id"}:
|
||||
extract_func = _get_file_extract_string_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
||||
elif order_by == "size":
|
||||
extract_func = _get_file_extract_number_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
||||
else:
|
||||
raise InvalidKeyError(f"Invalid order key: {order_by}")
|
||||
17
dify/api/core/workflow/nodes/llm/__init__.py
Normal file
17
dify/api/core/workflow/nodes/llm/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from .node import LLMNode
|
||||
|
||||
__all__ = [
|
||||
"LLMNode",
|
||||
"LLMNodeChatModelMessage",
|
||||
"LLMNodeCompletionModelPromptTemplate",
|
||||
"LLMNodeData",
|
||||
"ModelConfig",
|
||||
"VisionConfig",
|
||||
]
|
||||
98
dify/api/core/workflow/nodes/llm/entities.py
Normal file
98
dify/api/core/workflow/nodes/llm/entities.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
provider: str
|
||||
name: str
|
||||
mode: LLMMode
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
enabled: bool
|
||||
variable_selector: list[str] | None = None
|
||||
|
||||
|
||||
class VisionConfigOptions(BaseModel):
|
||||
variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"])
|
||||
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
|
||||
class VisionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions)
|
||||
|
||||
@field_validator("configs", mode="before")
|
||||
@classmethod
|
||||
def convert_none_configs(cls, v: Any):
|
||||
if v is None:
|
||||
return VisionConfigOptions()
|
||||
return v
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list)
|
||||
|
||||
@field_validator("jinja2_variables", mode="before")
|
||||
@classmethod
|
||||
def convert_none_jinja2_variables(cls, v: Any):
|
||||
if v is None:
|
||||
return []
|
||||
return v
|
||||
|
||||
|
||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||
text: str = ""
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
reasoning_format: Literal["separated", "tagged"] = Field(
|
||||
# Keep tagged as default for backward compatibility
|
||||
default="tagged",
|
||||
description=(
|
||||
"""
|
||||
Strategy for handling model reasoning output.
|
||||
|
||||
separated: Return clean text (without <think> tags) + reasoning_content field.
|
||||
Recommended for new workflows. Enables safe downstream parsing and
|
||||
workflow variable access: {{#node_id.reasoning_content#}}
|
||||
|
||||
tagged : Return original text (with <think> tags) + reasoning_content field.
|
||||
Maintains full backward compatibility while still providing reasoning_content
|
||||
for workflow automation. Frontend thinking panels work as before.
|
||||
"""
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
def convert_none_prompt_config(cls, v: Any):
|
||||
if v is None:
|
||||
return PromptConfig()
|
||||
return v
|
||||
|
||||
@property
|
||||
def structured_output_enabled(self) -> bool:
|
||||
return self.structured_output_switch_on and self.structured_output is not None
|
||||
45
dify/api/core/workflow/nodes/llm/exc.py
Normal file
45
dify/api/core/workflow/nodes/llm/exc.py
Normal file
@@ -0,0 +1,45 @@
|
||||
class LLMNodeError(ValueError):
|
||||
"""Base class for LLM Node errors."""
|
||||
|
||||
|
||||
class VariableNotFoundError(LLMNodeError):
|
||||
"""Raised when a required variable is not found."""
|
||||
|
||||
|
||||
class InvalidContextStructureError(LLMNodeError):
|
||||
"""Raised when the context structure is invalid."""
|
||||
|
||||
|
||||
class InvalidVariableTypeError(LLMNodeError):
|
||||
"""Raised when the variable type is invalid."""
|
||||
|
||||
|
||||
class ModelNotExistError(LLMNodeError):
|
||||
"""Raised when the specified model does not exist."""
|
||||
|
||||
|
||||
class LLMModeRequiredError(LLMNodeError):
|
||||
"""Raised when LLM mode is required but not provided."""
|
||||
|
||||
|
||||
class NoPromptFoundError(LLMNodeError):
|
||||
"""Raised when no prompt is found in the LLM configuration."""
|
||||
|
||||
|
||||
class TemplateTypeNotSupportError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str):
|
||||
super().__init__(f"Prompt type {type_name} is not supported.")
|
||||
|
||||
|
||||
class MemoryRolePrefixRequiredError(LLMNodeError):
|
||||
"""Raised when memory role prefix is required for completion model."""
|
||||
|
||||
|
||||
class FileTypeNotSupportError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str):
|
||||
super().__init__(f"{type_name} type is not supported by this model")
|
||||
|
||||
|
||||
class UnsupportedPromptContentTypeError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str):
|
||||
super().__init__(f"Prompt content type {type_name} is not supported.")
|
||||
157
dify/api/core/workflow/nodes/llm/file_saver.py
Normal file
157
dify/api/core/workflow/nodes/llm/file_saver.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import mimetypes
|
||||
import typing as tp
|
||||
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from extensions.ext_database import db as global_db
|
||||
|
||||
|
||||
class LLMFileSaver(tp.Protocol):
|
||||
"""LLMFileSaver is responsible for save multimodal output returned by
|
||||
LLM.
|
||||
"""
|
||||
|
||||
def save_binary_string(
|
||||
self,
|
||||
data: bytes,
|
||||
mime_type: str,
|
||||
file_type: FileType,
|
||||
extension_override: str | None = None,
|
||||
) -> File:
|
||||
"""save_binary_string saves the inline file data returned by LLM.
|
||||
|
||||
Currently (2025-04-30), only some of Google Gemini models will return
|
||||
multimodal output as inline data.
|
||||
|
||||
:param data: the contents of the file
|
||||
:param mime_type: the media type of the file, specified by rfc6838
|
||||
(https://datatracker.ietf.org/doc/html/rfc6838)
|
||||
:param file_type: The file type of the inline file.
|
||||
:param extension_override: Override the auto-detected file extension while saving this file.
|
||||
|
||||
The default value is `None`, which means do not override the file extension and guessing it
|
||||
from the `mime_type` attribute while saving the file.
|
||||
|
||||
Setting it to values other than `None` means override the file's extension, and
|
||||
will bypass the extension guessing saving the file.
|
||||
|
||||
Specially, setting it to empty string (`""`) will leave the file extension empty.
|
||||
|
||||
When it is not `None` or empty string (`""`), it should be a string beginning with a
|
||||
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
|
||||
and `tar.gz` are not.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
"""save_remote_url saves the file from a remote url returned by LLM.
|
||||
|
||||
Currently (2025-04-30), no model returns multimodel output as a url.
|
||||
|
||||
:param url: the url of the file.
|
||||
:param file_type: the file type of the file, check `FileType` enum for reference.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
|
||||
|
||||
|
||||
class FileSaverImpl(LLMFileSaver):
|
||||
_engine_factory: EngineFactory
|
||||
_tenant_id: str
|
||||
_user_id: str
|
||||
|
||||
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
|
||||
if engine_factory is None:
|
||||
|
||||
def _factory():
|
||||
return global_db.engine
|
||||
|
||||
engine_factory = _factory
|
||||
self._engine_factory = engine_factory
|
||||
self._user_id = user_id
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
def _get_tool_file_manager(self):
|
||||
return ToolFileManager(engine=self._engine_factory())
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
http_response = ssrf_proxy.get(url)
|
||||
http_response.raise_for_status()
|
||||
data = http_response.content
|
||||
mime_type_from_header = http_response.headers.get("Content-Type")
|
||||
mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
|
||||
return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
|
||||
|
||||
def save_binary_string(
|
||||
self,
|
||||
data: bytes,
|
||||
mime_type: str,
|
||||
file_type: FileType,
|
||||
extension_override: str | None = None,
|
||||
) -> File:
|
||||
tool_file_manager = self._get_tool_file_manager()
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self._user_id,
|
||||
tenant_id=self._tenant_id,
|
||||
# TODO(QuantumGhost): what is conversation id?
|
||||
conversation_id=None,
|
||||
file_binary=data,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
extension_override = _validate_extension_override(extension_override)
|
||||
extension = _get_extension(mime_type, extension_override)
|
||||
url = sign_tool_file(tool_file.id, extension)
|
||||
|
||||
return File(
|
||||
tenant_id=self._tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename=tool_file.name,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=len(data),
|
||||
related_id=tool_file.id,
|
||||
url=url,
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
|
||||
def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
|
||||
"""get_extension return the extension of file.
|
||||
|
||||
If the `extension_override` parameter is set, this function should honor it and
|
||||
return its value.
|
||||
"""
|
||||
if extension_override is not None:
|
||||
return extension_override
|
||||
return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
|
||||
|
||||
|
||||
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
|
||||
"""_extract_content_type_and_extension tries to
|
||||
guess content type of file from url and `Content-Type` header in response.
|
||||
"""
|
||||
if content_type_header:
|
||||
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
|
||||
return content_type_header, extension
|
||||
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
|
||||
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
|
||||
return content_type, extension
|
||||
|
||||
|
||||
def _validate_extension_override(extension_override: str | None) -> str | None:
|
||||
# `extension_override` is allow to be `None or `""`.
|
||||
if extension_override is None:
|
||||
return None
|
||||
if extension_override == "":
|
||||
return ""
|
||||
if not extension_override.startswith("."):
|
||||
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
|
||||
return extension_override
|
||||
156
dify/api/core/workflow/nodes/llm/llm_utils.py
Normal file
156
dify/api/core/workflow/nodes/llm/llm_utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.file.models import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
|
||||
|
||||
|
||||
def fetch_model_config(
|
||||
tenant_id: str, node_data_model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
model = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
)
|
||||
|
||||
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
|
||||
|
||||
# check model
|
||||
provider_model = model.provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
# model config
|
||||
stop: list[str] = []
|
||||
if "stop" in node_data_model.completion_params:
|
||||
stop = node_data_model.completion_params.pop("stop")
|
||||
|
||||
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
return model, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=model.provider_model_bundle,
|
||||
credentials=model.credentials,
|
||||
parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
|
||||
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
|
||||
variable = variable_pool.get(selector)
|
||||
if variable is None:
|
||||
return []
|
||||
elif isinstance(variable, FileSegment):
|
||||
return [variable.value]
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
return variable.value
|
||||
elif isinstance(variable, NoneSegment | ArrayAnySegment):
|
||||
return []
|
||||
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
|
||||
) -> TokenBufferMemory | None:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
return memory
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
1371
dify/api/core/workflow/nodes/llm/node.py
Normal file
1371
dify/api/core/workflow/nodes/llm/node.py
Normal file
File diff suppressed because it is too large
Load Diff
6
dify/api/core/workflow/nodes/loop/__init__.py
Normal file
6
dify/api/core/workflow/nodes/loop/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .entities import LoopNodeData
|
||||
from .loop_end_node import LoopEndNode
|
||||
from .loop_node import LoopNode
|
||||
from .loop_start_node import LoopStartNode
|
||||
|
||||
__all__ = ["LoopEndNode", "LoopNode", "LoopNodeData", "LoopStartNode"]
|
||||
98
dify/api/core/workflow/nodes/loop/entities.py
Normal file
98
dify/api/core/workflow/nodes/loop/entities.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
_VALID_VAR_TYPE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
|
||||
if seg_type not in _VALID_VAR_TYPE:
|
||||
raise ValueError(...)
|
||||
return seg_type
|
||||
|
||||
|
||||
class LoopVariableData(BaseModel):
|
||||
"""
|
||||
Loop Variable Data.
|
||||
"""
|
||||
|
||||
label: str
|
||||
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
||||
value_type: Literal["variable", "constant"]
|
||||
value: Any | list[str] | None = None
|
||||
|
||||
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
loop_count: int # Maximum number of loops
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("outputs", mode="before")
|
||||
@classmethod
|
||||
def validate_outputs(cls, v):
|
||||
if v is None:
|
||||
return {}
|
||||
return v
|
||||
|
||||
|
||||
class LoopStartNodeData(BaseNodeData):
|
||||
"""
|
||||
Loop Start Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LoopEndNodeData(BaseNodeData):
|
||||
"""
|
||||
Loop End Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LoopState(BaseLoopState):
|
||||
"""
|
||||
Loop State.
|
||||
"""
|
||||
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Any = None
|
||||
|
||||
class MetaData(BaseLoopState.MetaData):
|
||||
"""
|
||||
Data.
|
||||
"""
|
||||
|
||||
loop_length: int
|
||||
|
||||
def get_last_output(self) -> Any:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
if self.outputs:
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Any:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
return self.current_output
|
||||
49
dify/api/core/workflow/nodes/loop/loop_end_node.py
Normal file
49
dify/api/core/workflow/nodes/loop/loop_end_node.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopEndNodeData
|
||||
|
||||
|
||||
class LoopEndNode(Node):
|
||||
"""
|
||||
Loop End Node.
|
||||
"""
|
||||
|
||||
node_type = NodeType.LOOP_END
|
||||
|
||||
_node_data: LoopEndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopEndNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
463
dify/api/core/workflow/nodes/loop/loop_node.py
Normal file
463
dify/api/core/workflow/nodes/loop/loop_node.py
Normal file
@@ -0,0 +1,463 @@
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import Segment, SegmentType
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
GraphRunFailedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
LoopFailedEvent,
|
||||
LoopNextEvent,
|
||||
LoopStartedEvent,
|
||||
LoopSucceededEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopNode(LLMUsageTrackingMixin, Node):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
|
||||
node_type = NodeType.LOOP
|
||||
_node_data: LoopNodeData
|
||||
execution_type = NodeExecutionType.CONTAINER
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Run the node."""
|
||||
# Get inputs
|
||||
loop_count = self._node_data.loop_count
|
||||
break_conditions = self._node_data.break_conditions
|
||||
logical_operator = self._node_data.logical_operator
|
||||
|
||||
inputs = {"loop_count": loop_count}
|
||||
|
||||
if not self._node_data.start_node_id:
|
||||
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
|
||||
|
||||
root_node_id = self._node_data.start_node_id
|
||||
|
||||
# Initialize loop variables in the original variable pool
|
||||
loop_variable_selectors = {}
|
||||
if self._node_data.loop_variables:
|
||||
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
|
||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
|
||||
if isinstance(var.value, list)
|
||||
else None,
|
||||
}
|
||||
for loop_variable in self._node_data.loop_variables:
|
||||
if loop_variable.value_type not in value_processor:
|
||||
raise ValueError(
|
||||
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
|
||||
)
|
||||
|
||||
processed_segment = value_processor[loop_variable.value_type](loop_variable)
|
||||
if not processed_segment:
|
||||
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
|
||||
variable_selector = [self._node_id, loop_variable.label]
|
||||
variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
|
||||
self.graph_runtime_state.variable_pool.add(variable_selector, variable.value)
|
||||
loop_variable_selectors[loop_variable.label] = variable_selector
|
||||
inputs[loop_variable.label] = processed_segment.value
|
||||
|
||||
start_at = naive_utc_now()
|
||||
condition_processor = ConditionProcessor()
|
||||
|
||||
loop_duration_map: dict[str, float] = {}
|
||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
||||
loop_usage = LLMUsage.empty_usage()
|
||||
|
||||
# Start Loop event
|
||||
yield LoopStartedEvent(
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
metadata={"loop_length": loop_count},
|
||||
)
|
||||
|
||||
try:
|
||||
reach_break_condition = False
|
||||
if break_conditions:
|
||||
with contextlib.suppress(ValueError):
|
||||
_, _, reach_break_condition = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
|
||||
if reach_break_condition:
|
||||
loop_count = 0
|
||||
|
||||
for i in range(loop_count):
|
||||
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
|
||||
|
||||
loop_start_time = naive_utc_now()
|
||||
reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i)
|
||||
# Track loop duration
|
||||
loop_duration_map[str(i)] = (naive_utc_now() - loop_start_time).total_seconds()
|
||||
|
||||
# Accumulate outputs from the sub-graph's response nodes
|
||||
for key, value in graph_engine.graph_runtime_state.outputs.items():
|
||||
if key == "answer":
|
||||
# Concatenate answer outputs with newline
|
||||
existing_answer = self.graph_runtime_state.get_output("answer", "")
|
||||
if existing_answer:
|
||||
self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}")
|
||||
else:
|
||||
self.graph_runtime_state.set_output("answer", value)
|
||||
else:
|
||||
# For other outputs, just update
|
||||
self.graph_runtime_state.set_output(key, value)
|
||||
|
||||
# Accumulate usage from the sub-graph execution
|
||||
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
||||
|
||||
# Collect loop variable values after iteration
|
||||
single_loop_variable = {}
|
||||
for key, selector in loop_variable_selectors.items():
|
||||
segment = self.graph_runtime_state.variable_pool.get(selector)
|
||||
single_loop_variable[key] = segment.value if segment else None
|
||||
|
||||
single_loop_variable_map[str(i)] = single_loop_variable
|
||||
|
||||
if reach_break_node:
|
||||
break
|
||||
|
||||
if break_conditions:
|
||||
_, _, reach_break_condition = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
if reach_break_condition:
|
||||
break
|
||||
|
||||
yield LoopNextEvent(
|
||||
index=i + 1,
|
||||
pre_loop_output=self._node_data.outputs,
|
||||
)
|
||||
|
||||
self._accumulate_usage(loop_usage)
|
||||
# Loop completed successfully
|
||||
yield LoopSucceededEvent(
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs=self._node_data.outputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
outputs=self._node_data.outputs,
|
||||
inputs=inputs,
|
||||
llm_usage=loop_usage,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._accumulate_usage(loop_usage)
|
||||
yield LoopFailedEvent(
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
"completed_reason": "error",
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
llm_usage=loop_usage,
|
||||
)
|
||||
)
|
||||
|
||||
def _run_single_loop(
|
||||
self,
|
||||
*,
|
||||
graph_engine: "GraphEngine",
|
||||
current_index: int,
|
||||
) -> Generator[NodeEventBase | GraphNodeEventBase, None, bool]:
|
||||
reach_break_node = False
|
||||
for event in graph_engine.run():
|
||||
if isinstance(event, GraphNodeEventBase):
|
||||
self._append_loop_info_to_event(event=event, loop_run_index=current_index)
|
||||
|
||||
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.LOOP_START:
|
||||
continue
|
||||
if isinstance(event, GraphNodeEventBase):
|
||||
yield event
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_type == NodeType.LOOP_END:
|
||||
reach_break_node = True
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
raise Exception(event.error)
|
||||
|
||||
for loop_var in self._node_data.loop_variables or []:
|
||||
key, sel = loop_var.label, [self._node_id, loop_var.label]
|
||||
segment = self.graph_runtime_state.variable_pool.get(sel)
|
||||
self._node_data.outputs[key] = segment.value if segment else None
|
||||
self._node_data.outputs["loop_round"] = current_index + 1
|
||||
|
||||
return reach_break_node
|
||||
|
||||
def _append_loop_info_to_event(
|
||||
self,
|
||||
event: GraphNodeEventBase,
|
||||
loop_run_index: int,
|
||||
):
|
||||
event.in_loop_id = self._node_id
|
||||
loop_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: self._node_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: loop_run_index,
|
||||
}
|
||||
|
||||
current_metadata = event.node_run_result.metadata
|
||||
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
|
||||
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LoopNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
# Extract loop node IDs statically from graph_config
|
||||
|
||||
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
|
||||
|
||||
# Get node configs from graph_config
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
for sub_node_id, sub_node_config in node_configs.items():
|
||||
if sub_node_config.get("data", {}).get("loop_id") != node_id:
|
||||
continue
|
||||
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
continue
|
||||
node_version = sub_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
sub_node_variable_mapping = {}
|
||||
|
||||
# remove loop variables
|
||||
sub_node_variable_mapping = {
|
||||
sub_node_id + "." + key: value
|
||||
for key, value in sub_node_variable_mapping.items()
|
||||
if value[0] != node_id
|
||||
}
|
||||
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
for loop_variable in typed_node_data.loop_variables or []:
|
||||
if loop_variable.value_type == "variable":
|
||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
||||
# add loop variable to variable mapping
|
||||
selector = loop_variable.value
|
||||
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
|
||||
|
||||
# remove variable out from loop
|
||||
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
|
||||
"""
|
||||
Extract node IDs that belong to a specific loop from graph configuration.
|
||||
|
||||
This method statically analyzes the graph configuration to find all nodes
|
||||
that are part of the specified loop, without creating actual node instances.
|
||||
|
||||
:param graph_config: the complete graph configuration
|
||||
:param loop_node_id: the ID of the loop node
|
||||
:return: set of node IDs that belong to the loop
|
||||
"""
|
||||
loop_node_ids = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("loop_id") == loop_node_id:
|
||||
node_id = node.get("id")
|
||||
if node_id:
|
||||
loop_node_ids.add(node_id)
|
||||
|
||||
return loop_node_ids
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
# TODO: Refactor for maintainability:
|
||||
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
|
||||
# 2. Consider moving this method to LoopVariableData class for better encapsulation
|
||||
if not var_type.is_array_type() or var_type == SegmentType.ARRAY_BOOLEAN:
|
||||
value = original_value
|
||||
elif var_type in [
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
]:
|
||||
if original_value and isinstance(original_value, str):
|
||||
value = json.loads(original_value)
|
||||
else:
|
||||
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
|
||||
value = []
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
try:
|
||||
return build_segment_with_type(var_type, value=value)
|
||||
except TypeMismatchError as type_exc:
|
||||
# Attempt to parse the value as a JSON-encoded string, if applicable.
|
||||
if not isinstance(original_value, str):
|
||||
raise
|
||||
try:
|
||||
value = json.loads(original_value)
|
||||
except ValueError:
|
||||
raise type_exc
|
||||
return build_segment_with_type(var_type, value)
|
||||
|
||||
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
graph_config=self.graph_config,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
invoke_from=self.invoke_from.value,
|
||||
call_depth=self.workflow_call_depth,
|
||||
)
|
||||
|
||||
# Create a new GraphRuntimeState for this iteration
|
||||
graph_runtime_state_copy = GraphRuntimeState(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
start_at=start_at.timestamp(),
|
||||
)
|
||||
|
||||
# Create a new node factory with the new GraphRuntimeState
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
|
||||
)
|
||||
|
||||
# Initialize the loop graph with the new node factory
|
||||
loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph=loop_graph,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
)
|
||||
|
||||
return graph_engine
|
||||
49
dify/api/core/workflow/nodes/loop/loop_start_node.py
Normal file
49
dify/api/core/workflow/nodes/loop/loop_start_node.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopStartNodeData
|
||||
|
||||
|
||||
class LoopStartNode(Node):
|
||||
"""
|
||||
Loop Start Node.
|
||||
"""
|
||||
|
||||
node_type = NodeType.LOOP_START
|
||||
|
||||
_node_data: LoopStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopStartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
85
dify/api/core/workflow/nodes/node_factory.py
Normal file
85
dify/api/core/workflow/nodes/node_factory.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
@final
|
||||
class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
Default implementation of NodeFactory that uses the traditional node mapping.
|
||||
|
||||
This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
|
||||
and instantiating the appropriate node class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data using the traditional mapping.
|
||||
|
||||
:param node_config: node configuration dictionary containing type and other data
|
||||
:return: initialized Node instance
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
"""
|
||||
# Get node_id from config
|
||||
node_id = node_config.get("id")
|
||||
if not is_str(node_id):
|
||||
raise ValueError("Node config missing id")
|
||||
|
||||
# Get node type from config
|
||||
node_data = node_config.get("data", {})
|
||||
if not is_str_dict(node_data):
|
||||
raise ValueError(f"Node {node_id} missing data information")
|
||||
|
||||
node_type_str = node_data.get("type")
|
||||
if not is_str(node_type_str):
|
||||
raise ValueError(f"Node {node_id} missing or invalid type information")
|
||||
|
||||
try:
|
||||
node_type = NodeType(node_type_str)
|
||||
except ValueError:
|
||||
raise ValueError(f"Unknown node type: {node_type_str}")
|
||||
|
||||
# Get node class
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
node_class = node_mapping.get(LATEST_VERSION)
|
||||
if not node_class:
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
|
||||
# Create node instance
|
||||
node_instance = node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node with provided data
|
||||
node_data = node_config.get("data", {})
|
||||
if not is_str_dict(node_data):
|
||||
raise ValueError(f"Node {node_id} missing data information")
|
||||
node_instance.init_node_data(node_data)
|
||||
|
||||
return node_instance
|
||||
165
dify/api/core/workflow/nodes/node_mapping.py
Normal file
165
dify/api/core/workflow/nodes/node_mapping.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code import CodeNode
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request import HttpRequestNode
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.if_else import IfElseNode
|
||||
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
|
||||
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from core.workflow.nodes.start import StartNode
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.trigger_plugin import TriggerEventNode
|
||||
from core.workflow.nodes.trigger_schedule import TriggerScheduleNode
|
||||
from core.workflow.nodes.trigger_webhook import TriggerWebhookNode
|
||||
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
|
||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
|
||||
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
|
||||
# Specifically, if you have introduced new node types, you should add them here.
|
||||
#
|
||||
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
|
||||
# hook. Try to avoid duplication of node information.
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
|
||||
NodeType.START: {
|
||||
LATEST_VERSION: StartNode,
|
||||
"1": StartNode,
|
||||
},
|
||||
NodeType.END: {
|
||||
LATEST_VERSION: EndNode,
|
||||
"1": EndNode,
|
||||
},
|
||||
NodeType.ANSWER: {
|
||||
LATEST_VERSION: AnswerNode,
|
||||
"1": AnswerNode,
|
||||
},
|
||||
NodeType.LLM: {
|
||||
LATEST_VERSION: LLMNode,
|
||||
"1": LLMNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: {
|
||||
LATEST_VERSION: KnowledgeRetrievalNode,
|
||||
"1": KnowledgeRetrievalNode,
|
||||
},
|
||||
NodeType.IF_ELSE: {
|
||||
LATEST_VERSION: IfElseNode,
|
||||
"1": IfElseNode,
|
||||
},
|
||||
NodeType.CODE: {
|
||||
LATEST_VERSION: CodeNode,
|
||||
"1": CodeNode,
|
||||
},
|
||||
NodeType.TEMPLATE_TRANSFORM: {
|
||||
LATEST_VERSION: TemplateTransformNode,
|
||||
"1": TemplateTransformNode,
|
||||
},
|
||||
NodeType.QUESTION_CLASSIFIER: {
|
||||
LATEST_VERSION: QuestionClassifierNode,
|
||||
"1": QuestionClassifierNode,
|
||||
},
|
||||
NodeType.HTTP_REQUEST: {
|
||||
LATEST_VERSION: HttpRequestNode,
|
||||
"1": HttpRequestNode,
|
||||
},
|
||||
NodeType.TOOL: {
|
||||
LATEST_VERSION: ToolNode,
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use two different versions to point to the same class here,
|
||||
# but in order to maintain compatibility with historical data, this approach has been retained.
|
||||
"2": ToolNode,
|
||||
"1": ToolNode,
|
||||
},
|
||||
NodeType.VARIABLE_AGGREGATOR: {
|
||||
LATEST_VERSION: VariableAggregatorNode,
|
||||
"1": VariableAggregatorNode,
|
||||
},
|
||||
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
|
||||
LATEST_VERSION: VariableAggregatorNode,
|
||||
"1": VariableAggregatorNode,
|
||||
}, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: {
|
||||
LATEST_VERSION: IterationNode,
|
||||
"1": IterationNode,
|
||||
},
|
||||
NodeType.ITERATION_START: {
|
||||
LATEST_VERSION: IterationStartNode,
|
||||
"1": IterationStartNode,
|
||||
},
|
||||
NodeType.LOOP: {
|
||||
LATEST_VERSION: LoopNode,
|
||||
"1": LoopNode,
|
||||
},
|
||||
NodeType.LOOP_START: {
|
||||
LATEST_VERSION: LoopStartNode,
|
||||
"1": LoopStartNode,
|
||||
},
|
||||
NodeType.LOOP_END: {
|
||||
LATEST_VERSION: LoopEndNode,
|
||||
"1": LoopEndNode,
|
||||
},
|
||||
NodeType.PARAMETER_EXTRACTOR: {
|
||||
LATEST_VERSION: ParameterExtractorNode,
|
||||
"1": ParameterExtractorNode,
|
||||
},
|
||||
NodeType.VARIABLE_ASSIGNER: {
|
||||
LATEST_VERSION: VariableAssignerNodeV2,
|
||||
"1": VariableAssignerNodeV1,
|
||||
"2": VariableAssignerNodeV2,
|
||||
},
|
||||
NodeType.DOCUMENT_EXTRACTOR: {
|
||||
LATEST_VERSION: DocumentExtractorNode,
|
||||
"1": DocumentExtractorNode,
|
||||
},
|
||||
NodeType.LIST_OPERATOR: {
|
||||
LATEST_VERSION: ListOperatorNode,
|
||||
"1": ListOperatorNode,
|
||||
},
|
||||
NodeType.AGENT: {
|
||||
LATEST_VERSION: AgentNode,
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use two different versions to point to the same class here,
|
||||
# but in order to maintain compatibility with historical data, this approach has been retained.
|
||||
"2": AgentNode,
|
||||
"1": AgentNode,
|
||||
},
|
||||
NodeType.HUMAN_INPUT: {
|
||||
LATEST_VERSION: HumanInputNode,
|
||||
"1": HumanInputNode,
|
||||
},
|
||||
NodeType.DATASOURCE: {
|
||||
LATEST_VERSION: DatasourceNode,
|
||||
"1": DatasourceNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_INDEX: {
|
||||
LATEST_VERSION: KnowledgeIndexNode,
|
||||
"1": KnowledgeIndexNode,
|
||||
},
|
||||
NodeType.TRIGGER_WEBHOOK: {
|
||||
LATEST_VERSION: TriggerWebhookNode,
|
||||
"1": TriggerWebhookNode,
|
||||
},
|
||||
NodeType.TRIGGER_PLUGIN: {
|
||||
LATEST_VERSION: TriggerEventNode,
|
||||
"1": TriggerEventNode,
|
||||
},
|
||||
NodeType.TRIGGER_SCHEDULE: {
|
||||
LATEST_VERSION: TriggerScheduleNode,
|
||||
"1": TriggerScheduleNode,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
from .parameter_extractor_node import ParameterExtractorNode
|
||||
|
||||
__all__ = ["ParameterExtractorNode"]
|
||||
129
dify/api/core/workflow/nodes/parameter_extractor/entities.py
Normal file
129
dify/api/core/workflow/nodes/parameter_extractor/entities.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
|
||||
_OLD_BOOL_TYPE_NAME = "bool"
|
||||
_OLD_SELECT_TYPE_NAME = "select"
|
||||
|
||||
_VALID_PARAMETER_TYPES = frozenset(
|
||||
[
|
||||
SegmentType.STRING, # "string",
|
||||
SegmentType.NUMBER, # "number",
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
_OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node
|
||||
_OLD_SELECT_TYPE_NAME, # string type with enumeration choices.
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_type(parameter_type: str) -> SegmentType:
|
||||
if parameter_type not in _VALID_PARAMETER_TYPES:
|
||||
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
|
||||
|
||||
if parameter_type == _OLD_BOOL_TYPE_NAME:
|
||||
return SegmentType.BOOLEAN
|
||||
elif parameter_type == _OLD_SELECT_TYPE_NAME:
|
||||
return SegmentType.STRING
|
||||
return SegmentType(parameter_type)
|
||||
|
||||
|
||||
class ParameterConfig(BaseModel):
|
||||
"""
|
||||
Parameter Config.
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
|
||||
options: list[str] | None = None
|
||||
description: str
|
||||
required: bool
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def validate_name(cls, value) -> str:
|
||||
if not value:
|
||||
raise ValueError("Parameter name is required")
|
||||
if value in {"__reason", "__is_success"}:
|
||||
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
|
||||
return str(value)
|
||||
|
||||
def is_array_type(self) -> bool:
|
||||
return self.type.is_array_type()
|
||||
|
||||
def element_type(self) -> SegmentType:
|
||||
"""Return the element type of the parameter.
|
||||
|
||||
Raises a ValueError if the parameter's type is not an array type.
|
||||
"""
|
||||
element_type = self.type.element_type()
|
||||
# At this point, self.type is guaranteed to be one of `ARRAY_STRING`,
|
||||
# `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`.
|
||||
#
|
||||
# See: _VALID_PARAMETER_TYPES for reference.
|
||||
assert element_type is not None, f"the element type should not be None, {self.type=}"
|
||||
return element_type
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
"""
|
||||
Parameter Extractor Node Data.
|
||||
"""
|
||||
|
||||
model: ModelConfig
|
||||
query: list[str]
|
||||
parameters: list[ParameterConfig]
|
||||
instruction: str | None = None
|
||||
memory: MemoryConfig | None = None
|
||||
reasoning_mode: Literal["function_call", "prompt"]
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@field_validator("reasoning_mode", mode="before")
|
||||
@classmethod
|
||||
def set_reasoning_mode(cls, v) -> str:
|
||||
return v or "function_call"
|
||||
|
||||
def get_parameter_json_schema(self):
|
||||
"""
|
||||
Get parameter json schema.
|
||||
|
||||
:return: parameter json schema
|
||||
"""
|
||||
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
for parameter in self.parameters:
|
||||
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
||||
|
||||
if parameter.type == SegmentType.STRING:
|
||||
parameter_schema["type"] = "string"
|
||||
elif parameter.type.is_array_type():
|
||||
parameter_schema["type"] = "array"
|
||||
element_type = parameter.type.element_type()
|
||||
if element_type is None:
|
||||
raise AssertionError("element type should not be None.")
|
||||
parameter_schema["items"] = {"type": element_type.value}
|
||||
else:
|
||||
parameter_schema["type"] = parameter.type
|
||||
|
||||
if parameter.options:
|
||||
parameter_schema["enum"] = parameter.options
|
||||
|
||||
parameters["properties"][parameter.name] = parameter_schema
|
||||
|
||||
if parameter.required:
|
||||
parameters["required"].append(parameter.name)
|
||||
|
||||
return parameters
|
||||
75
dify/api/core/workflow/nodes/parameter_extractor/exc.py
Normal file
75
dify/api/core/workflow/nodes/parameter_extractor/exc.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import Any
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
|
||||
class ParameterExtractorNodeError(ValueError):
|
||||
"""Base error for ParameterExtractorNode."""
|
||||
|
||||
|
||||
class InvalidModelTypeError(ParameterExtractorNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
|
||||
|
||||
class ModelSchemaNotFoundError(ParameterExtractorNodeError):
|
||||
"""Raised when the model schema is not found."""
|
||||
|
||||
|
||||
class InvalidInvokeResultError(ParameterExtractorNodeError):
|
||||
"""Raised when the invoke result is invalid."""
|
||||
|
||||
|
||||
class InvalidTextContentTypeError(ParameterExtractorNodeError):
|
||||
"""Raised when the text content type is invalid."""
|
||||
|
||||
|
||||
class InvalidNumberOfParametersError(ParameterExtractorNodeError):
|
||||
"""Raised when the number of parameters is invalid."""
|
||||
|
||||
|
||||
class RequiredParameterMissingError(ParameterExtractorNodeError):
|
||||
"""Raised when a required parameter is missing."""
|
||||
|
||||
|
||||
class InvalidSelectValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a select value is invalid."""
|
||||
|
||||
|
||||
class InvalidNumberValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a number value is invalid."""
|
||||
|
||||
|
||||
class InvalidBoolValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a bool value is invalid."""
|
||||
|
||||
|
||||
class InvalidStringValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a string value is invalid."""
|
||||
|
||||
|
||||
class InvalidArrayValueError(ParameterExtractorNodeError):
|
||||
"""Raised when an array value is invalid."""
|
||||
|
||||
|
||||
class InvalidModelModeError(ParameterExtractorNodeError):
|
||||
"""Raised when the model mode is invalid."""
|
||||
|
||||
|
||||
class InvalidValueTypeError(ParameterExtractorNodeError):
|
||||
def __init__(
|
||||
self,
|
||||
/,
|
||||
parameter_name: str,
|
||||
expected_type: SegmentType,
|
||||
actual_type: SegmentType | None,
|
||||
value: Any,
|
||||
):
|
||||
message = (
|
||||
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
|
||||
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
|
||||
)
|
||||
super().__init__(message)
|
||||
self.parameter_name = parameter_name
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
self.value = value
|
||||
@@ -0,0 +1,858 @@
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||
from core.workflow.runtime import VariablePool
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .exc import (
|
||||
InvalidModelModeError,
|
||||
InvalidModelTypeError,
|
||||
InvalidNumberOfParametersError,
|
||||
InvalidSelectValueError,
|
||||
InvalidTextContentTypeError,
|
||||
InvalidValueTypeError,
|
||||
ModelSchemaNotFoundError,
|
||||
ParameterExtractorNodeError,
|
||||
RequiredParameterMissingError,
|
||||
)
|
||||
from .prompts import (
|
||||
CHAT_EXAMPLE,
|
||||
CHAT_GENERATE_JSON_PROMPT,
|
||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
|
||||
COMPLETION_GENERATE_JSON_PROMPT,
|
||||
FUNCTION_CALLING_EXTRACTOR_EXAMPLE,
|
||||
FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT,
|
||||
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_json(text):
|
||||
"""
|
||||
From a given JSON started from '{' or '[' extract the complete JSON object.
|
||||
"""
|
||||
stack = []
|
||||
for i, c in enumerate(text):
|
||||
if c in {"{", "["}:
|
||||
stack.append(c)
|
||||
elif c in {"}", "]"}:
|
||||
# check if stack is empty
|
||||
if not stack:
|
||||
return text[:i]
|
||||
# check if the last element in stack is matching
|
||||
if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["):
|
||||
stack.pop()
|
||||
if not stack:
|
||||
return text[: i + 1]
|
||||
else:
|
||||
return text[:i]
|
||||
return None
|
||||
|
||||
|
||||
class ParameterExtractorNode(Node):
|
||||
"""
|
||||
Parameter Extractor Node.
|
||||
"""
|
||||
|
||||
node_type = NodeType.PARAMETER_EXTRACTOR
|
||||
|
||||
_node_data: ParameterExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ParameterExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
_model_instance: ModelInstance | None = None
|
||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"model": {
|
||||
"prompt_templates": {
|
||||
"completion_model": {
|
||||
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
|
||||
"stop": ["Human:"],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
node_data = self._node_data
|
||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
|
||||
query = variable.text if variable else ""
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
files = (
|
||||
llm_utils.fetch_files(
|
||||
variable_pool=variable_pool,
|
||||
selector=node_data.vision.configs.variable_selector,
|
||||
)
|
||||
if node_data.vision.enabled
|
||||
else []
|
||||
)
|
||||
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model_config.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ModelSchemaNotFoundError("Model schema not found")
|
||||
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
if (
|
||||
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
|
||||
and node_data.reasoning_mode == "function_call"
|
||||
):
|
||||
# use function call
|
||||
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
model_config=model_config,
|
||||
memory=memory,
|
||||
files=files,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
)
|
||||
else:
|
||||
# use prompt engineering
|
||||
prompt_messages = self._generate_prompt_engineering_prompt(
|
||||
data=node_data,
|
||||
query=query,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
model_config=model_config,
|
||||
memory=memory,
|
||||
files=files,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
)
|
||||
|
||||
prompt_message_tools = []
|
||||
|
||||
inputs = {
|
||||
"query": query,
|
||||
"files": [f.to_dict() for f in files],
|
||||
"parameters": jsonable_encoder(node_data.parameters),
|
||||
"instruction": jsonable_encoder(node_data.instruction),
|
||||
}
|
||||
|
||||
process_data = {
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
"usage": None,
|
||||
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
|
||||
"tool_call": None,
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
|
||||
try:
|
||||
text, usage, tool_call = self._invoke(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=prompt_message_tools,
|
||||
stop=model_config.stop,
|
||||
)
|
||||
process_data["usage"] = jsonable_encoder(usage)
|
||||
process_data["tool_call"] = jsonable_encoder(tool_call)
|
||||
process_data["llm_text"] = text
|
||||
except ParameterExtractorNodeError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"__is_success": 0, "__reason": str(e)},
|
||||
error=str(e),
|
||||
metadata={},
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)},
|
||||
error=str(e),
|
||||
metadata={},
|
||||
)
|
||||
|
||||
error = None
|
||||
|
||||
if tool_call:
|
||||
result = self._extract_json_from_tool_call(tool_call)
|
||||
else:
|
||||
result = self._extract_complete_json_response(text)
|
||||
if not result:
|
||||
result = self._generate_default_result(node_data)
|
||||
error = "Failed to extract result from function call or text response, using empty result."
|
||||
|
||||
try:
|
||||
result = self._validate_result(data=node_data, result=result or {})
|
||||
except ParameterExtractorNodeError as e:
|
||||
error = str(e)
|
||||
|
||||
# transform result into standard format
|
||||
result = self._transform_result(data=node_data, result=result or {})
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={
|
||||
"__is_success": 1 if not error else 0,
|
||||
"__reason": error,
|
||||
"__usage": jsonable_encoder(usage),
|
||||
**result,
|
||||
},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user=self.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
|
||||
text = invoke_result.message.content or ""
|
||||
if not isinstance(text, str):
|
||||
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
|
||||
|
||||
usage = invoke_result.usage
|
||||
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage, tool_call
|
||||
|
||||
def _generate_function_call_prompt(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
"""
|
||||
Generate function call prompt.
|
||||
"""
|
||||
query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(
|
||||
content=query, structure=json.dumps(node_data.get_parameter_json_schema())
|
||||
)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
|
||||
prompt_template = self._get_function_calling_prompt_template(
|
||||
node_data, query, variable_pool, memory, rest_token
|
||||
)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=files,
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
# find last user message
|
||||
last_user_message_idx = -1
|
||||
for i, prompt_message in enumerate(prompt_messages):
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
last_user_message_idx = i
|
||||
|
||||
# add function call messages before last user message
|
||||
example_messages = []
|
||||
for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE:
|
||||
id = uuid.uuid4().hex
|
||||
example_messages.extend(
|
||||
[
|
||||
UserPromptMessage(content=example["user"]["query"]),
|
||||
AssistantPromptMessage(
|
||||
content=example["assistant"]["text"],
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=example["assistant"]["function_call"]["name"],
|
||||
arguments=json.dumps(example["assistant"]["function_call"]["parameters"]),
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
ToolPromptMessage(
|
||||
content="Great! You have called the function with the correct parameters.", tool_call_id=id
|
||||
),
|
||||
AssistantPromptMessage(
|
||||
content="I have extracted the parameters, let's move on.",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
prompt_messages = (
|
||||
prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:]
|
||||
)
|
||||
|
||||
# generate tool
|
||||
tool = PromptMessageTool(
|
||||
name=FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
description="Extract parameters from the natural language text",
|
||||
parameters=node_data.get_parameter_json_schema(),
|
||||
)
|
||||
|
||||
return prompt_messages, [tool]
|
||||
|
||||
def _generate_prompt_engineering_prompt(
|
||||
self,
|
||||
data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate prompt engineering prompt.
|
||||
"""
|
||||
model_mode = ModelMode(data.model.mode)
|
||||
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
return self._generate_prompt_engineering_completion_prompt(
|
||||
node_data=data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
model_config=model_config,
|
||||
memory=memory,
|
||||
files=files,
|
||||
vision_detail=vision_detail,
|
||||
)
|
||||
elif model_mode == ModelMode.CHAT:
|
||||
return self._generate_prompt_engineering_chat_prompt(
|
||||
node_data=data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
model_config=model_config,
|
||||
memory=memory,
|
||||
files=files,
|
||||
vision_detail=vision_detail,
|
||||
)
|
||||
else:
|
||||
raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
|
||||
|
||||
def _generate_prompt_engineering_completion_prompt(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate completion prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
|
||||
)
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
|
||||
)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={"structure": json.dumps(node_data.get_parameter_json_schema())},
|
||||
query="",
|
||||
files=files,
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _generate_prompt_engineering_chat_prompt(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate chat prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
|
||||
)
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data=node_data,
|
||||
query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
|
||||
structure=json.dumps(node_data.get_parameter_json_schema()), text=query
|
||||
),
|
||||
variable_pool=variable_pool,
|
||||
memory=memory,
|
||||
max_token_limit=rest_token,
|
||||
)
|
||||
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=files,
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
# find last user message
|
||||
last_user_message_idx = -1
|
||||
for i, prompt_message in enumerate(prompt_messages):
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
last_user_message_idx = i
|
||||
|
||||
# add example messages before last user message
|
||||
example_messages = []
|
||||
for example in CHAT_EXAMPLE:
|
||||
example_messages.extend(
|
||||
[
|
||||
UserPromptMessage(
|
||||
content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
|
||||
structure=json.dumps(example["user"]["json"]),
|
||||
text=example["user"]["query"],
|
||||
)
|
||||
),
|
||||
AssistantPromptMessage(
|
||||
content=json.dumps(example["assistant"]["json"]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
prompt_messages = (
|
||||
prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _validate_result(self, data: ParameterExtractorNodeData, result: dict):
|
||||
if len(data.parameters) != len(result):
|
||||
raise InvalidNumberOfParametersError("Invalid number of parameters")
|
||||
|
||||
for parameter in data.parameters:
|
||||
if parameter.required and parameter.name not in result:
|
||||
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
|
||||
|
||||
param_value = result.get(parameter.name)
|
||||
if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL):
|
||||
inferred_type = SegmentType.infer_segment_type(param_value)
|
||||
raise InvalidValueTypeError(
|
||||
parameter_name=parameter.name,
|
||||
expected_type=parameter.type,
|
||||
actual_type=inferred_type,
|
||||
value=param_value,
|
||||
)
|
||||
if parameter.type == SegmentType.STRING and parameter.options:
|
||||
if param_value not in parameter.options:
|
||||
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _transform_number(value: int | float | str | bool) -> int | float | None:
|
||||
"""
|
||||
Attempts to transform the input into an integer or float.
|
||||
|
||||
Returns:
|
||||
int or float: The transformed number if the conversion is successful.
|
||||
None: If the transformation fails.
|
||||
|
||||
Note:
|
||||
Boolean values `True` and `False` are converted to integers `1` and `0`, respectively.
|
||||
This behavior ensures compatibility with existing workflows that may use boolean types as integers.
|
||||
"""
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
elif isinstance(value, (int, float)):
|
||||
return value
|
||||
elif isinstance(value, str):
|
||||
if "." in value:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
|
||||
"""
|
||||
Transform result into standard format.
|
||||
"""
|
||||
transformed_result: dict[str, Any] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.name in result:
|
||||
param_value = result[parameter.name]
|
||||
# transform value
|
||||
if parameter.type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(param_value)
|
||||
if transformed is not None:
|
||||
transformed_result[parameter.name] = transformed
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
if isinstance(result[parameter.name], (bool, int)):
|
||||
transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
# elif isinstance(result[parameter.name], str):
|
||||
# if result[parameter.name].lower() in ["true", "false"]:
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true")
|
||||
elif parameter.type == SegmentType.STRING:
|
||||
if isinstance(param_value, str):
|
||||
transformed_result[parameter.name] = param_value
|
||||
elif parameter.is_array_type():
|
||||
if isinstance(param_value, list):
|
||||
nested_type = parameter.element_type()
|
||||
assert nested_type is not None
|
||||
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
|
||||
transformed_result[parameter.name] = segment_value
|
||||
for item in param_value:
|
||||
if nested_type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(item)
|
||||
if transformed is not None:
|
||||
segment_value.value.append(transformed)
|
||||
elif nested_type == SegmentType.STRING:
|
||||
if isinstance(item, str):
|
||||
segment_value.value.append(item)
|
||||
elif nested_type == SegmentType.OBJECT:
|
||||
if isinstance(item, dict):
|
||||
segment_value.value.append(item)
|
||||
elif nested_type == SegmentType.BOOLEAN:
|
||||
if isinstance(item, bool):
|
||||
segment_value.value.append(item)
|
||||
|
||||
if parameter.name not in transformed_result:
|
||||
if parameter.type.is_array_type():
|
||||
transformed_result[parameter.name] = build_segment_with_type(
|
||||
segment_type=SegmentType(parameter.type), value=[]
|
||||
)
|
||||
elif parameter.type in (SegmentType.STRING, SegmentType.SECRET):
|
||||
transformed_result[parameter.name] = ""
|
||||
elif parameter.type == SegmentType.NUMBER:
|
||||
transformed_result[parameter.name] = 0
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
transformed_result[parameter.name] = False
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
return transformed_result
|
||||
|
||||
def _extract_complete_json_response(self, result: str) -> dict | None:
|
||||
"""
|
||||
Extract complete json response.
|
||||
"""
|
||||
|
||||
# extract json from the text
|
||||
for idx in range(len(result)):
|
||||
if result[idx] == "{" or result[idx] == "[":
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
with contextlib.suppress(Exception):
|
||||
return cast(dict, json.loads(json_str))
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
|
||||
"""
|
||||
Extract json from tool call.
|
||||
"""
|
||||
if not tool_call or not tool_call.function.arguments:
|
||||
return None
|
||||
|
||||
result = tool_call.function.arguments
|
||||
# extract json from the arguments
|
||||
for idx in range(len(result)):
|
||||
if result[idx] == "{" or result[idx] == "[":
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
with contextlib.suppress(Exception):
|
||||
return cast(dict, json.loads(json_str))
|
||||
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData):
|
||||
"""
|
||||
Generate default result.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.type == "number":
|
||||
result[parameter.name] = 0
|
||||
elif parameter.type == "boolean":
|
||||
result[parameter.name] = False
|
||||
elif parameter.type in {"string", "select"}:
|
||||
result[parameter.name] = ""
|
||||
|
||||
return result
|
||||
|
||||
def _get_function_calling_prompt_template(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
|
||||
if memory and node_data.memory and node_data.memory.window:
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
else:
|
||||
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
|
||||
|
||||
def _get_prompt_engineering_prompt_template(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
|
||||
if memory and node_data.memory and node_data.memory.window:
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return CompletionModelPromptTemplate(
|
||||
text=COMPLETION_GENERATE_JSON_PROMPT.format(
|
||||
histories=memory_str, text=input_text, instruction=instruction
|
||||
)
|
||||
.replace("{γγγ", "")
|
||||
.replace("}γγγ", "")
|
||||
)
|
||||
else:
|
||||
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
|
||||
|
||||
def _calculate_rest_token(
|
||||
self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: str | None,
|
||||
) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
|
||||
if not model_schema:
|
||||
raise ModelSchemaNotFoundError("Model schema not found")
|
||||
|
||||
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
else:
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
curr_message_tokens = (
|
||||
model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000
|
||||
) # add 1000 to ensure tool call messages
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _fetch_model_config(
|
||||
self, node_data_model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config.
|
||||
"""
|
||||
if not self._model_instance or not self._model_config:
|
||||
self._model_instance, self._model_config = llm_utils.fetch_model_config(
|
||||
tenant_id=self.tenant_id, node_data_model=node_data_model
|
||||
)
|
||||
|
||||
return self._model_instance, self._model_config
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
|
||||
|
||||
if typed_node_data.instruction:
|
||||
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
|
||||
for selector in selectors:
|
||||
variable_mapping[selector.variable] = selector.value_selector
|
||||
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
|
||||
return variable_mapping
|
||||
184
dify/api/core/workflow/nodes/parameter_extractor/prompts.py
Normal file
184
dify/api/core/workflow/nodes/parameter_extractor/prompts.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from typing import Any
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters"
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
|
||||
### Task
|
||||
Always call the `{FUNCTION_CALLING_EXTRACTOR_NAME}` function with the correct parameters. Ensure that the information extraction is contextual and aligns with the provided criteria.
|
||||
### Memory
|
||||
Here is the chat history between the human and assistant, provided within <histories> tags:
|
||||
<histories>
|
||||
\x7bhistories\x7d
|
||||
</histories>
|
||||
### Instructions:
|
||||
Some additional information is provided below. Always adhere to these instructions as closely as possible:
|
||||
<instruction>
|
||||
\x7binstruction\x7d
|
||||
</instruction>
|
||||
Steps:
|
||||
1. Review the chat history provided within the <histories> tags.
|
||||
2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text.
|
||||
3. Generate a well-formatted output using the defined functions and arguments.
|
||||
4. Use the `extract_parameter` function to create structured outputs with appropriate parameters.
|
||||
5. Do not include any XML tags in your output.
|
||||
### Example
|
||||
To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples.
|
||||
### Final Output
|
||||
Produce well-formatted function calls in json without XML tags, as shown in the example.
|
||||
""" # noqa: E501
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside <context></context> XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside <structure></structure> XML tags.
|
||||
<context>
|
||||
\x7bcontent\x7d
|
||||
</context>
|
||||
|
||||
<structure>
|
||||
\x7bstructure\x7d
|
||||
</structure>
|
||||
""" # noqa: E501
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [
|
||||
{
|
||||
"user": {
|
||||
"query": "What is the weather today in SF?",
|
||||
"function": {
|
||||
"name": FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The location to get the weather information",
|
||||
"required": True,
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"assistant": {
|
||||
"text": "I need always call the function with the correct parameters."
|
||||
" in this case, I need to call the function with the location parameter.",
|
||||
"function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"user": {
|
||||
"query": "I want to eat some apple pie.",
|
||||
"function": {
|
||||
"name": FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"food": {"type": "string", "description": "The food to eat", "required": True}},
|
||||
"required": ["food"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"assistant": {
|
||||
"text": "I need always call the function with the correct parameters."
|
||||
" in this case, I need to call the function with the food parameter.",
|
||||
"function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
COMPLETION_GENERATE_JSON_PROMPT = """### Instructions:
|
||||
Some extra information are provided below, I should always follow the instructions as possible as I can.
|
||||
<instructions>
|
||||
{instruction}
|
||||
</instructions>
|
||||
|
||||
### Extract parameter Workflow
|
||||
I need to extract the following information from the input text. The <information to be extracted> tag specifies the 'type', 'description' and 'required' of the information to be extracted.
|
||||
<information to be extracted>
|
||||
{{ structure }}
|
||||
</information to be extracted>
|
||||
|
||||
Step 1: Carefully read the input and understand the structure of the expected output.
|
||||
Step 2: Extract relevant parameters from the provided text based on the name and description of object.
|
||||
Step 3: Structure the extracted parameters to JSON object as specified in <structure>.
|
||||
Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted.
|
||||
|
||||
### Memory
|
||||
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
<histories>
|
||||
{histories}
|
||||
</histories>
|
||||
|
||||
### Structure
|
||||
Here is the structure of the expected output, I should always follow the output structure.
|
||||
{{γγγ
|
||||
'properties1': 'relevant text extracted from input',
|
||||
'properties2': 'relevant text extracted from input',
|
||||
}}γγγ
|
||||
|
||||
### Input Text
|
||||
Inside <text></text> XML tags, there is a text that I should extract parameters and convert to a JSON object.
|
||||
<text>
|
||||
{text}
|
||||
</text>
|
||||
|
||||
### Answer
|
||||
I should always output a valid JSON object. Output nothing other than the JSON object.
|
||||
```JSON
|
||||
""" # noqa: E501
|
||||
|
||||
CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object.
|
||||
The structure of the JSON object you can found in the instructions.
|
||||
|
||||
### Memory
|
||||
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
<histories>
|
||||
{histories}
|
||||
</histories>
|
||||
|
||||
### Instructions:
|
||||
Some extra information are provided below, you should always follow the instructions as possible as you can.
|
||||
<instructions>
|
||||
{instructions}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE = """### Structure
|
||||
Here is the structure of the JSON object, you should always follow the structure.
|
||||
<structure>
|
||||
{structure}
|
||||
</structure>
|
||||
|
||||
### Text to be converted to JSON
|
||||
Inside <text></text> XML tags, there is a text that you should convert to a JSON object.
|
||||
<text>
|
||||
{text}
|
||||
</text>
|
||||
"""
|
||||
|
||||
CHAT_EXAMPLE = [
|
||||
{
|
||||
"user": {
|
||||
"query": "What is the weather today in SF?",
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The location to get the weather information",
|
||||
"required": True,
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
"assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}},
|
||||
},
|
||||
{
|
||||
"user": {
|
||||
"query": "I want to eat some apple pie.",
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": {"food": {"type": "string", "description": "The food to eat", "required": True}},
|
||||
"required": ["food"],
|
||||
},
|
||||
},
|
||||
"assistant": {"text": "I need to output a valid JSON object.", "json": {"food": "apple pie"}},
|
||||
},
|
||||
]
|
||||
@@ -0,0 +1,4 @@
|
||||
from .entities import QuestionClassifierNodeData
|
||||
from .question_classifier_node import QuestionClassifierNode
|
||||
|
||||
__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"]
|
||||
28
dify/api/core/workflow/nodes/question_classifier/entities.py
Normal file
28
dify/api/core/workflow/nodes/question_classifier/entities.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
class ClassConfig(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class QuestionClassifierNodeData(BaseNodeData):
|
||||
query_variable_selector: list[str]
|
||||
model: ModelConfig
|
||||
classes: list[ClassConfig]
|
||||
instruction: str | None = None
|
||||
memory: MemoryConfig | None = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@property
|
||||
def structured_output_enabled(self) -> bool:
|
||||
# NOTE(QuantumGhost): Temporary workaround for issue #20725
|
||||
# (https://github.com/langgenius/dify/issues/20725).
|
||||
#
|
||||
# The proper fix would be to make `QuestionClassifierNode` inherit
|
||||
# from `BaseNode` instead of `LLMNode`.
|
||||
return False
|
||||
6
dify/api/core/workflow/nodes/question_classifier/exc.py
Normal file
6
dify/api/core/workflow/nodes/question_classifier/exc.py
Normal file
@@ -0,0 +1,6 @@
|
||||
class QuestionClassifierNodeError(ValueError):
|
||||
"""Base class for QuestionClassifierNode errors."""
|
||||
|
||||
|
||||
class InvalidModelTypeError(QuestionClassifierNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
@@ -0,0 +1,398 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
from .entities import QuestionClassifierNodeData
|
||||
from .exc import InvalidModelTypeError
|
||||
from .template_prompts import (
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
|
||||
QUESTION_CLASSIFIER_COMPLETION_PROMPT,
|
||||
QUESTION_CLASSIFIER_SYSTEM_PROMPT,
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_1,
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_2,
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_3,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class QuestionClassifierNode(Node):
|
||||
node_type = NodeType.QUESTION_CLASSIFIER
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
_node_data: QuestionClassifierNodeData
|
||||
|
||||
_file_outputs: list["File"]
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs = []
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = QuestionClassifierNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
node_data = self._node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# extract variables
|
||||
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
|
||||
query = variable.value if variable else None
|
||||
variables = {"query": query}
|
||||
# fetch model config
|
||||
model_instance, model_config = llm_utils.fetch_model_config(
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_model=node_data.model,
|
||||
)
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
# fetch instruction
|
||||
node_data.instruction = node_data.instruction or ""
|
||||
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
|
||||
|
||||
files = (
|
||||
llm_utils.fetch_files(
|
||||
variable_pool=variable_pool,
|
||||
selector=node_data.vision.configs.variable_selector,
|
||||
)
|
||||
if node_data.vision.enabled
|
||||
else []
|
||||
)
|
||||
|
||||
# fetch prompt messages
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query or "",
|
||||
model_config=model_config,
|
||||
context="",
|
||||
)
|
||||
prompt_template = self._get_prompt_template(
|
||||
node_data=node_data,
|
||||
query=query or "",
|
||||
memory=memory,
|
||||
max_token_limit=rest_token,
|
||||
)
|
||||
# Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...).
|
||||
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
|
||||
# two consecutive user prompts will be generated, causing model's error.
|
||||
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query="",
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
sys_files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
break
|
||||
|
||||
rendered_classes = [
|
||||
c.model_copy(update={"name": variable_pool.convert_template(c.name).text}) for c in node_data.classes
|
||||
]
|
||||
|
||||
category_name = rendered_classes[0].name
|
||||
category_id = rendered_classes[0].id
|
||||
if "<think>" in result_text:
|
||||
result_text = re.sub(r"<think[^>]*>[\s\S]*?</think>", "", result_text, flags=re.IGNORECASE)
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
# result_text_json = json.loads(result_text.strip('```JSON\n'))
|
||||
if "category_name" in result_text_json and "category_id" in result_text_json:
|
||||
category_id_result = result_text_json["category_id"]
|
||||
classes = rendered_classes
|
||||
classes_map = {class_.id: class_.name for class_ in classes}
|
||||
category_ids = [_class.id for _class in classes]
|
||||
if category_id_result in category_ids:
|
||||
category_name = classes_map[category_id_result]
|
||||
category_id = category_id_result
|
||||
process_data = {
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
outputs = {
|
||||
"class_name": category_name,
|
||||
"class_id": category_id,
|
||||
"usage": jsonable_encoder(usage),
|
||||
}
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
edge_source_handle=category_id,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
except ValueError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {"query": typed_node_data.query_variable_selector}
|
||||
variable_selectors: list[VariableSelector] = []
|
||||
if typed_node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)
|
||||
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters (not used in this implementation).
|
||||
:return:
|
||||
"""
|
||||
# filters parameter is not used in this node type
|
||||
return {"type": "question-classifier", "config": {"instructions": ""}}
|
||||
|
||||
def _calculate_rest_token(
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: str | None,
|
||||
) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _get_prompt_template(
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
classes = node_data.classes
|
||||
categories = []
|
||||
for class_ in classes:
|
||||
category = {"category_id": class_.id, "category_name": class_.name}
|
||||
categories.append(category)
|
||||
instruction = node_data.instruction or ""
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
if memory:
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
|
||||
)
|
||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
|
||||
)
|
||||
prompt_messages.append(system_prompt_messages)
|
||||
user_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_1)
|
||||
assistant_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_1)
|
||||
user_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_2)
|
||||
assistant_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_2)
|
||||
user_prompt_message_3 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(
|
||||
input_text=input_text,
|
||||
categories=json.dumps(categories, ensure_ascii=False),
|
||||
classification_instructions=instruction,
|
||||
),
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_3)
|
||||
return prompt_messages
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return LLMNodeCompletionModelPromptTemplate(
|
||||
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(
|
||||
histories=memory_str,
|
||||
input_text=input_text,
|
||||
categories=json.dumps(categories, ensure_ascii=False),
|
||||
classification_instructions=instruction,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")
|
||||
@@ -0,0 +1,76 @@
|
||||
QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
|
||||
### Job Description',
|
||||
You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.
|
||||
### Task
|
||||
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.
|
||||
### Format
|
||||
The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy.
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
### Memory
|
||||
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
<histories>
|
||||
{histories}
|
||||
</histories>
|
||||
""" # noqa: E501
|
||||
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_1 = """
|
||||
{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],
|
||||
"categories": [{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"},{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"},{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"},{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}],
|
||||
"classification_instructions": ["classify the text based on the feedback provided by customer"]}
|
||||
""" # noqa: E501
|
||||
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """
|
||||
```json
|
||||
{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],
|
||||
"category_id": "f5660049-284f-41a7-b301-fd24176a711c",
|
||||
"category_name": "Customer Service"}
|
||||
```
|
||||
"""
|
||||
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_2 = """
|
||||
{"input_text": ["bad service, slow to bring the food"],
|
||||
"categories": [{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"},{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"},{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}],
|
||||
"classification_instructions": []}
|
||||
""" # noqa: E501
|
||||
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """
|
||||
```json
|
||||
{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],
|
||||
"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f",
|
||||
"category_name": "Experience"}
|
||||
```
|
||||
"""
|
||||
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_3 = """
|
||||
{{"input_text": ["{input_text}"],
|
||||
"categories": {categories},
|
||||
"classification_instructions": ["{classification_instructions}"]}}
|
||||
"""
|
||||
|
||||
QUESTION_CLASSIFIER_COMPLETION_PROMPT = """
|
||||
### Job Description
|
||||
You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.
|
||||
### Task
|
||||
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.
|
||||
### Format
|
||||
The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy.
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
### Example
|
||||
Here is the chat example between human and assistant, inside <example></example> XML tags.
|
||||
<example>
|
||||
User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": [{{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"}},{{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"}},{{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"}},{{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}}], "classification_instructions": ["classify the text based on the feedback provided by customer"]}}
|
||||
Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}}
|
||||
User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"}},{{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"}},{{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}}], "classification_instructions": []}}
|
||||
Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}}
|
||||
</example>
|
||||
### Memory
|
||||
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
<histories>
|
||||
{histories}
|
||||
</histories>
|
||||
### User Input
|
||||
{{"input_text" : ["{input_text}"], "categories" : {categories},"classification_instruction" : ["{classification_instructions}"]}}
|
||||
### Assistant Output
|
||||
""" # noqa: E501
|
||||
3
dify/api/core/workflow/nodes/start/__init__.py
Normal file
3
dify/api/core/workflow/nodes/start/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .start_node import StartNode
|
||||
|
||||
__all__ = ["StartNode"]
|
||||
14
dify/api/core/workflow/nodes/start/entities.py
Normal file
14
dify/api/core/workflow/nodes/start/entities.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class StartNodeData(BaseNodeData):
|
||||
"""
|
||||
Start Node Data
|
||||
"""
|
||||
|
||||
variables: Sequence[VariableEntity] = Field(default_factory=list)
|
||||
53
dify/api/core/workflow/nodes/start/start_node.py
Normal file
53
dify/api/core/workflow/nodes/start/start_node.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
|
||||
|
||||
class StartNode(Node):
|
||||
node_type = NodeType.START
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
_node_data: StartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = StartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
|
||||
|
||||
# TODO: System variables should be directly accessible, no need for special handling
|
||||
# Set system variables as node outputs.
|
||||
for var in system_inputs:
|
||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
|
||||
outputs = dict(node_inputs)
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .template_transform_node import TemplateTransformNode
|
||||
|
||||
__all__ = ["TemplateTransformNode"]
|
||||
11
dify/api/core/workflow/nodes/template_transform/entities.py
Normal file
11
dify/api/core/workflow/nodes/template_transform/entities.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
class TemplateTransformNodeData(BaseNodeData):
|
||||
"""
|
||||
Template Transform Node Data.
|
||||
"""
|
||||
|
||||
variables: list[VariableSelector]
|
||||
template: str
|
||||
@@ -0,0 +1,93 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
|
||||
|
||||
class TemplateTransformNode(Node):
|
||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
|
||||
_node_data: TemplateTransformNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"type": "template-transform",
|
||||
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get variables
|
||||
variables: dict[str, Any] = {}
|
||||
for variable_selector in self._node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
variables[variable_name] = value.to_object() if value else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters",
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = TemplateTransformNodeData.model_validate(node_data)
|
||||
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in typed_node_data.variables
|
||||
}
|
||||
3
dify/api/core/workflow/nodes/tool/__init__.py
Normal file
3
dify/api/core/workflow/nodes/tool/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .tool_node import ToolNode
|
||||
|
||||
__all__ = ["ToolNode"]
|
||||
84
dify/api/core/workflow/nodes/tool/entities.py
Normal file
84
dify/api/core/workflow/nodes/tool/entities.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: ToolProviderType
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
tool_label: str # redundancy
|
||||
tool_configurations: dict[str, Any]
|
||||
credential_id: str | None = None
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
@field_validator("tool_configurations", mode="before")
|
||||
@classmethod
|
||||
def validate_tool_configurations(cls, value, values: ValidationInfo):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("tool_configurations must be a dictionary")
|
||||
|
||||
for key in values.data.get("tool_configurations", {}):
|
||||
value = values.data.get("tool_configurations", {}).get(key)
|
||||
if not isinstance(value, str | int | float | bool):
|
||||
raise ValueError(f"{key} must be a string")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get("value")
|
||||
|
||||
if value is None:
|
||||
return typ
|
||||
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, str | int | float | bool | dict):
|
||||
raise ValueError("value must be a string, int, float, bool or dict")
|
||||
return typ
|
||||
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
# The version of the tool parameter.
|
||||
# If this value is None, it indicates this is a previous version
|
||||
# and requires using the legacy parameter parsing rules.
|
||||
tool_node_version: str | None = None
|
||||
|
||||
@field_validator("tool_parameters", mode="before")
|
||||
@classmethod
|
||||
def filter_none_tool_inputs(cls, value):
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
return {
|
||||
key: tool_input
|
||||
for key, tool_input in value.items()
|
||||
if tool_input is not None and cls._has_valid_value(tool_input)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _has_valid_value(tool_input):
|
||||
"""Check if the value is valid"""
|
||||
if isinstance(tool_input, dict):
|
||||
return tool_input.get("value") is not None
|
||||
return getattr(tool_input, "value", None) is not None
|
||||
16
dify/api/core/workflow/nodes/tool/exc.py
Normal file
16
dify/api/core/workflow/nodes/tool/exc.py
Normal file
@@ -0,0 +1,16 @@
|
||||
class ToolNodeError(ValueError):
|
||||
"""Base exception for tool node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ToolParameterError(ToolNodeError):
|
||||
"""Exception raised for errors in tool parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ToolFileError(ToolNodeError):
|
||||
"""Exception raised for errors related to tool files."""
|
||||
|
||||
pass
|
||||
521
dify/api/core/workflow/nodes/tool/tool_node.py
Normal file
521
dify/api/core/workflow/nodes/tool/tool_node.py
Normal file
@@ -0,0 +1,521 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .entities import ToolNodeData
|
||||
from .exc import (
|
||||
ToolFileError,
|
||||
ToolNodeError,
|
||||
ToolParameterError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
class ToolNode(Node):
|
||||
"""
|
||||
Tool Node
|
||||
"""
|
||||
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
_node_data: ToolNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ToolNodeData.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Run the tool node
|
||||
"""
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||
|
||||
node_data = self._node_data
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
"provider_type": node_data.provider_type.value,
|
||||
"provider_id": node_data.provider_id,
|
||||
"plugin_unique_identifier": node_data.plugin_unique_identifier,
|
||||
}
|
||||
|
||||
# get tool runtime
|
||||
try:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use the node_data.version field for judgment
|
||||
# But for backward compatibility with historical data
|
||||
# this version field judgment is still preserved here.
|
||||
variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to get tool runtime: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# get parameters
|
||||
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self._node_data,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self._node_data,
|
||||
for_log=True,
|
||||
)
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = ToolEngine.generic_invoke(
|
||||
tool=tool_runtime,
|
||||
tool_parameters=parameters,
|
||||
user_id=self.user_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# convert tool messages
|
||||
_ = yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
tool_info=tool_info,
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_id=self._node_id,
|
||||
tool_runtime=tool_runtime,
|
||||
)
|
||||
except ToolInvokeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool {node_data.provider_name}: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
except PluginInvokeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=e.to_user_friendly_error(plugin_name=node_data.provider_name),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool, error: {e.description}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
tool_parameters: Sequence[ToolParameter],
|
||||
variable_pool: "VariablePool",
|
||||
node_data: ToolNodeData,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
parameter = tool_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: "VariablePool") -> list[File]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
tool_runtime: Tool,
|
||||
) -> Generator[NodeEventBase, None, LLMUsage]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
# JSON message handling for tool node
|
||||
if message.message.json_object:
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
# Check if this LINK message is a file link
|
||||
file_obj = (message.meta or {}).get("file")
|
||||
if isinstance(file_obj, File):
|
||||
files.append(file_obj)
|
||||
stream_text = f"File: {message.message.text}\n"
|
||||
else:
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, dict)
|
||||
# Validate that meta contains a 'file' key
|
||||
if "file" not in message.meta:
|
||||
raise ToolNodeError("File message is missing 'file' key in meta")
|
||||
|
||||
# Validate that the file is an instance of File
|
||||
if not isinstance(message.meta["file"], File):
|
||||
raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json:
|
||||
json_output.extend(json)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
# Send final chunk events for all streamed outputs
|
||||
# Final chunk for text stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Final chunks for any streamed variables
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
usage = self._extract_tool_usage(tool_runtime)
|
||||
|
||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
}
|
||||
if usage.total_tokens > 0:
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
||||
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
||||
metadata=metadata,
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
|
||||
if isinstance(tool_runtime, WorkflowTool):
|
||||
return tool_runtime.latest_usage
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ToolNodeData.model_validate(node_data)
|
||||
|
||||
result = {}
|
||||
for parameter_name in typed_node_data.tool_parameters:
|
||||
input = typed_node_data.tool_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
3
dify/api/core/workflow/nodes/trigger_plugin/__init__.py
Normal file
3
dify/api/core/workflow/nodes/trigger_plugin/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .trigger_event_node import TriggerEventNode
|
||||
|
||||
__all__ = ["TriggerEventNode"]
|
||||
77
dify/api/core/workflow/nodes/trigger_plugin/entities.py
Normal file
77
dify/api/core/workflow/nodes/trigger_plugin/entities.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.trigger.entities.entities import EventParameter
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.trigger_plugin.exc import TriggerEventParameterError
|
||||
|
||||
|
||||
class TriggerEventNodeData(BaseNodeData):
|
||||
"""Plugin trigger node data"""
|
||||
|
||||
class TriggerEventInput(BaseModel):
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
type = value
|
||||
value = validation_info.data.get("value")
|
||||
|
||||
if value is None:
|
||||
return type
|
||||
|
||||
if type == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
|
||||
if type == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError("value must be a list of strings")
|
||||
|
||||
if type == "constant" and not isinstance(value, str | int | float | bool | dict | list):
|
||||
raise ValueError("value must be a string, int, float, bool or dict")
|
||||
return type
|
||||
|
||||
title: str
|
||||
desc: str | None = None
|
||||
plugin_id: str = Field(..., description="Plugin ID")
|
||||
provider_id: str = Field(..., description="Provider ID")
|
||||
event_name: str = Field(..., description="Event name")
|
||||
subscription_id: str = Field(..., description="Subscription ID")
|
||||
plugin_unique_identifier: str = Field(..., description="Plugin unique identifier")
|
||||
event_parameters: Mapping[str, TriggerEventInput] = Field(default_factory=dict, description="Trigger parameters")
|
||||
|
||||
def resolve_parameters(
|
||||
self,
|
||||
*,
|
||||
parameter_schemas: Mapping[str, EventParameter],
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given plugin trigger parameters.
|
||||
|
||||
Args:
|
||||
parameter_schemas (Mapping[str, EventParameter]): The mapping of parameter schemas.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in self.event_parameters:
|
||||
parameter: EventParameter | None = parameter_schemas.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
event_input = self.event_parameters[parameter_name]
|
||||
|
||||
# trigger node only supports constant input
|
||||
if event_input.type != "constant":
|
||||
raise TriggerEventParameterError(f"Unknown plugin trigger input type '{event_input.type}'")
|
||||
result[parameter_name] = event_input.value
|
||||
return result
|
||||
10
dify/api/core/workflow/nodes/trigger_plugin/exc.py
Normal file
10
dify/api/core/workflow/nodes/trigger_plugin/exc.py
Normal file
@@ -0,0 +1,10 @@
|
||||
class TriggerEventNodeError(ValueError):
|
||||
"""Base exception for plugin trigger node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TriggerEventParameterError(TriggerEventNodeError):
|
||||
"""Exception raised for errors in plugin trigger parameters."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,89 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import TriggerEventNodeData
|
||||
|
||||
|
||||
class TriggerEventNode(Node):
|
||||
node_type = NodeType.TRIGGER_PLUGIN
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
_node_data: TriggerEventNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = TriggerEventNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "plugin",
|
||||
"config": {
|
||||
"title": "",
|
||||
"plugin_id": "",
|
||||
"provider_id": "",
|
||||
"event_name": "",
|
||||
"subscription_id": "",
|
||||
"plugin_unique_identifier": "",
|
||||
"event_parameters": {},
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the plugin trigger node.
|
||||
|
||||
This node invokes the trigger to convert request data into events
|
||||
and makes them available to downstream nodes.
|
||||
"""
|
||||
|
||||
# Get trigger data passed when workflow was triggered
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
||||
"provider_id": self._node_data.provider_id,
|
||||
"event_name": self._node_data.event_name,
|
||||
"plugin_unique_identifier": self._node_data.plugin_unique_identifier,
|
||||
},
|
||||
}
|
||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
|
||||
|
||||
# TODO: System variables should be directly accessible, no need for special handling
|
||||
# Set system variables as node outputs.
|
||||
for var in system_inputs:
|
||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
|
||||
outputs = dict(node_inputs)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
outputs=outputs,
|
||||
metadata=metadata,
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode
|
||||
|
||||
__all__ = ["TriggerScheduleNode"]
|
||||
49
dify/api/core/workflow/nodes/trigger_schedule/entities.py
Normal file
49
dify/api/core/workflow/nodes/trigger_schedule/entities.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class TriggerScheduleNodeData(BaseNodeData):
|
||||
"""
|
||||
Trigger Schedule Node Data
|
||||
"""
|
||||
|
||||
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
|
||||
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
|
||||
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")
|
||||
visual_config: dict | None = Field(default=None, description="Visual configuration details")
|
||||
timezone: str = Field(default="UTC", description="Timezone for schedule execution")
|
||||
|
||||
|
||||
class ScheduleConfig(BaseModel):
|
||||
node_id: str
|
||||
cron_expression: str
|
||||
timezone: str = "UTC"
|
||||
|
||||
|
||||
class SchedulePlanUpdate(BaseModel):
|
||||
node_id: str | None = None
|
||||
cron_expression: str | None = None
|
||||
timezone: str | None = None
|
||||
|
||||
|
||||
class VisualConfig(BaseModel):
|
||||
"""Visual configuration for schedule trigger"""
|
||||
|
||||
# For hourly frequency
|
||||
on_minute: int | None = Field(default=0, ge=0, le=59, description="Minute of the hour (0-59)")
|
||||
|
||||
# For daily, weekly, monthly frequencies
|
||||
time: str | None = Field(default="12:00 AM", description="Time in 12-hour format (e.g., '2:30 PM')")
|
||||
|
||||
# For weekly frequency
|
||||
weekdays: list[Literal["sun", "mon", "tue", "wed", "thu", "fri", "sat"]] | None = Field(
|
||||
default=None, description="List of weekdays to run on"
|
||||
)
|
||||
|
||||
# For monthly frequency
|
||||
monthly_days: list[Union[int, Literal["last"]]] | None = Field(
|
||||
default=None, description="Days of month to run on (1-31 or 'last')"
|
||||
)
|
||||
31
dify/api/core/workflow/nodes/trigger_schedule/exc.py
Normal file
31
dify/api/core/workflow/nodes/trigger_schedule/exc.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from core.workflow.nodes.base.exc import BaseNodeError
|
||||
|
||||
|
||||
class ScheduleNodeError(BaseNodeError):
|
||||
"""Base schedule node error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScheduleNotFoundError(ScheduleNodeError):
|
||||
"""Schedule not found error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScheduleConfigError(ScheduleNodeError):
|
||||
"""Schedule configuration error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScheduleExecutionError(ScheduleNodeError):
|
||||
"""Schedule execution error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TenantOwnerNotFoundError(ScheduleExecutionError):
|
||||
"""Tenant owner not found error for schedule execution."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,69 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.trigger_schedule.entities import TriggerScheduleNodeData
|
||||
|
||||
|
||||
class TriggerScheduleNode(Node):
|
||||
node_type = NodeType.TRIGGER_SCHEDULE
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
_node_data: TriggerScheduleNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = TriggerScheduleNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "trigger-schedule",
|
||||
"config": {
|
||||
"mode": "visual",
|
||||
"frequency": "daily",
|
||||
"visual_config": {"time": "12:00 AM", "on_minute": 0, "weekdays": ["sun"], "monthly_days": [1]},
|
||||
"timezone": "UTC",
|
||||
},
|
||||
}
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
|
||||
|
||||
# TODO: System variables should be directly accessible, no need for special handling
|
||||
# Set system variables as node outputs.
|
||||
for var in system_inputs:
|
||||
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
|
||||
outputs = dict(node_inputs)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
outputs=outputs,
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user