This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from core.workflow.enums import NodeType
__all__ = ["NodeType"]

View File

@@ -0,0 +1,3 @@
from .agent_node import AgentNode
__all__ = ["AgentNode"]

View File

@@ -0,0 +1,756 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from packaging.version import Version
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.utils.encoders import jsonable_encoder
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import (
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import (
AgentLogEvent,
NodeEventBase,
NodeRunResult,
StreamChunkEvent,
StreamCompletedEvent,
)
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exc import (
AgentInputTypeError,
AgentInvocationError,
AgentMessageTransformError,
AgentNodeError,
AgentVariableNotFoundError,
AgentVariableTypeError,
ToolFileNotFoundError,
)
if TYPE_CHECKING:
from core.agent.strategy.plugin import PluginAgentStrategy
from core.plugin.entities.request import InvokeCredentials
class AgentNode(Node):
"""
Agent Node
"""
node_type = NodeType.AGENT
_node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AgentNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
try:
strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id,
agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
agent_strategy_name=self._node_data.agent_strategy_name,
)
except Exception as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error=f"Failed to get agent strategy: {str(e)}",
),
)
return
agent_parameters = strategy.get_parameters()
# get parameters
parameters = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data,
strategy=strategy,
)
parameters_for_log = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data,
for_log=True,
strategy=strategy,
)
credentials = self._generate_credentials(parameters=parameters)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = strategy.invoke(
params=parameters,
user_id=self.user_id,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
)
except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(error),
)
)
return
try:
yield from self._transform_message(
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": self._node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(transform_error),
)
)
def _generate_agent_parameters(
self,
*,
agent_parameters: Sequence[AgentStrategyParameter],
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
strategy: "PluginAgentStrategy",
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
agent_parameters (Sequence[AgentParameter]): The list of agent parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (AgentNodeData): The data associated with the agent node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
agent_input = node_data.agent_parameters[parameter_name]
if agent_input.type == "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
elif agent_input.type in {"mixed", "constant"}:
# variable_pool.convert_template expects a string template,
# but if passing a dict, convert to JSON string first before rendering
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
# variable_pool.convert_template returns a string,
# so we need to convert it back to a dictionary
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
else:
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_type=provider_type,
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
)
extra = tool.get("extra", {})
# This is an issue that caused problems before.
# Logically, we shouldn't use the node_data.version field for judgment
# But for backward compatibility with historical data
# this version field judgment is still preserved here.
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
runtime_variable_pool = variable_pool
tool_runtime = ToolManager.get_agent_tool_runtime(
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
)
if tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
extra.get("description", "") or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
ToolParameter.ToolParameterForm.FORM
if tool_runtime_params.name in manual_input_params
else tool_runtime_params.form
)
manual_input_value = {}
if tool_runtime.entity.parameters:
manual_input_value = {
key: value for key, value in parameters.items() if key in manual_input_params
}
runtime_parameters = {
**tool_runtime.runtime.runtime_parameters,
**manual_input_value,
}
tool_value.append(
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
model_instance, model_schema = self._fetch_model(value)
# memory config
history_prompt_messages = []
if node_data.memory:
memory = self._fetch_memory(model_instance)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size or None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
value["history_prompt_messages"] = history_prompt_messages
if model_schema:
# remove structured output feature to support old version agent plugin
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
value["entity"] = model_schema.model_dump(mode="json")
else:
value["entity"] = None
result[parameter_name] = value
return result
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> "InvokeCredentials":
"""
Generate credentials based on the given agent parameters.
"""
from core.plugin.entities.request import InvokeCredentials
credentials = InvokeCredentials()
# generate credentials for tools selector
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
if tool.get("credential_id"):
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
except ValidationError:
continue
return credentials
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = AgentNodeData.model_validate(node_data)
result: dict[str, Any] = {}
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
if input.type in ["mixed", "constant"]:
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()}
return result
@property
def agent_strategy_icon(self) -> str | None:
"""
Get agent strategy icon
:return:
"""
from core.plugin.impl.plugin import PluginInstaller
manager = PluginInstaller()
plugins = manager.list_plugins(self.tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:
icon = None
return icon
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
)
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM, model=model_name
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_instance, model_schema
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
if model_schema.features:
for feature in model_schema.features[:]: # Create a copy to safely modify during iteration
try:
AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value
except ValueError:
model_schema.features.remove(feature)
return model_schema
def _filter_mcp_type_tool(
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""
Filter MCP type tool
:param strategy: plugin agent strategy
:param tool: tool
:return: filtered tool dict
"""
meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator[NodeEventBase, None, None]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json_list: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage = LLMUsage.empty_usage()
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise AgentVariableTypeError(
"When 'stream' is True, 'variable_value' must be a string.",
variable_name=variable_name,
expected_type="str",
actual_type=type(variable_value).__name__,
)
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, dict)
# Validate that meta contains a 'file' key
if "file" not in message.meta:
raise AgentNodeError("File message is missing 'file' key in meta")
# Validate that the file is an instance of File
if not isinstance(message.meta["file"], File):
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
message_id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.message_id == agent_log.message_id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.message_id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json_list:
json_output.extend(json_list)
else:
json_output.append({"data": []})
# Send final chunk events for all streamed outputs
# Final chunk for text stream
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk="",
is_final=True,
)
# Final chunks for any streamed variables
for var_name in variables:
yield StreamChunkEvent(
selector=[node_id, var_name],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"text": text,
"usage": jsonable_encoder(llm_usage),
"files": ArrayFileSegment(value=files),
"json": json_output,
**variables,
},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)

View File

@@ -0,0 +1,45 @@
from enum import IntEnum, StrEnum, auto
from typing import Any, Literal, Union
from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from core.workflow.nodes.base.entities import BaseNodeData
class AgentNodeData(BaseNodeData):
agent_strategy_provider_name: str # redundancy
agent_strategy_name: str
agent_strategy_label: str # redundancy
memory: MemoryConfig | None = None
# The version of the tool parameter.
# If this value is None, it indicates this is a previous version
# and requires using the legacy parameter parsing rules.
tool_node_version: str | None = None
class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]
type: Literal["mixed", "variable", "constant"]
agent_parameters: dict[str, AgentInput]
class ParamsAutoGenerated(IntEnum):
CLOSE = 0
OPEN = 1
class AgentOldVersionModelFeatures(StrEnum):
"""
Enum class for old SDK version llm feature.
"""
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = auto()
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = auto()
VIDEO = auto()
AUDIO = auto()

View File

@@ -0,0 +1,121 @@
class AgentNodeError(Exception):
"""Base exception for all agent node errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class AgentStrategyError(AgentNodeError):
"""Exception raised when there's an error with the agent strategy."""
def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None):
self.strategy_name = strategy_name
self.provider_name = provider_name
super().__init__(message)
class AgentStrategyNotFoundError(AgentStrategyError):
"""Exception raised when the specified agent strategy is not found."""
def __init__(self, strategy_name: str, provider_name: str | None = None):
super().__init__(
f"Agent strategy '{strategy_name}' not found"
+ (f" for provider '{provider_name}'" if provider_name else ""),
strategy_name,
provider_name,
)
class AgentInvocationError(AgentNodeError):
"""Exception raised when there's an error invoking the agent."""
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
class AgentParameterError(AgentNodeError):
"""Exception raised when there's an error with agent parameters."""
def __init__(self, message: str, parameter_name: str | None = None):
self.parameter_name = parameter_name
super().__init__(message)
class AgentVariableError(AgentNodeError):
"""Exception raised when there's an error with variables in the agent node."""
def __init__(self, message: str, variable_name: str | None = None):
self.variable_name = variable_name
super().__init__(message)
class AgentVariableNotFoundError(AgentVariableError):
"""Exception raised when a variable is not found in the variable pool."""
def __init__(self, variable_name: str):
super().__init__(f"Variable '{variable_name}' does not exist", variable_name)
class AgentInputTypeError(AgentNodeError):
"""Exception raised when an unknown agent input type is encountered."""
def __init__(self, input_type: str):
super().__init__(f"Unknown agent input type '{input_type}'")
class ToolFileError(AgentNodeError):
"""Exception raised when there's an error with a tool file."""
def __init__(self, message: str, file_id: str | None = None):
self.file_id = file_id
super().__init__(message)
class ToolFileNotFoundError(ToolFileError):
"""Exception raised when a tool file is not found."""
def __init__(self, file_id: str):
super().__init__(f"Tool file '{file_id}' does not exist", file_id)
class AgentMessageTransformError(AgentNodeError):
"""Exception raised when there's an error transforming agent messages."""
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
class AgentModelError(AgentNodeError):
"""Exception raised when there's an error with the model used by the agent."""
def __init__(self, message: str, model_name: str | None = None, provider: str | None = None):
self.model_name = model_name
self.provider = provider
super().__init__(message)
class AgentMemoryError(AgentNodeError):
"""Exception raised when there's an error with the agent's memory."""
def __init__(self, message: str, conversation_id: str | None = None):
self.conversation_id = conversation_id
super().__init__(message)
class AgentVariableTypeError(AgentNodeError):
"""Exception raised when a variable has an unexpected type."""
def __init__(
self,
message: str,
variable_name: str | None = None,
expected_type: str | None = None,
actual_type: str | None = None,
):
self.variable_name = variable_name
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)

View File

@@ -0,0 +1,96 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.variables import ArrayFileSegment, FileSegment, Segment
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.answer.entities import AnswerNodeData
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
class AnswerNode(Node):
node_type = NodeType.ANSWER
execution_type = NodeExecutionType.RESPONSE
_node_data: AnswerNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AnswerNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer)
files = self._extract_files_from_segments(segments.value)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)},
)
def _extract_files_from_segments(self, segments: Sequence[Segment]):
"""Extract all files from segments containing FileSegment or ArrayFileSegment instances.
FileSegment contains a single file, while ArrayFileSegment contains multiple files.
This method flattens all files into a single list.
"""
files = []
for segment in segments:
if isinstance(segment, FileSegment):
# Single file - wrap in list for consistency
files.append(segment.value)
elif isinstance(segment, ArrayFileSegment):
# Multiple files - extend the list
files.extend(segment.value)
return files
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = AnswerNodeData.model_validate(node_data)
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
return variable_mapping
def get_streaming_template(self) -> Template:
"""
Get the template for streaming.
Returns:
Template instance for this Answer node
"""
return Template.from_answer_template(self._node_data.answer)

View File

@@ -0,0 +1,65 @@
from collections.abc import Sequence
from enum import StrEnum, auto
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
class AnswerNodeData(BaseNodeData):
"""
Answer Node Data.
"""
answer: str = Field(..., description="answer template string")
class GenerateRouteChunk(BaseModel):
"""
Generate Route Chunk.
"""
class ChunkType(StrEnum):
VAR = auto()
TEXT = auto()
type: ChunkType = Field(..., description="generate route chunk type")
class VarGenerateRouteChunk(GenerateRouteChunk):
"""
Var Generate Route Chunk.
"""
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
"""generate route chunk type"""
value_selector: Sequence[str] = Field(..., description="value selector")
class TextGenerateRouteChunk(GenerateRouteChunk):
"""
Text Generate Route Chunk.
"""
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT
"""generate route chunk type"""
text: str = Field(..., description="text")
class AnswerNodeDoubleLink(BaseModel):
node_id: str = Field(..., description="node id")
source_node_ids: list[str] = Field(..., description="source node ids")
target_node_ids: list[str] = Field(..., description="target node ids")
class AnswerStreamGenerateRoute(BaseModel):
"""
AnswerStreamGenerateRoute entity
"""
answer_dependencies: dict[str, list[str]] = Field(
..., description="answer dependencies (answer node id -> dependent answer node ids)"
)
answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
..., description="answer generate route (answer node id -> generate route chunks)"
)

View File

@@ -0,0 +1,11 @@
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .usage_tracking_mixin import LLMUsageTrackingMixin
__all__ = [
"BaseIterationNodeData",
"BaseIterationState",
"BaseLoopNodeData",
"BaseLoopState",
"BaseNodeData",
"LLMUsageTrackingMixin",
]

View File

@@ -0,0 +1,171 @@
import json
from abc import ABC
from builtins import type as type_
from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Union
from pydantic import BaseModel, model_validator
from core.workflow.enums import ErrorStrategy
from .exc import DefaultValueTypeError
_NumberType = Union[int, float]
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class VariableSelector(BaseModel):
"""
Variable Selector.
"""
variable: str
value_selector: Sequence[str]
class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_NUMBER = "array[number]"
ARRAY_STRING = "array[string]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILES = "array[file]"
class DefaultValue(BaseModel):
value: Any = None
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str):
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> "DefaultValue":
# Type validation configuration
type_validators = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": _NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": _NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}
validator: dict[str, Any] = type_validators.get(self.type, {})
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)
# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
return self
class BaseNodeData(ABC, BaseModel):
title: str
desc: str | None = None
version: str = "1"
error_strategy: ErrorStrategy | None = None
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
@property
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}
class BaseIterationNodeData(BaseNodeData):
start_node_id: str | None = None
class BaseIterationState(BaseModel):
iteration_node_id: str
index: int
inputs: dict
class MetaData(BaseModel):
pass
metadata: MetaData
class BaseLoopNodeData(BaseNodeData):
start_node_id: str | None = None
class BaseLoopState(BaseModel):
loop_node_id: str
index: int
inputs: dict
class MetaData(BaseModel):
pass
metadata: MetaData

View File

@@ -0,0 +1,10 @@
class BaseNodeError(ValueError):
"""Base class for node errors."""
pass
class DefaultValueTypeError(BaseNodeError):
"""Raised when the default value type is invalid."""
pass

View File

@@ -0,0 +1,538 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
from typing import Any, ClassVar
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunAgentLogEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunPauseRequestedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import (
AgentLogEvent,
IterationFailedEvent,
IterationNextEvent,
IterationStartedEvent,
IterationSucceededEvent,
LoopFailedEvent,
LoopNextEvent,
LoopStartedEvent,
LoopSucceededEvent,
NodeEventBase,
NodeRunResult,
PauseRequestedEvent,
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
)
from core.workflow.runtime import GraphRuntimeState
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
from .entities import BaseNodeData, RetryConfig
logger = logging.getLogger(__name__)
class Node:
node_type: ClassVar["NodeType"]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
self.workflow_id = graph_init_params.workflow_id
self.graph_config = graph_init_params.graph_config
self.user_id = graph_init_params.user_id
self.user_from = UserFrom(graph_init_params.user_from)
self.invoke_from = InvokeFrom(graph_init_params.invoke_from)
self.workflow_call_depth = graph_init_params.call_depth
self.graph_runtime_state = graph_runtime_state
self.state: NodeState = NodeState.UNKNOWN # node execution state
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required.")
self._node_id = node_id
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
@abstractmethod
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
@abstractmethod
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
"""
Run node
:return:
"""
raise NotImplementedError
def run(self) -> Generator[GraphNodeEventBase, None, None]:
# Generate a single node execution ID to use for all events
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
self._start_at = naive_utc_now()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.title,
in_iteration_id=None,
start_at=self._start_at,
)
# === FIXME(-LAN-): Needs to refactor.
from core.workflow.nodes.tool.tool_node import ToolNode
if isinstance(self, ToolNode):
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
if isinstance(self, DatasourceNode):
plugin_id = getattr(self.get_base_node_data(), "plugin_id", "")
provider_name = getattr(self.get_base_node_data(), "provider_name", "")
start_event.provider_id = f"{plugin_id}/{provider_name}"
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
if isinstance(self, TriggerEventNode):
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
from typing import cast
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.agent.entities import AgentNodeData
if isinstance(self, AgentNode):
start_event.agent_strategy = AgentNodeStrategyInit(
name=cast(AgentNodeData, self.get_base_node_data()).agent_strategy_name,
icon=self.agent_strategy_icon,
)
# ===
yield start_event
try:
result = self._run()
# Handle NodeRunResult
if isinstance(result, NodeRunResult):
yield self._convert_node_run_result_to_graph_node_event(result)
return
# Handle event stream
for event in result:
# NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
event.id = self._node_execution_id
yield event
else:
yield event
except Exception as e:
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
error_type="WorkflowNodeError",
)
yield NodeRunFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=result,
error=str(e),
)
@classmethod
def extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
config: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""Extracts references variable selectors from node configuration.
The `config` parameter represents the configuration for a specific node type and corresponds
to the `data` field in the node definition object.
The returned mapping has the following structure:
{'1747829548239.#1747829667553.result#': ['1747829667553', 'result']}
For loop and iteration nodes, the mapping may look like this:
{
"1748332301644.input_selector": ["1748332363630", "result"],
"1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"],
}
where `1748332301644` is the ID of the loop / iteration node,
and `1748332325079` is the ID of the node inside the loop or iteration node.
Here, the key consists of two parts: the current node ID (provided as the `node_id`
parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector,
enclosed in `#` symbols. These two parts are separated by a dot (`.`).
The value is a list of string representing the variable selector, where the first element is the node ID
of the referenced variable, and the second element is the variable name within that node.
The meaning of the above response is:
The node with ID `1747829548239` references the variable `result` from the node with
ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a
reference to the `result` output variable of node `1747829667553`.
:param graph_config: graph config
:param config: node config
:return:
"""
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
# Pass raw dict data instead of creating NodeData instance
data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
)
return data
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
return {}
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this node blocks the output of specific variables.
This method is used to determine if a node must complete execution before
the specified variables can be used in streaming output.
:param variable_selectors: Set of variable selectors, each as a tuple (e.g., ('conversation', 'str'))
:return: True if this node blocks output of any of the specified variables, False otherwise
"""
return False
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {}
@classmethod
@abstractmethod
def version(cls) -> str:
"""`node_version` returns the version of current node type."""
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
#
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
# in `api/core/workflow/nodes/__init__.py`.
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@property
def retry(self) -> bool:
return False
# Abstract methods that subclasses must implement to provide access
# to BaseNodeData properties in a type-safe way
@abstractmethod
def _get_error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
...
@abstractmethod
def _get_retry_config(self) -> RetryConfig:
"""Get the retry configuration for this node."""
...
@abstractmethod
def _get_title(self) -> str:
"""Get the node title."""
...
@abstractmethod
def _get_description(self) -> str | None:
"""Get the node description."""
...
@abstractmethod
def _get_default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
...
@abstractmethod
def get_base_node_data(self) -> BaseNodeData:
"""Get the BaseNodeData object for this node."""
...
# Public interface properties that delegate to abstract methods
@property
def error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
return self._get_error_strategy()
@property
def retry_config(self) -> RetryConfig:
"""Get the retry configuration for this node."""
return self._get_retry_config()
@property
def title(self) -> str:
"""Get the node title."""
return self._get_title()
@property
def description(self) -> str | None:
"""Get the node description."""
return self._get_description()
@property
def default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
return self._get_default_value_dict()
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
match result.status:
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
id=self._node_execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=result,
error=result.error,
)
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
id=self._node_execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=result,
)
case _:
raise Exception(f"result status {result.status} not supported")
@singledispatchmethod
def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase:
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
@_dispatch.register
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
)
@_dispatch.register
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
match event.node_run_result.status:
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=event.node_run_result,
)
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=event.node_run_result,
error=event.node_run_result.error,
)
case _:
raise NotImplementedError(
f"Node {self._node_id} does not support status {event.node_run_result.status}"
)
@_dispatch.register
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
return NodeRunPauseRequestedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
reason=event.reason,
)
@_dispatch.register
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
return NodeRunAgentLogEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
message_id=event.message_id,
label=event.label,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
)
@_dispatch.register
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
return NodeRunLoopStartedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
metadata=event.metadata,
predecessor_node_id=event.predecessor_node_id,
)
@_dispatch.register
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
return NodeRunLoopNextEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
index=event.index,
pre_loop_output=event.pre_loop_output,
)
@_dispatch.register
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
return NodeRunLoopSucceededEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
)
@_dispatch.register
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
return NodeRunLoopFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error,
)
@_dispatch.register
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
return NodeRunIterationStartedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
metadata=event.metadata,
predecessor_node_id=event.predecessor_node_id,
)
@_dispatch.register
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
return NodeRunIterationNextEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
index=event.index,
pre_iteration_output=event.pre_iteration_output,
)
@_dispatch.register
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
return NodeRunIterationSucceededEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
)
@_dispatch.register
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
return NodeRunIterationFailedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error,
)
@_dispatch.register
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
return NodeRunRetrieverResourceEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
retriever_resources=event.retriever_resources,
context=event.context,
node_version=self.version(),
)

View File

@@ -0,0 +1,148 @@
"""Template structures for Response nodes (Answer and End).
This module provides a unified template structure for both Answer and End nodes,
similar to SegmentGroup but focused on template representation without values.
"""
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Union
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
@dataclass(frozen=True)
class TemplateSegment(ABC):
"""Base class for template segments."""
@abstractmethod
def __str__(self) -> str:
"""String representation of the segment."""
pass
@dataclass(frozen=True)
class TextSegment(TemplateSegment):
"""A text segment in a template."""
text: str
def __str__(self) -> str:
return self.text
@dataclass(frozen=True)
class VariableSegment(TemplateSegment):
"""A variable reference segment in a template."""
selector: Sequence[str]
variable_name: str | None = None # Optional variable name for End nodes
def __str__(self) -> str:
return "{{#" + ".".join(self.selector) + "#}}"
# Type alias for segments
TemplateSegmentUnion = Union[TextSegment, VariableSegment]
@dataclass(frozen=True)
class Template:
"""Unified template structure for Response nodes.
Similar to SegmentGroup, but represents the template structure
without variable values - only marking variable selectors.
"""
segments: list[TemplateSegmentUnion]
@classmethod
def from_answer_template(cls, template_str: str) -> "Template":
"""Create a Template from an Answer node template string.
Example:
"Hello, {{#node1.name#}}" -> [TextSegment("Hello, "), VariableSegment(["node1", "name"])]
Args:
template_str: The answer template string
Returns:
Template instance
"""
parser = VariableTemplateParser(template_str)
segments: list[TemplateSegmentUnion] = []
# Extract variable selectors to find all variables
variable_selectors = parser.extract_variable_selectors()
var_map = {var.variable: var.value_selector for var in variable_selectors}
# Parse template to get ordered segments
# We need to split the template by variable placeholders while preserving order
import re
# Create a regex pattern that matches variable placeholders
pattern = r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}"
# Split template while keeping the delimiters (variable placeholders)
parts = re.split(pattern, template_str)
for i, part in enumerate(parts):
if not part:
continue
# Check if this part is a variable reference (odd indices after split)
if i % 2 == 1: # Odd indices are variable keys
# Remove the # symbols from the variable key
var_key = part
if var_key in var_map:
segments.append(VariableSegment(selector=list(var_map[var_key])))
else:
# This shouldn't happen with valid templates
segments.append(TextSegment(text="{{" + part + "}}"))
else:
# Even indices are text segments
segments.append(TextSegment(text=part))
return cls(segments=segments)
@classmethod
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
"""Create a Template from an End node outputs configuration.
End nodes are treated as templates of concatenated variables with newlines.
Example:
[{"variable": "text", "value_selector": ["node1", "text"]},
{"variable": "result", "value_selector": ["node2", "result"]}]
->
[VariableSegment(["node1", "text"]),
TextSegment("\n"),
VariableSegment(["node2", "result"])]
Args:
outputs_config: List of output configurations with variable and value_selector
Returns:
Template instance
"""
segments: list[TemplateSegmentUnion] = []
for i, output in enumerate(outputs_config):
if i > 0:
# Add newline separator between variables
segments.append(TextSegment(text="\n"))
value_selector = output.get("value_selector", [])
variable_name = output.get("variable", "")
if value_selector:
segments.append(VariableSegment(selector=list(value_selector), variable_name=variable_name))
if len(segments) > 0 and isinstance(segments[-1], TextSegment):
segments = segments[:-1]
return cls(segments=segments)
def __str__(self) -> str:
"""String representation of the template."""
return "".join(str(segment) for segment in self.segments)

View File

@@ -0,0 +1,28 @@
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime import GraphRuntimeState
class LLMUsageTrackingMixin:
"""Provides shared helpers for merging and recording LLM usage within workflow nodes."""
graph_runtime_state: GraphRuntimeState
@staticmethod
def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage:
"""Return a combined usage snapshot, preserving zero-value inputs."""
if new_usage is None or new_usage.total_tokens <= 0:
return current
if current.total_tokens == 0:
return new_usage
return current.plus(new_usage)
def _accumulate_usage(self, usage: LLMUsage) -> None:
"""Push usage into the graph runtime accumulator for downstream reporting."""
if usage.total_tokens <= 0:
return
current_usage = self.graph_runtime_state.llm_usage
if current_usage.total_tokens == 0:
self.graph_runtime_state.llm_usage = usage.model_copy()
else:
self.graph_runtime_state.llm_usage = current_usage.plus(usage)

View File

@@ -0,0 +1,130 @@
import re
from collections.abc import Mapping, Sequence
from typing import Any
from .entities import VariableSelector
REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]:
parts = SELECTOR_PATTERN.split(template)
selectors = []
for part in filter(lambda x: x, parts):
if "." in part and part[0] == "#" and part[-1] == "#":
selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split(".")))
return selectors
class VariableTemplateParser:
"""
!NOTE: Consider to use the new `segments` module instead of this class.
A class for parsing and manipulating template variables in a string.
Rules:
1. Template variables must be enclosed in `{{}}`.
2. The template variable Key can only be: #node_id.var1.var2#.
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
Example usage:
template = "Hello, {{#node_id.query.name#}}! Your age is {{#node_id.query.age#}}."
parser = VariableTemplateParser(template)
# Extract template variable keys
variable_keys = parser.extract()
print(variable_keys)
# Output: ['#node_id.query.name#', '#node_id.query.age#']
# Extract variable selectors
variable_selectors = parser.extract_variable_selectors()
print(variable_selectors)
# Output: [VariableSelector(variable='#node_id.query.name#', value_selector=['node_id', 'query', 'name']),
# VariableSelector(variable='#node_id.query.age#', value_selector=['node_id', 'query', 'age'])]
# Format the template string
inputs = {'#node_id.query.name#': 'John', '#node_id.query.age#': 25}}
formatted_string = parser.format(inputs)
print(formatted_string)
# Output: "Hello, John! Your age is 25."
"""
def __init__(self, template: str):
self.template = template
self.variable_keys = self.extract()
def extract(self):
"""
Extracts all the template variable keys from the template string.
Returns:
A list of template variable keys.
"""
# Regular expression to match the template rules
matches = re.findall(REGEX, self.template)
first_group_matches = [match[0] for match in matches]
return list(set(first_group_matches))
def extract_variable_selectors(self) -> list[VariableSelector]:
"""
Extracts the variable selectors from the template variable keys.
Returns:
A list of VariableSelector objects representing the variable selectors.
"""
variable_selectors = []
for variable_key in self.variable_keys:
remove_hash = variable_key.replace("#", "")
split_result = remove_hash.split(".")
if len(split_result) < 2:
continue
variable_selectors.append(VariableSelector(variable=variable_key, value_selector=split_result))
return variable_selectors
def format(self, inputs: Mapping[str, Any]) -> str:
"""
Formats the template string by replacing the template variables with their corresponding values.
Args:
inputs: A dictionary containing the values for the template variables.
Returns:
The formatted string with template variables replaced by their values.
"""
def replacer(match):
key = match.group(1)
value = inputs.get(key, match.group(0)) # return original matched string if key not found
if value is None:
value = ""
# convert the value to string
if isinstance(value, list | dict | bool | int | float):
value = str(value)
# remove template variables if required
return VariableTemplateParser.remove_template_variables(value)
prompt = re.sub(REGEX, replacer, self.template)
return re.sub(r"<\|.*?\|>", "", prompt)
@classmethod
def remove_template_variables(cls, text: str):
"""
Removes the template variables from the given text.
Args:
text: The text from which to remove the template variables.
Returns:
The text with template variables removed.
"""
return re.sub(REGEX, r"{\1}", text)

View File

@@ -0,0 +1,3 @@
from .code_node import CodeNode
__all__ = ["CodeNode"]

View File

@@ -0,0 +1,444 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
from typing import Any, cast
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.variables.types import SegmentType
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
from .exc import (
CodeNodeError,
DepthLimitError,
OutputValidationError,
)
class CodeNode(Node):
node_type = NodeType.CODE
_node_data: CodeNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = CodeNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
code_language = CodeLanguage.PYTHON3
if filters:
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
return code_provider.get_default_config()
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
# Get code language
code_language = self._node_data.code_language
code = self._node_data.code
# Get variables
variables = {}
for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if isinstance(variable, ArrayFileSegment):
variables[variable_name] = [v.to_dict() for v in variable.value] if variable.value else None
else:
variables[variable_name] = variable.to_object() if variable else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=code_language,
code=code,
inputs=variables,
)
# Transform result
result = self._transform_result(result=result, output_schema=self._node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string
:param value: value
:param variable: variable
:return:
"""
if value is None:
return None
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
raise OutputValidationError(
f"The length of output variable `{variable}` must be"
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
)
return value.replace("\x00", "")
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
if value is None:
return None
return value
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
"""
Check number
:param value: value
:param variable: variable
:return:
"""
if value is None:
return None
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
raise OutputValidationError(
f"Output variable `{variable}` is out of range,"
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
)
if isinstance(value, float):
decimal_value = Decimal(str(value)).normalize()
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
# raise error if precision is too high
if precision > dify_config.CODE_MAX_PRECISION:
raise OutputValidationError(
f"Output variable `{variable}` has too high precision,"
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
)
return value
def _transform_result(
self,
result: Mapping[str, Any],
output_schema: dict[str, CodeNodeData.Output] | None,
prefix: str = "",
depth: int = 1,
):
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
# Note that `_transform_result` may produce lists containing `None` values,
# which don't conform to the type requirements of `Array*Segment` classes.
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
transformed_result: dict[str, Any] = {}
if output_schema is None:
# validate output thought instance type
for output_name, output_value in result.items():
if isinstance(output_value, dict):
self._transform_result(
result=output_value,
output_schema=None,
prefix=f"{prefix}.{output_name}" if prefix else output_name,
depth=depth + 1,
)
elif isinstance(output_value, bool):
self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name)
elif isinstance(output_value, int | float):
self._check_number(
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
)
elif isinstance(output_value, str):
self._check_string(
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
)
elif isinstance(output_value, list):
first_element = output_value[0] if len(output_value) > 0 else None
if first_element is not None:
if isinstance(first_element, int | float) and all(
value is None or isinstance(value, int | float) for value in output_value
):
for i, value in enumerate(output_value):
self._check_number(
value=value,
variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
)
elif isinstance(first_element, str) and all(
value is None or isinstance(value, str) for value in output_value
):
for i, value in enumerate(output_value):
self._check_string(
value=value,
variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
)
elif (
isinstance(first_element, dict)
and all(value is None or isinstance(value, dict) for value in output_value)
or isinstance(first_element, list)
and all(value is None or isinstance(value, list) for value in output_value)
):
for i, value in enumerate(output_value):
if value is not None:
self._transform_result(
result=value,
output_schema=None,
prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
depth=depth + 1,
)
else:
raise OutputValidationError(
f"Output {prefix}.{output_name} is not a valid array."
f" make sure all elements are of the same type."
)
elif output_value is None:
pass
else:
raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.")
return result
parameters_validated = {}
for output_name, output_config in output_schema.items():
dot = "." if prefix else ""
if output_name not in result:
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
if output_config.type == SegmentType.OBJECT:
# check if output is object
if not isinstance(result.get(output_name), dict):
if result[output_name] is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an object,"
f" got {type(result.get(output_name))} instead."
)
else:
transformed_result[output_name] = self._transform_result(
result=result[output_name],
output_schema=output_config.children,
prefix=f"{prefix}.{output_name}",
depth=depth + 1,
)
elif output_config.type == SegmentType.NUMBER:
# check if number available
value = result.get(output_name)
if value is not None and not isinstance(value, (int, float)):
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not a number,"
f" got {type(result.get(output_name))} instead."
)
checked = self._check_number(value=value, variable=f"{prefix}{dot}{output_name}")
# If the output is a boolean and the output schema specifies a NUMBER type,
# convert the boolean value to an integer.
#
# This ensures compatibility with existing workflows that may use
# `True` and `False` as values for NUMBER type outputs.
transformed_result[output_name] = self._convert_boolean_to_int(checked)
elif output_config.type == SegmentType.STRING:
# check if string available
value = result.get(output_name)
if value is not None and not isinstance(value, str):
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} must be a string, got {type(value).__name__} instead"
)
transformed_result[output_name] = self._check_string(
value=value,
variable=f"{prefix}{dot}{output_name}",
)
elif output_config.type == SegmentType.BOOLEAN:
transformed_result[output_name] = self._check_boolean(
value=result[output_name],
variable=f"{prefix}{dot}{output_name}",
)
elif output_config.type == SegmentType.ARRAY_NUMBER:
# check if array of number available
value = result[output_name]
if not isinstance(value, list):
if value is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
)
else:
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
)
for i, inner_value in enumerate(value):
if not isinstance(inner_value, (int, float)):
raise OutputValidationError(
f"The element at index {i} of output variable `{prefix}{dot}{output_name}` must be"
f" a number."
)
_ = self._check_number(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]")
transformed_result[output_name] = [
# If the element is a boolean and the output schema specifies a `array[number]` type,
# convert the boolean value to an integer.
#
# This ensures compatibility with existing workflows that may use
# `True` and `False` as values for NUMBER type outputs.
self._convert_boolean_to_int(v)
for v in value
]
elif output_config.type == SegmentType.ARRAY_STRING:
# check if array of string available
if not isinstance(result[output_name], list):
if result[output_name] is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
)
transformed_result[output_name] = [
self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
for i, value in enumerate(result[output_name])
]
elif output_config.type == SegmentType.ARRAY_OBJECT:
# check if array of object available
if not isinstance(result[output_name], list):
if result[output_name] is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
)
for i, value in enumerate(result[output_name]):
if not isinstance(value, dict):
if value is None:
pass
else:
raise OutputValidationError(
f"Output {prefix}{dot}{output_name}[{i}] is not an object,"
f" got {type(value)} instead at index {i}."
)
transformed_result[output_name] = [
None
if value is None
else self._transform_result(
result=value,
output_schema=output_config.children,
prefix=f"{prefix}{dot}{output_name}[{i}]",
depth=depth + 1,
)
for i, value in enumerate(result[output_name])
]
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
# check if array of object available
value = result[output_name]
if not isinstance(value, list):
if value is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
for i, inner_value in enumerate(value):
if inner_value is not None and not isinstance(inner_value, bool):
raise OutputValidationError(
f"Output {prefix}{dot}{output_name}[{i}] is not a boolean,"
f" got {type(inner_value)} instead."
)
_ = self._check_boolean(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]")
transformed_result[output_name] = value
else:
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
parameters_validated[output_name] = True
# check if all output parameters are validated
if len(parameters_validated) != len(result):
raise CodeNodeError("Not all output parameters are validated.")
return transformed_result
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
_ = graph_config # Explicitly mark as unused
# Create typed NodeData from dict
typed_node_data = CodeNodeData.model_validate(node_data)
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in typed_node_data.variables
}
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
@staticmethod
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
"""This function convert boolean to integers when the output schema specifies a NUMBER type.
This ensures compatibility with existing workflows that may use
`True` and `False` as values for NUMBER type outputs.
"""
if value is None:
return None
if isinstance(value, bool):
return int(value)
return value

View File

@@ -0,0 +1,47 @@
from typing import Annotated, Literal, Self
from pydantic import AfterValidator, BaseModel
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
[
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
]
)
def _validate_type(segment_type: SegmentType) -> SegmentType:
if segment_type not in _ALLOWED_OUTPUT_FROM_CODE:
raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}")
return segment_type
class CodeNodeData(BaseNodeData):
"""
Code Node Data.
"""
class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: dict[str, Self] | None = None
class Dependency(BaseModel):
name: str
version: str
variables: list[VariableSelector]
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
code: str
outputs: dict[str, Output]
dependencies: list[Dependency] | None = None

View File

@@ -0,0 +1,16 @@
class CodeNodeError(ValueError):
"""Base class for code node errors."""
pass
class OutputValidationError(CodeNodeError):
"""Raised when there is an output validation error."""
pass
class DepthLimitError(CodeNodeError):
"""Raised when the depth limit is reached."""
pass

View File

@@ -0,0 +1,3 @@
from .datasource_node import DatasourceNode
__all__ = ["DatasourceNode"]

View File

@@ -0,0 +1,502 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.datasource.entities.datasource_entities import (
DatasourceMessage,
DatasourceParameter,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDriveDownloadFileRequest,
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from core.file import File
from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.tool.exc import ToolFileError
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from factories import file_factory
from models.model import UploadFile
from models.tools import ToolFile
from services.datasource_provider_service import DatasourceProviderService
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from .entities import DatasourceNodeData
from .exc import DatasourceNodeError, DatasourceParameterError
class DatasourceNode(Node):
"""
Datasource Node
"""
_node_data: DatasourceNodeData
node_type = NodeType.DATASOURCE
execution_type = NodeExecutionType.ROOT
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = DatasourceNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self) -> Generator:
"""
Run the datasource node
"""
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
if not datasource_type_segement:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
if not datasource_info_segement:
raise DatasourceNodeError("Datasource info is not set")
datasource_info_value = datasource_info_segement.value
if not isinstance(datasource_info_value, dict):
raise DatasourceNodeError("Invalid datasource info format")
datasource_info: dict[str, Any] = datasource_info_value
# get datasource runtime
from core.datasource.datasource_manager import DatasourceManager
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
parameters_for_log = datasource_info
try:
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=self.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
credential_id=datasource_info.get("credential_id", ""),
)
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
if credentials:
datasource_runtime.runtime.credentials = credentials
online_document_result: Generator[DatasourceMessage, None, None] = (
datasource_runtime.get_online_document_page_content(
user_id=self.user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(
workspace_id=datasource_info.get("workspace_id", ""),
page_id=datasource_info.get("page", {}).get("page_id", ""),
type=datasource_info.get("page", {}).get("type", ""),
),
provider_type=datasource_type,
)
)
yield from self._transform_message(
messages=online_document_result,
parameters_for_log=parameters_for_log,
datasource_info=datasource_info,
)
case DatasourceProviderType.ONLINE_DRIVE:
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
if credentials:
datasource_runtime.runtime.credentials = credentials
online_drive_result: Generator[DatasourceMessage, None, None] = (
datasource_runtime.online_drive_download_file(
user_id=self.user_id,
request=OnlineDriveDownloadFileRequest(
id=datasource_info.get("id", ""),
bucket=datasource_info.get("bucket"),
),
provider_type=datasource_type,
)
)
yield from self._transform_datasource_file_message(
messages=online_drive_result,
parameters_for_log=parameters_for_log,
datasource_info=datasource_info,
variable_pool=variable_pool,
datasource_type=datasource_type,
)
case DatasourceProviderType.WEBSITE_CRAWL:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
**datasource_info,
"datasource_type": datasource_type,
},
)
)
case DatasourceProviderType.LOCAL_FILE:
related_id = datasource_info.get("related_id")
if not related_id:
raise DatasourceNodeError("File is not exist")
upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first()
if not upload_file:
raise ValueError("Invalid upload file Info")
file_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.tenant_id,
type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=upload_file.source_url,
)
variable_pool.add([self._node_id, "file"], file_info)
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": file_info,
"datasource_type": datasource_type,
},
)
)
case _:
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
except PluginDaemonClientSideError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__,
)
)
except DatasourceNodeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to invoke datasource: {str(e)}",
error_type=type(e).__name__,
)
)
def _generate_parameters(
self,
*,
datasource_parameters: Sequence[DatasourceParameter],
variable_pool: VariablePool,
node_data: DatasourceNodeData,
for_log: bool = False,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
result: dict[str, Any] = {}
if node_data.datasource_parameters:
for parameter_name in node_data.datasource_parameters:
parameter = datasource_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
datasource_input = node_data.datasource_parameters[parameter_name]
if datasource_input.type == "variable":
variable = variable_pool.get(datasource_input.value)
if variable is None:
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
parameter_value = variable.value
elif datasource_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(datasource_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
result[parameter_name] = parameter_value
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
typed_node_data = DatasourceNodeData.model_validate(node_data)
result = {}
if typed_node_data.datasource_parameters:
for parameter_name in typed_node_data.datasource_parameters:
input = typed_node_data.datasource_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
elif input.type == "constant":
pass
result = {node_id + "." + key: value for key, value in result.items()}
return result
def _transform_message(
self,
messages: Generator[DatasourceMessage, None, None],
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict] = []
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
DatasourceMessage.MessageType.IMAGE_LINK,
DatasourceMessage.MessageType.BINARY_LINK,
DatasourceMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, DatasourceMessage.TextMessage)
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
elif message.type == DatasourceMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceMessage.TextMessage)
assert message.meta
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"tool_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
elif message.type == DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
json.append(message.message.json_object)
elif message.type == DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[self._node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
# mark the end of the stream
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={**variables},
metadata={
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
},
inputs=parameters_for_log,
)
)
@classmethod
def version(cls) -> str:
return "1"
def _transform_datasource_file_message(
self,
messages: Generator[DatasourceMessage, None, None],
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
variable_pool: VariablePool,
datasource_type: DatasourceProviderType,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
file = None
for message in message_stream:
if message.type == DatasourceMessage.MessageType.BINARY_LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
if file:
variable_pool.add([self._node_id, "file"], file)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": file,
"datasource_type": datasource_type,
},
)
)

View File

@@ -0,0 +1,41 @@
from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.workflow.nodes.base.entities import BaseNodeData
class DatasourceEntity(BaseModel):
plugin_id: str
provider_name: str # redundancy
provider_type: str
datasource_name: str | None = "local_file"
datasource_configurations: dict[str, Any] | None = None
plugin_unique_identifier: str | None = None # redundancy
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
class DatasourceInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"] | None = None
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool):
raise ValueError("value must be a string, int, float, or bool")
return typ
datasource_parameters: dict[str, DatasourceInput] | None = None

View File

@@ -0,0 +1,16 @@
class DatasourceNodeError(ValueError):
"""Base exception for datasource node errors."""
pass
class DatasourceParameterError(DatasourceNodeError):
"""Exception raised for errors in datasource parameters."""
pass
class DatasourceFileError(DatasourceNodeError):
"""Exception raised for errors related to datasource files."""
pass

View File

@@ -0,0 +1,4 @@
from .entities import DocumentExtractorNodeData
from .node import DocumentExtractorNode
__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"]

View File

@@ -0,0 +1,7 @@
from collections.abc import Sequence
from core.workflow.nodes.base import BaseNodeData
class DocumentExtractorNodeData(BaseNodeData):
variable_selector: Sequence[str]

View File

@@ -0,0 +1,14 @@
class DocumentExtractorError(ValueError):
"""Base exception for errors related to the DocumentExtractorNode."""
class FileDownloadError(DocumentExtractorError):
"""Exception raised when there's an error downloading a file."""
class UnsupportedFileTypeError(DocumentExtractorError):
"""Exception raised when trying to extract text from an unsupported file type."""
class TextExtractionError(DocumentExtractorError):
"""Exception raised when there's an error during text extraction from a file."""

View File

@@ -0,0 +1,693 @@
import csv
import io
import json
import logging
import os
import tempfile
from collections.abc import Mapping, Sequence
from typing import Any
import chardet
import docx
import pandas as pd
import pypandoc
import pypdfium2
import webvtt
import yaml
from docx.document import Document
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P
from docx.table import Table
from docx.text.paragraph import Paragraph
from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
logger = logging.getLogger(__name__)
class DocumentExtractorNode(Node):
"""
Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files.
"""
node_type = NodeType.DOCUMENT_EXTRACTOR
_node_data: DocumentExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = DocumentExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
variable_selector = self._node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
if variable is None:
error_message = f"File variable not found for selector: {variable_selector}"
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message)
if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment):
error_message = f"Variable {variable_selector} is not an ArrayFileSegment"
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message)
value = variable.value
inputs = {"variable_selector": variable_selector}
process_data = {"documents": value if isinstance(value, list) else [value]}
try:
if isinstance(value, list):
extracted_text_list = list(map(_extract_text_from_file, value))
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": ArrayStringSegment(value=extracted_text_list)},
)
elif isinstance(value, File):
extracted_text = _extract_text_from_file(value)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": extracted_text},
)
else:
raise DocumentExtractorError(f"Unsupported variable type: {type(value)}")
except DocumentExtractorError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=inputs,
process_data=process_data,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
return {node_id + ".files": typed_node_data.variable_selector}
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
"""Extract text from a file based on its MIME type."""
match mime_type:
case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml":
return _extract_text_from_plain_text(file_content)
case "application/pdf":
return _extract_text_from_pdf(file_content)
case "application/msword":
return _extract_text_from_doc(file_content)
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
return _extract_text_from_docx(file_content)
case "text/csv":
return _extract_text_from_csv(file_content)
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel":
return _extract_text_from_excel(file_content)
case "application/vnd.ms-powerpoint":
return _extract_text_from_ppt(file_content)
case "application/vnd.openxmlformats-officedocument.presentationml.presentation":
return _extract_text_from_pptx(file_content)
case "application/epub+zip":
return _extract_text_from_epub(file_content)
case "message/rfc822":
return _extract_text_from_eml(file_content)
case "application/vnd.ms-outlook":
return _extract_text_from_msg(file_content)
case "application/json":
return _extract_text_from_json(file_content)
case "application/x-yaml" | "text/yaml":
return _extract_text_from_yaml(file_content)
case "text/vtt":
return _extract_text_from_vtt(file_content)
case "text/properties":
return _extract_text_from_properties(file_content)
case _:
raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
"""Extract text from a file based on its file extension."""
match file_extension:
case (
".txt"
| ".markdown"
| ".md"
| ".mdx"
| ".html"
| ".htm"
| ".xml"
| ".c"
| ".h"
| ".cpp"
| ".hpp"
| ".cc"
| ".cxx"
| ".c++"
| ".py"
| ".js"
| ".ts"
| ".jsx"
| ".tsx"
| ".java"
| ".php"
| ".rb"
| ".go"
| ".rs"
| ".swift"
| ".kt"
| ".scala"
| ".sh"
| ".bash"
| ".bat"
| ".ps1"
| ".sql"
| ".r"
| ".m"
| ".pl"
| ".lua"
| ".vim"
| ".asm"
| ".s"
| ".css"
| ".scss"
| ".less"
| ".sass"
| ".ini"
| ".cfg"
| ".conf"
| ".toml"
| ".env"
| ".log"
| ".vtt"
):
return _extract_text_from_plain_text(file_content)
case ".json":
return _extract_text_from_json(file_content)
case ".yaml" | ".yml":
return _extract_text_from_yaml(file_content)
case ".pdf":
return _extract_text_from_pdf(file_content)
case ".doc":
return _extract_text_from_doc(file_content)
case ".docx":
return _extract_text_from_docx(file_content)
case ".csv":
return _extract_text_from_csv(file_content)
case ".xls" | ".xlsx":
return _extract_text_from_excel(file_content)
case ".ppt":
return _extract_text_from_ppt(file_content)
case ".pptx":
return _extract_text_from_pptx(file_content)
case ".epub":
return _extract_text_from_epub(file_content)
case ".eml":
return _extract_text_from_eml(file_content)
case ".msg":
return _extract_text_from_msg(file_content)
case ".properties":
return _extract_text_from_properties(file_content)
case _:
raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}")
def _extract_text_from_plain_text(file_content: bytes) -> str:
try:
# Detect encoding using chardet
result = chardet.detect(file_content)
encoding = result["encoding"]
# Fallback to utf-8 if detection fails
if not encoding:
encoding = "utf-8"
return file_content.decode(encoding, errors="ignore")
except (UnicodeDecodeError, LookupError) as e:
# If decoding fails, try with utf-8 as last resort
try:
return file_content.decode("utf-8", errors="ignore")
except UnicodeDecodeError:
raise TextExtractionError(f"Failed to decode plain text file: {e}") from e
def _extract_text_from_json(file_content: bytes) -> str:
try:
# Detect encoding using chardet
result = chardet.detect(file_content)
encoding = result["encoding"]
# Fallback to utf-8 if detection fails
if not encoding:
encoding = "utf-8"
json_data = json.loads(file_content.decode(encoding, errors="ignore"))
return json.dumps(json_data, indent=2, ensure_ascii=False)
except (UnicodeDecodeError, LookupError, json.JSONDecodeError) as e:
# If decoding fails, try with utf-8 as last resort
try:
json_data = json.loads(file_content.decode("utf-8", errors="ignore"))
return json.dumps(json_data, indent=2, ensure_ascii=False)
except (UnicodeDecodeError, json.JSONDecodeError):
raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e
def _extract_text_from_yaml(file_content: bytes) -> str:
"""Extract the content from yaml file"""
try:
# Detect encoding using chardet
result = chardet.detect(file_content)
encoding = result["encoding"]
# Fallback to utf-8 if detection fails
if not encoding:
encoding = "utf-8"
yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore"))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e:
# If decoding fails, try with utf-8 as last resort
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore"))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, yaml.YAMLError):
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
def _extract_text_from_pdf(file_content: bytes) -> str:
try:
pdf_file = io.BytesIO(file_content)
pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True)
text = ""
for page in pdf_document:
text_page = page.get_textpage()
text += text_page.get_text_range()
text_page.close()
page.close()
return text
except Exception as e:
raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e
def _extract_text_from_doc(file_content: bytes) -> str:
"""
Extract text from a DOC file.
"""
from unstructured.partition.api import partition_via_api
if not dify_config.UNSTRUCTURED_API_URL:
raise TextExtractionError("UNSTRUCTURED_API_URL must be set")
try:
with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
with open(temp_file.name, "rb") as file:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
)
os.unlink(temp_file.name)
return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC: {str(e)}") from e
def parser_docx_part(block, doc: Document, content_items, i):
if isinstance(block, CT_P):
content_items.append((i, "paragraph", Paragraph(block, doc)))
elif isinstance(block, CT_Tbl):
content_items.append((i, "table", Table(block, doc)))
def _extract_text_from_docx(file_content: bytes) -> str:
"""
Extract text from a DOCX file.
For now support only paragraph and table add more if needed
"""
try:
doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file)
text = []
# Keep track of paragraph and table positions
content_items: list[tuple[int, str, Table | Paragraph]] = []
it = iter(doc.element.body)
part = next(it, None)
i = 0
while part is not None:
parser_docx_part(part, doc, content_items, i)
i = i + 1
part = next(it, None)
# Process sorted content
for _, item_type, item in content_items:
if item_type == "paragraph":
if isinstance(item, Table):
continue
text.append(item.text)
elif item_type == "table":
# Process tables
if not isinstance(item, Table):
continue
try:
# Check if any cell in the table has text
has_content = False
for row in item.rows:
if any(cell.text.strip() for cell in row.cells):
has_content = True
break
if has_content:
cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells]
markdown_table = f"| {' | '.join(cell_texts)} |\n"
markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n"
for row in item.rows[1:]:
# Replace newlines with <br> in each cell
row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells]
markdown_table += "| " + " | ".join(row_cells) + " |\n"
text.append(markdown_table)
except Exception as e:
logger.warning("Failed to extract table from DOC: %s", e)
continue
return "\n".join(text)
except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e
def _download_file_content(file: File) -> bytes:
"""Download the content of a file based on its transfer method."""
try:
if file.transfer_method == FileTransferMethod.REMOTE_URL:
if file.remote_url is None:
raise FileDownloadError("Missing URL for remote file")
response = ssrf_proxy.get(file.remote_url)
response.raise_for_status()
return response.content
else:
return file_manager.download(file)
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
def _extract_text_from_file(file: File):
file_content = _download_file_content(file)
if file.extension:
extracted_text = _extract_text_by_file_extension(file_content=file_content, file_extension=file.extension)
elif file.mime_type:
extracted_text = _extract_text_by_mime_type(file_content=file_content, mime_type=file.mime_type)
else:
raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing")
return extracted_text
def _extract_text_from_csv(file_content: bytes) -> str:
try:
# Detect encoding using chardet
result = chardet.detect(file_content)
encoding = result["encoding"]
# Fallback to utf-8 if detection fails
if not encoding:
encoding = "utf-8"
try:
csv_file = io.StringIO(file_content.decode(encoding, errors="ignore"))
except (UnicodeDecodeError, LookupError):
# If decoding fails, try with utf-8 as last resort
csv_file = io.StringIO(file_content.decode("utf-8", errors="ignore"))
csv_reader = csv.reader(csv_file)
rows = list(csv_reader)
if not rows:
return ""
# Combine multi-line text in the header row
header_row = [cell.replace("\n", " ").replace("\r", "") for cell in rows[0]]
# Create Markdown table
markdown_table = "| " + " | ".join(header_row) + " |\n"
markdown_table += "| " + " | ".join(["-" * len(col) for col in rows[0]]) + " |\n"
# Process each data row and combine multi-line text in each cell
for row in rows[1:]:
processed_row = [cell.replace("\n", " ").replace("\r", "") for cell in row]
markdown_table += "| " + " | ".join(processed_row) + " |\n"
return markdown_table
except Exception as e:
raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e
def _extract_text_from_excel(file_content: bytes) -> str:
"""Extract text from an Excel file using pandas."""
def _construct_markdown_table(df: pd.DataFrame) -> str:
"""Manually construct a Markdown table from a DataFrame."""
# Construct the header row
header_row = "| " + " | ".join(df.columns) + " |"
# Construct the separator row
separator_row = "| " + " | ".join(["-" * len(col) for col in df.columns]) + " |"
# Construct the data rows
data_rows = []
for _, row in df.iterrows():
data_row = "| " + " | ".join(map(str, row)) + " |"
data_rows.append(data_row)
# Combine all rows into a single string
markdown_table = "\n".join([header_row, separator_row] + data_rows)
return markdown_table
try:
excel_file = pd.ExcelFile(io.BytesIO(file_content))
markdown_table = ""
for sheet_name in excel_file.sheet_names:
try:
df = excel_file.parse(sheet_name=sheet_name)
df.dropna(how="all", inplace=True)
# Combine multi-line text in each cell into a single line
df = df.map(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x)
# Combine multi-line text in column names into a single line
df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns])
# Manually construct the Markdown table
markdown_table += _construct_markdown_table(df) + "\n\n"
except Exception:
continue
return markdown_table
except Exception as e:
raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e
def _extract_text_from_ppt(file_content: bytes) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.ppt import partition_ppt
try:
if dify_config.UNSTRUCTURED_API_URL:
with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
with open(temp_file.name, "rb") as file:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
)
os.unlink(temp_file.name)
else:
with io.BytesIO(file_content) as file:
elements = partition_ppt(file=file)
return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
def _extract_text_from_pptx(file_content: bytes) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.pptx import partition_pptx
try:
if dify_config.UNSTRUCTURED_API_URL:
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
with open(temp_file.name, "rb") as file:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
)
os.unlink(temp_file.name)
else:
with io.BytesIO(file_content) as file:
elements = partition_pptx(file=file)
return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
def _extract_text_from_epub(file_content: bytes) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.epub import partition_epub
try:
if dify_config.UNSTRUCTURED_API_URL:
with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
with open(temp_file.name, "rb") as file:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
)
os.unlink(temp_file.name)
else:
pypandoc.download_pandoc()
with io.BytesIO(file_content) as file:
elements = partition_epub(file=file)
return "\n".join([str(element) for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e
def _extract_text_from_eml(file_content: bytes) -> str:
from unstructured.partition.email import partition_email
try:
with io.BytesIO(file_content) as file:
elements = partition_email(file=file)
return "\n".join([str(element) for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e
def _extract_text_from_msg(file_content: bytes) -> str:
from unstructured.partition.msg import partition_msg
try:
with io.BytesIO(file_content) as file:
elements = partition_msg(file=file)
return "\n".join([str(element) for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e
def _extract_text_from_vtt(vtt_bytes: bytes) -> str:
text = _extract_text_from_plain_text(vtt_bytes)
# remove bom
text = text.lstrip("\ufeff")
raw_results = []
for caption in webvtt.from_string(text):
raw_results.append((caption.voice, caption.text))
# Merge consecutive utterances by the same speaker
merged_results = []
if raw_results:
current_speaker, current_text = raw_results[0]
for i in range(1, len(raw_results)):
spk, txt = raw_results[i]
if spk is None:
merged_results.append((None, current_text))
continue
if spk == current_speaker:
# If it is the same speaker, merge the utterances (joined by space)
current_text += " " + txt
else:
# If the speaker changes, register the utterance so far and move on
merged_results.append((current_speaker, current_text))
current_speaker, current_text = spk, txt
# Add the last element
merged_results.append((current_speaker, current_text))
else:
merged_results = raw_results
# Return the result in the specified format: Speaker "text" style
formatted = [f'{spk or ""} "{txt}"' for spk, txt in merged_results]
return "\n".join(formatted)
def _extract_text_from_properties(file_content: bytes) -> str:
try:
text = _extract_text_from_plain_text(file_content)
lines = text.splitlines()
result = []
for line in lines:
line = line.strip()
# Preserve comments and empty lines
if not line or line.startswith("#") or line.startswith("!"):
result.append(line)
continue
if "=" in line:
key, value = line.split("=", 1)
elif ":" in line:
key, value = line.split(":", 1)
else:
key, value = line, ""
result.append(f"{key.strip()}: {value.strip()}")
return "\n".join(result)
except Exception as e:
raise TextExtractionError(f"Failed to extract text from properties file: {str(e)}") from e

View File

@@ -0,0 +1,74 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.end.entities import EndNodeData
class EndNode(Node):
node_type = NodeType.END
execution_type = NodeExecutionType.RESPONSE
_node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = EndNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run node - collect all outputs at once.
This method runs after streaming is complete (if streaming was enabled).
It collects all output variables and returns them.
"""
output_variables = self._node_data.outputs
outputs = {}
for variable_selector in output_variables:
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
value = variable.to_object() if variable is not None else None
outputs[variable_selector.variable] = value
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=outputs,
outputs=outputs,
)
def get_streaming_template(self) -> Template:
"""
Get the template for streaming.
Returns:
Template instance for this End node
"""
outputs_config = [
{"variable": output.variable, "value_selector": output.value_selector} for output in self._node_data.outputs
]
return Template.from_end_outputs(outputs_config)

View File

@@ -0,0 +1,25 @@
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
class EndNodeData(BaseNodeData):
"""
END Node Data.
"""
outputs: list[VariableSelector]
class EndStreamParam(BaseModel):
"""
EndStreamParam entity
"""
end_dependencies: dict[str, list[str]] = Field(
..., description="end dependencies (end node id -> dependent node ids)"
)
end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field(
..., description="end stream variable selector mapping (end node id -> stream variable selectors)"
)

View File

@@ -0,0 +1,4 @@
from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData
from .node import HttpRequestNode
__all__ = ["BodyData", "HttpRequestNode", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "HttpRequestNodeData"]

View File

@@ -0,0 +1,192 @@
import mimetypes
from collections.abc import Sequence
from email.message import Message
from typing import Any, Literal
import httpx
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from configs import dify_config
from core.workflow.nodes.base import BaseNodeData
class HttpRequestNodeAuthorizationConfig(BaseModel):
type: Literal["basic", "bearer", "custom"]
api_key: str
header: str = ""
class HttpRequestNodeAuthorization(BaseModel):
type: Literal["no-auth", "api-key"]
config: HttpRequestNodeAuthorizationConfig | None = None
@field_validator("config", mode="before")
@classmethod
def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo):
"""
Check config, if type is no-auth, config should be None, otherwise it should be a dict.
"""
if values.data["type"] == "no-auth":
return None
else:
if not v or not isinstance(v, dict):
raise ValueError("config should be a dict")
return v
class BodyData(BaseModel):
key: str = ""
type: Literal["file", "text"]
value: str = ""
file: Sequence[str] = Field(default_factory=list)
class HttpRequestNodeBody(BaseModel):
type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json", "binary"]
data: Sequence[BodyData] = Field(default_factory=list)
@field_validator("data", mode="before")
@classmethod
def check_data(cls, v: Any):
"""For compatibility, if body is not set, return empty list."""
if not v:
return []
if isinstance(v, str):
return [BodyData(key="", type="text", value=v)]
return v
class HttpRequestNodeTimeout(BaseModel):
connect: int = dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT
read: int = dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT
write: int = dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT
class HttpRequestNodeData(BaseNodeData):
"""
Code Node Data.
"""
method: Literal[
"get",
"post",
"put",
"patch",
"delete",
"head",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
]
url: str
authorization: HttpRequestNodeAuthorization
headers: str
params: str
body: HttpRequestNodeBody | None = None
timeout: HttpRequestNodeTimeout | None = None
ssl_verify: bool | None = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
class Response:
headers: dict[str, str]
response: httpx.Response
def __init__(self, response: httpx.Response):
self.response = response
self.headers = dict(response.headers)
@property
def is_file(self):
"""
Determine if the response contains a file by checking:
1. Content-Disposition header (RFC 6266)
2. Content characteristics
3. MIME type analysis
"""
content_type = self.content_type.split(";")[0].strip().lower()
parsed_content_disposition = self.parsed_content_disposition
# Check if it's explicitly marked as an attachment
if parsed_content_disposition:
disp_type = parsed_content_disposition.get_content_disposition() # Returns 'attachment', 'inline', or None
filename = parsed_content_disposition.get_filename() # Returns filename if present, None otherwise
if disp_type == "attachment" or filename is not None:
return True
# For 'text/' types, only 'csv' should be downloaded as file
if content_type.startswith("text/") and "csv" not in content_type:
return False
# For application types, try to detect if it's a text-based format
if content_type.startswith("application/"):
# Common text-based application types
if any(
text_type in content_type
for text_type in ("json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql")
):
return False
# Try to detect if content is text-based by sampling first few bytes
try:
# Sample first 1024 bytes for text detection
content_sample = self.response.content[:1024]
content_sample.decode("utf-8")
# If we can decode as UTF-8 and find common text patterns, likely not a file
text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ")
if any(marker in content_sample for marker in text_markers):
return False
except UnicodeDecodeError:
# If we can't decode as UTF-8, likely a binary file
return True
# For other types, use MIME type analysis
main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or ""))
if main_type:
return main_type.split("/")[0] in ("application", "image", "audio", "video")
# For unknown types, check if it's a media type
return any(media_type in content_type for media_type in ("image/", "audio/", "video/"))
@property
def content_type(self) -> str:
return self.headers.get("content-type", "")
@property
def text(self) -> str:
return self.response.text
@property
def content(self) -> bytes:
return self.response.content
@property
def status_code(self) -> int:
return self.response.status_code
@property
def size(self) -> int:
return len(self.content)
@property
def readable_size(self) -> str:
if self.size < 1024:
return f"{self.size} bytes"
elif self.size < 1024 * 1024:
return f"{(self.size / 1024):.2f} KB"
else:
return f"{(self.size / 1024 / 1024):.2f} MB"
@property
def parsed_content_disposition(self) -> Message | None:
content_disposition = self.headers.get("content-disposition", "")
if content_disposition:
msg = Message()
msg["content-disposition"] = content_disposition
return msg
return None

View File

@@ -0,0 +1,26 @@
class HttpRequestNodeError(ValueError):
"""Custom error for HTTP request node."""
class AuthorizationConfigError(HttpRequestNodeError):
"""Raised when authorization config is missing or invalid."""
class FileFetchError(HttpRequestNodeError):
"""Raised when a file cannot be fetched."""
class InvalidHttpMethodError(HttpRequestNodeError):
"""Raised when an invalid HTTP method is used."""
class ResponseSizeError(HttpRequestNodeError):
"""Raised when the response size exceeds the allowed threshold."""
class RequestBodyError(HttpRequestNodeError):
"""Raised when the request body is invalid."""
class InvalidURLError(HttpRequestNodeError):
"""Raised when the URL is invalid."""

View File

@@ -0,0 +1,463 @@
import base64
import json
import secrets
import string
from collections.abc import Mapping
from copy import deepcopy
from typing import Any, Literal
from urllib.parse import urlencode, urlparse
import httpx
from json_repair import repair_json
from configs import dify_config
from core.file import file_manager
from core.file.enums import FileTransferMethod
from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.runtime import VariablePool
from .entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
HttpRequestNodeTimeout,
Response,
)
from .exc import (
AuthorizationConfigError,
FileFetchError,
HttpRequestNodeError,
InvalidHttpMethodError,
InvalidURLError,
RequestBodyError,
ResponseSizeError,
)
BODY_TYPE_TO_CONTENT_TYPE = {
"json": "application/json",
"x-www-form-urlencoded": "application/x-www-form-urlencoded",
"form-data": "multipart/form-data",
"raw-text": "text/plain",
}
class Executor:
method: Literal[
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
]
url: str
params: list[tuple[str, str]] | None
content: str | bytes | None
data: Mapping[str, Any] | None
files: list[tuple[str, tuple[str | None, bytes, str]]] | None
json: Any
headers: dict[str, str]
auth: HttpRequestNodeAuthorization
timeout: HttpRequestNodeTimeout
max_retries: int
boundary: str
def __init__(
self,
*,
node_data: HttpRequestNodeData,
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
if node_data.authorization.config is None:
raise AuthorizationConfigError("authorization config is required")
node_data.authorization.config.api_key = variable_pool.convert_template(
node_data.authorization.config.api_key
).text
self.url = node_data.url
self.method = node_data.method
self.auth = node_data.authorization
self.timeout = timeout
self.ssl_verify = node_data.ssl_verify
self.params = None
self.headers = {}
self.content = None
self.files = None
self.data = None
self.json = None
self.max_retries = max_retries
# init template
self.variable_pool = variable_pool
self.node_data = node_data
self._initialize()
def _initialize(self):
self._init_url()
self._init_params()
self._init_headers()
self._init_body()
def _init_url(self):
self.url = self.variable_pool.convert_template(self.node_data.url).text
# check if url is a valid URL
if not self.url:
raise InvalidURLError("url is required")
if not self.url.startswith(("http://", "https://")):
raise InvalidURLError("url should start with http:// or https://")
def _init_params(self):
"""
Almost same as _init_headers(), difference:
1. response a list tuple to support same key, like 'aa=1&aa=2'
2. param value may have '\n', we need to splitlines then extract the variable value.
"""
result = []
for line in self.node_data.params.splitlines():
if not (line := line.strip()):
continue
key, *value = line.split(":", 1)
if not (key := key.strip()):
continue
value_str = value[0].strip() if value else ""
result.append(
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text)
)
if result:
self.params = result
def _init_headers(self):
"""
Convert the header string of frontend to a dictionary.
Each line in the header string represents a key-value pair.
Keys and values are separated by ':'.
Empty values are allowed.
Examples:
'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'}
'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'}
'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'}
"""
headers = self.variable_pool.convert_template(self.node_data.headers).text
self.headers = {
key.strip(): (value[0].strip() if value else "")
for line in headers.splitlines()
if line.strip()
for key, *value in [line.split(":", 1)]
}
def _init_body(self):
body = self.node_data.body
if body is not None:
data = body.data
match body.type:
case "none":
self.content = ""
case "raw-text":
if len(data) != 1:
raise RequestBodyError("raw-text body type should have exactly one item")
self.content = self.variable_pool.convert_template(data[0].value).text
case "json":
if len(data) != 1:
raise RequestBodyError("json body type should have exactly one item")
json_string = self.variable_pool.convert_template(data[0].value).text
try:
repaired = repair_json(json_string)
json_object = json.loads(repaired, strict=False)
except json.JSONDecodeError as e:
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
self.json = json_object
# self.json = self._parse_object_contains_variables(json_object)
case "binary":
if len(data) != 1:
raise RequestBodyError("binary body type should have exactly one item")
file_selector = data[0].file
file_variable = self.variable_pool.get_file(file_selector)
if file_variable is None:
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
file = file_variable.value
self.content = file_manager.download(file)
case "x-www-form-urlencoded":
form_data = {
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
item.value
).text
for item in data
}
self.data = form_data
case "form-data":
form_data = {
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
item.value
).text
for item in filter(lambda item: item.type == "text", data)
}
file_selectors = {
self.variable_pool.convert_template(item.key).text: item.file
for item in filter(lambda item: item.type == "file", data)
}
# get files from file_selectors, add support for array file variables
files_list = []
for key, selector in file_selectors.items():
segment = self.variable_pool.get(selector)
if isinstance(segment, FileSegment):
files_list.append((key, [segment.value]))
elif isinstance(segment, ArrayFileSegment):
files_list.append((key, list(segment.value)))
# get files from file_manager
files: dict[str, list[tuple[str | None, bytes, str]]] = {}
for key, files_in_segment in files_list:
for file in files_in_segment:
if file.related_id is not None or (
file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None
):
file_tuple = (
file.filename,
file_manager.download(file),
file.mime_type or "application/octet-stream",
)
if key not in files:
files[key] = []
files[key].append(file_tuple)
# convert files to list for httpx request
# If there are no actual files, we still need to force httpx to use `multipart/form-data`.
# This is achieved by inserting a harmless placeholder file that will be ignored by the server.
if not files:
self.files = [("__multipart_placeholder__", ("", b"", "application/octet-stream"))]
if files:
self.files = []
for key, file_tuples in files.items():
for file_tuple in file_tuples:
self.files.append((key, file_tuple))
self.data = form_data
def _assembling_headers(self) -> dict[str, Any]:
authorization = deepcopy(self.auth)
headers = deepcopy(self.headers) or {}
if self.auth.type == "api-key":
if self.auth.config is None:
raise AuthorizationConfigError("self.authorization config is required")
if authorization.config is None:
raise AuthorizationConfigError("authorization config is required")
if not authorization.config.header:
authorization.config.header = "Authorization"
if self.auth.config.type == "bearer" and authorization.config.api_key:
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
elif self.auth.config.type == "basic" and authorization.config.api_key:
credentials = authorization.config.api_key
if ":" in credentials:
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
else:
encoded_credentials = credentials
headers[authorization.config.header] = f"Basic {encoded_credentials}"
elif self.auth.config.type == "custom":
if authorization.config.header and authorization.config.api_key:
headers[authorization.config.header] = authorization.config.api_key
# Handle Content-Type for multipart/form-data requests
# Fix for issue #23829: Missing boundary when using multipart/form-data
body = self.node_data.body
if body and body.type == "form-data":
# For multipart/form-data with files (including placeholder files),
# remove any manually set Content-Type header to let httpx handle
# For multipart/form-data, if any files are present (including placeholder files),
# we must remove any manually set Content-Type header. This is because httpx needs to
# automatically set the Content-Type and boundary for multipart encoding whenever files
# are included, even if they are placeholders, to avoid boundary issues and ensure correct
# file upload behaviour. Manually setting Content-Type can cause httpx to fail to set the
# boundary, resulting in invalid requests.
if self.files:
# Remove Content-Type if it was manually set to avoid boundary issues
headers = {k: v for k, v in headers.items() if k.lower() != "content-type"}
else:
# No files at all, set Content-Type manually
if "content-type" not in (k.lower() for k in headers):
headers["Content-Type"] = "multipart/form-data"
elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE:
# Set Content-Type for other body types
if "content-type" not in (k.lower() for k in headers):
headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
return headers
def _validate_and_parse_response(self, response: httpx.Response) -> Response:
executor_response = Response(response)
threshold_size = (
dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE
if executor_response.is_file
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
)
if executor_response.size > threshold_size:
raise ResponseSizeError(
f"{'File' if executor_response.is_file else 'Text'} size is too large,"
f" max size is {threshold_size / 1024 / 1024:.2f} MB,"
f" but current size is {executor_response.readable_size}."
)
return executor_response
def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
"""
do http request depending on api bundle
"""
_METHOD_MAP = {
"get": ssrf_proxy.get,
"head": ssrf_proxy.head,
"post": ssrf_proxy.post,
"put": ssrf_proxy.put,
"delete": ssrf_proxy.delete,
"patch": ssrf_proxy.patch,
}
method_lc = self.method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
"url": self.url,
"data": self.data,
"files": self.files,
"json": self.json,
"content": self.content,
"headers": headers,
"params": self.params,
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"ssl_verify": self.ssl_verify,
"follow_redirects": True,
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response
def invoke(self) -> Response:
# assemble headers
headers = self._assembling_headers()
# do http request
response = self._do_http_request(headers)
# validate response
return self._validate_and_parse_response(response)
def to_log(self):
url_parts = urlparse(self.url)
path = url_parts.path or "/"
# Add query parameters
if self.params:
query_string = urlencode(self.params)
path += f"?{query_string}"
elif url_parts.query:
path += f"?{url_parts.query}"
raw = f"{self.method.upper()} {path} HTTP/1.1\r\n"
raw += f"Host: {url_parts.netloc}\r\n"
headers = self._assembling_headers()
body = self.node_data.body
boundary = f"----WebKitFormBoundary{_generate_random_string(16)}"
if body:
if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE:
headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
if body.type == "form-data":
headers["Content-Type"] = f"multipart/form-data; boundary={boundary}"
for k, v in headers.items():
if self.auth.type == "api-key":
authorization_header = "Authorization"
if self.auth.config and self.auth.config.header:
authorization_header = self.auth.config.header
if k.lower() == authorization_header.lower():
raw += f"{k}: {'*' * len(v)}\r\n"
continue
raw += f"{k}: {v}\r\n"
body_string = ""
# Only log actual files if present.
# '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file.
# This prevents logging meaningless placeholder entries.
if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files):
for file_entry in self.files:
# file_entry should be (key, (filename, content, mime_type)), but handle edge cases
if len(file_entry) != 2 or len(file_entry[1]) < 2:
continue # skip malformed entries
key = file_entry[0]
content = file_entry[1][1]
body_string += f"--{boundary}\r\n"
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
# decode content safely
try:
body_string += content.decode("utf-8")
except UnicodeDecodeError:
body_string += content.decode("utf-8", errors="replace")
body_string += "\r\n"
body_string += f"--{boundary}--\r\n"
elif self.node_data.body:
if self.content:
if isinstance(self.content, bytes):
body_string = self.content.decode("utf-8", errors="replace")
else:
body_string = self.content
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
body_string = urlencode(self.data)
elif self.data and self.node_data.body.type == "form-data":
for key, value in self.data.items():
body_string += f"--{boundary}\r\n"
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
body_string += f"{value}\r\n"
body_string += f"--{boundary}--\r\n"
elif self.json:
body_string = json.dumps(self.json)
elif self.node_data.body.type == "raw-text":
if len(self.node_data.body.data) != 1:
raise RequestBodyError("raw-text body type should have exactly one item")
body_string = self.node_data.body.data[0].value
if body_string:
raw += f"Content-Length: {len(body_string)}\r\n"
raw += "\r\n" # Empty line between headers and body
raw += body_string
return raw
def _generate_random_string(n: int) -> str:
"""
Generate a random string of lowercase ASCII letters.
Args:
n (int): The length of the random string to generate.
Returns:
str: A random string of lowercase ASCII letters with length n.
Example:
>>> _generate_random_string(5)
'abcde'
"""
return "".join(secrets.choice(string.ascii_lowercase) for _ in range(n))

View File

@@ -0,0 +1,249 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
from typing import Any
from configs import dify_config
from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
from factories import file_factory
from .entities import (
HttpRequestNodeData,
HttpRequestNodeTimeout,
Response,
)
from .exc import HttpRequestNodeError, RequestBodyError
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
)
logger = logging.getLogger(__name__)
class HttpRequestNode(Node):
node_type = NodeType.HTTP_REQUEST
_node_data: HttpRequestNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = HttpRequestNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "http-request",
"config": {
"method": "get",
"authorization": {
"type": "no-auth",
},
"body": {"type": "none"},
"timeout": {
**HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(),
"max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
"max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
},
"ssl_verify": dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
},
"retry_config": {
"max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
"retry_interval": 0.5 * (2**2),
"retry_enabled": True,
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
process_data = {}
try:
http_executor = Executor(
node_data=self._node_data,
timeout=self._get_request_timeout(self._node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
)
process_data["request"] = http_executor.to_log()
response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
if not response.response.is_success and (self.error_strategy or self.retry):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
"status_code": response.status_code,
"body": response.text if not files.value else "",
"headers": response.headers,
"files": files,
},
process_data={
"request": http_executor.to_log(),
},
error=f"Request failed with status code {response.status_code}",
error_type="HTTPResponseCodeError",
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"status_code": response.status_code,
"body": response.text if not files.value else "",
"headers": response.headers,
"files": files,
},
process_data={
"request": http_executor.to_log(),
},
)
except HttpRequestNodeError as e:
logger.warning("http request node %s failed to run: %s", self._node_id, e)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
process_data=process_data,
error_type=type(e).__name__,
)
@staticmethod
def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout:
timeout = node_data.timeout
if timeout is None:
return HTTP_REQUEST_DEFAULT_TIMEOUT
timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect
timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read
timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write
return timeout
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = HttpRequestNodeData.model_validate(node_data)
selectors: list[VariableSelector] = []
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
if typed_node_data.body:
body_type = typed_node_data.body.type
data = typed_node_data.body.data
match body_type:
case "none":
pass
case "binary":
if len(data) != 1:
raise RequestBodyError("invalid body data, should have only one item")
selector = data[0].file
selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector))
case "json" | "raw-text":
if len(data) != 1:
raise RequestBodyError("invalid body data, should have only one item")
selectors += variable_template_parser.extract_selectors_from_template(data[0].key)
selectors += variable_template_parser.extract_selectors_from_template(data[0].value)
case "x-www-form-urlencoded":
for item in data:
selectors += variable_template_parser.extract_selectors_from_template(item.key)
selectors += variable_template_parser.extract_selectors_from_template(item.value)
case "form-data":
for item in data:
selectors += variable_template_parser.extract_selectors_from_template(item.key)
if item.type == "text":
selectors += variable_template_parser.extract_selectors_from_template(item.value)
elif item.type == "file":
selectors.append(
VariableSelector(variable="#" + ".".join(item.file) + "#", value_selector=item.file)
)
mapping = {}
for selector_iter in selectors:
mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector
return mapping
def extract_files(self, url: str, response: Response) -> ArrayFileSegment:
"""
Extract files from response by checking both Content-Type header and URL
"""
files: list[File] = []
is_file = response.is_file
content_type = response.content_type
content = response.content
parsed_content_disposition = response.parsed_content_disposition
content_disposition_type = None
if not is_file:
return ArrayFileSegment(value=[])
if parsed_content_disposition:
content_disposition_filename = parsed_content_disposition.get_filename()
if content_disposition_filename:
# If filename is available from content-disposition, use it to guess the content type
content_disposition_type = mimetypes.guess_type(content_disposition_filename)[0]
# Guess file extension from URL or Content-Type header
filename = url.split("?")[0].split("/")[-1] or ""
mime_type = (
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
)
tool_file_manager = ToolFileManager()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
file_binary=content,
mimetype=mime_type,
)
mapping = {
"tool_file_id": tool_file.id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
return ArrayFileSegment(value=files)
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

View File

@@ -0,0 +1,3 @@
from .human_input_node import HumanInputNode
__all__ = ["HumanInputNode"]

View File

@@ -0,0 +1,10 @@
from pydantic import Field
from core.workflow.nodes.base import BaseNodeData
class HumanInputNodeData(BaseNodeData):
"""Configuration schema for the HumanInput node."""
required_variables: list[str] = Field(default_factory=list)
pause_reason: str | None = Field(default=None)

View File

@@ -0,0 +1,133 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import HumanInputNodeData
class HumanInputNode(Node):
node_type = NodeType.HUMAN_INPUT
execution_type = NodeExecutionType.BRANCH
_BRANCH_SELECTION_KEYS: tuple[str, ...] = (
"edge_source_handle",
"edgeSourceHandle",
"source_handle",
"selected_branch",
"selectedBranch",
"branch",
"branch_id",
"branchId",
"handle",
)
_node_data: HumanInputNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = HumanInputNodeData(**data)
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def _run(self): # type: ignore[override]
if self._is_completion_ready():
branch_handle = self._resolve_branch_selection()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={},
edge_source_handle=branch_handle or "source",
)
return self._pause_generator()
def _pause_generator(self):
yield PauseRequestedEvent(reason=HumanInputRequired())
def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied."""
if not self._node_data.required_variables:
return False
variable_pool = self.graph_runtime_state.variable_pool
for selector_str in self._node_data.required_variables:
parts = selector_str.split(".")
if len(parts) != 2:
return False
segment = variable_pool.get(parts)
if segment is None:
return False
return True
def _resolve_branch_selection(self) -> str | None:
"""Determine the branch handle selected by human input if available."""
variable_pool = self.graph_runtime_state.variable_pool
for key in self._BRANCH_SELECTION_KEYS:
handle = self._extract_branch_handle(variable_pool.get((self.id, key)))
if handle:
return handle
default_values = self._node_data.default_value_dict
for key in self._BRANCH_SELECTION_KEYS:
handle = self._normalize_branch_value(default_values.get(key))
if handle:
return handle
return None
@staticmethod
def _extract_branch_handle(segment: Any) -> str | None:
if segment is None:
return None
candidate = getattr(segment, "to_object", None)
raw_value = candidate() if callable(candidate) else getattr(segment, "value", None)
if raw_value is None:
return None
return HumanInputNode._normalize_branch_value(raw_value)
@staticmethod
def _normalize_branch_value(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
stripped = value.strip()
return stripped or None
if isinstance(value, Mapping):
for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"):
candidate = value.get(key)
if isinstance(candidate, str) and candidate:
return candidate
return None

View File

@@ -0,0 +1,3 @@
from .if_else_node import IfElseNode
__all__ = ["IfElseNode"]

View File

@@ -0,0 +1,26 @@
from typing import Literal
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
from core.workflow.utils.condition.entities import Condition
class IfElseNodeData(BaseNodeData):
"""
If Else Node Data.
"""
class Case(BaseModel):
"""
Case entity representing a single logical condition group
"""
case_id: str
logical_operator: Literal["and", "or"]
conditions: list[Condition]
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
cases: list[Case] | None = None

View File

@@ -0,0 +1,150 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing_extensions import deprecated
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.runtime import VariablePool
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor
class IfElseNode(Node):
node_type = NodeType.IF_ELSE
execution_type = NodeExecutionType.BRANCH
_node_data: IfElseNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IfElseNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run node
:return:
"""
node_inputs: dict[str, Sequence[Mapping[str, Any]]] = {"conditions": []}
process_data: dict[str, list] = {"condition_results": []}
input_conditions: Sequence[Mapping[str, Any]] = []
final_result = False
selected_case_id = "false"
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used
if self._node_data.cases:
for case in self._node_data.cases:
input_conditions, group_result, final_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions,
operator=case.logical_operator,
)
process_data["condition_results"].append(
{
"group": case.model_dump(),
"results": group_result,
"final_result": final_result,
}
)
# Break if a case passes (logical short-circuit)
if final_result:
selected_case_id = case.case_id # Capture the ID of the passing case
break
else:
# TODO: Update database then remove this
# Fallback to old structure if cases are not defined
input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated]
condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool,
conditions=self._node_data.conditions or [],
operator=self._node_data.logical_operator or "and",
)
selected_case_id = "true" if final_result else "false"
process_data["condition_results"].append(
{"group": "default", "results": group_result, "final_result": final_result}
)
node_inputs["conditions"] = input_conditions
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_data, error=str(e)
)
outputs = {"result": final_result, "selected_case_id": selected_case_id}
data = NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
edge_source_handle=selected_case_id or "false", # Use case ID or 'default'
outputs=outputs,
)
return data
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = IfElseNodeData.model_validate(node_data)
var_mapping: dict[str, list[str]] = {}
for case in typed_node_data.cases or []:
for condition in case.conditions:
key = f"{node_id}.#{'.'.join(condition.variable_selector)}#"
var_mapping[key] = condition.variable_selector
return var_mapping
@deprecated("This function is deprecated. You should use the new cases structure.")
def _should_not_use_old_function(
*,
condition_processor: ConditionProcessor,
variable_pool: VariablePool,
conditions: list[Condition],
operator: Literal["and", "or"],
):
return condition_processor.process_conditions(
variable_pool=variable_pool,
conditions=conditions,
operator=operator,
)

View File

@@ -0,0 +1,5 @@
from .entities import IterationNodeData
from .iteration_node import IterationNode
from .iteration_start_node import IterationStartNode
__all__ = ["IterationNode", "IterationNodeData", "IterationStartNode"]

View File

@@ -0,0 +1,64 @@
from enum import StrEnum
from typing import Any
from pydantic import Field
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
class ErrorHandleMode(StrEnum):
TERMINATED = "terminated"
CONTINUE_ON_ERROR = "continue-on-error"
REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output"
class IterationNodeData(BaseIterationNodeData):
"""
Iteration Node Data.
"""
parent_loop_id: str | None = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
is_parallel: bool = False # open the parallel mode or not
parallel_nums: int = 10 # the numbers of parallel
error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error
flatten_output: bool = True # whether to flatten the output array if all elements are lists
class IterationStartNodeData(BaseNodeData):
"""
Iteration Start Node Data.
"""
pass
class IterationState(BaseIterationState):
"""
Iteration State.
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Any = None
class MetaData(BaseIterationState.MetaData):
"""
Data.
"""
iterator_length: int
def get_last_output(self) -> Any:
"""
Get last output.
"""
if self.outputs:
return self.outputs[-1]
return None
def get_current_output(self) -> Any:
"""
Get current output.
"""
return self.current_output

View File

@@ -0,0 +1,22 @@
class IterationNodeError(ValueError):
"""Base class for iteration node errors."""
class IteratorVariableNotFoundError(IterationNodeError):
"""Raised when the iterator variable is not found."""
class InvalidIteratorValueError(IterationNodeError):
"""Raised when the iterator value is invalid."""
class StartNodeIdNotFoundError(IterationNodeError):
"""Raised when the start node ID is not found."""
class IterationGraphNotFoundError(IterationNodeError):
"""Raised when the iteration graph is not found."""
class IterationIndexNotFoundError(IterationNodeError):
"""Raised when the iteration index is not found."""

View File

@@ -0,0 +1,667 @@
import contextvars
import logging
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, NewType, cast
from flask import Flask, current_app
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
)
from core.workflow.node_events import (
IterationFailedEvent,
IterationNextEvent,
IterationStartedEvent,
IterationSucceededEvent,
NodeEventBase,
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
from .exc import (
InvalidIteratorValueError,
IterationGraphNotFoundError,
IterationIndexNotFoundError,
IterationNodeError,
IteratorVariableNotFoundError,
StartNodeIdNotFoundError,
)
if TYPE_CHECKING:
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
class IterationNode(LLMUsageTrackingMixin, Node):
"""
Iteration Node.
"""
node_type = NodeType.ITERATION
execution_type = NodeExecutionType.CONTAINER
_node_data: IterationNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "iteration",
"config": {
"is_parallel": False,
"parallel_nums": 10,
"error_handle_mode": ErrorHandleMode.TERMINATED,
"flatten_output": True,
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore
variable = self._get_iterator_variable()
if self._is_empty_iteration(variable):
yield from self._handle_empty_iteration(variable)
return
iterator_list_value = self._validate_and_get_iterator_list(variable)
inputs = {"iterator_selector": iterator_list_value}
self._validate_start_node()
started_at = naive_utc_now()
iter_run_map: dict[str, float] = {}
outputs: list[object] = []
usage_accumulator = [LLMUsage.empty_usage()]
yield IterationStartedEvent(
start_at=started_at,
inputs=inputs,
metadata={"iteration_length": len(iterator_list_value)},
)
try:
yield from self._execute_iterations(
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
)
self._accumulate_usage(usage_accumulator[0])
yield from self._handle_iteration_success(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
usage=usage_accumulator[0],
)
except IterationNodeError as e:
self._accumulate_usage(usage_accumulator[0])
yield from self._handle_iteration_failure(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
usage=usage_accumulator[0],
error=e,
)
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
return variable
def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]:
return isinstance(variable, NoneSegment) or len(variable.value) == 0
def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]:
# Try our best to preserve the type information.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
# TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
)
)
def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]:
iterator_list_value = variable.to_object()
if not isinstance(iterator_list_value, list):
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
return cast(list[object], iterator_list_value)
def _validate_start_node(self) -> None:
if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
def _execute_iterations(
self,
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
if self._node_data.is_parallel:
# Parallel mode execution
yield from self._execute_parallel_iterations(
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
)
else:
# Sequential mode execution
for index, item in enumerate(iterator_list_value):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
yield IterationNextEvent(index=index)
graph_engine = self._create_graph_engine(index, item)
# Run the iteration
yield from self._run_single_iter(
variable_pool=graph_engine.graph_runtime_state.variable_pool,
outputs=outputs,
graph_engine=graph_engine,
)
# Sync conversation variables after each iteration completes
self._sync_conversation_variables_from_snapshot(
self._extract_conversation_variable_snapshot(
variable_pool=graph_engine.graph_runtime_state.variable_pool
)
)
# Accumulate usage from this iteration
usage_accumulator[0] = self._merge_usage(
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
)
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
def _execute_parallel_iterations(
self,
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
# Initialize outputs list with None values to maintain order
outputs.extend([None] * len(iterator_list_value))
# Determine the number of parallel workers
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks
future_to_index: dict[
Future[
tuple[
datetime,
list[GraphNodeEventBase],
object | None,
dict[str, VariableUnion],
LLMUsage,
]
],
int,
] = {}
for index, item in enumerate(iterator_list_value):
yield IterationNextEvent(index=index)
future = executor.submit(
self._execute_single_iteration_parallel,
index=index,
item=item,
flask_app=current_app._get_current_object(), # type: ignore
context_vars=contextvars.copy_context(),
)
future_to_index[future] = index
# Process completed iterations as they finish
for future in as_completed(future_to_index):
index = future_to_index[future]
try:
result = future.result()
(
iter_start_at,
events,
output_value,
conversation_snapshot,
iteration_usage,
) = result
# Update outputs at the correct index
outputs[index] = output_value
# Yield all events from this iteration
yield from events
# Update tokens and timing
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
# Sync conversation variables after iteration completion
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
except Exception as e:
# Handle errors based on error_handle_mode
match self._node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED:
# Cancel remaining futures and re-raise
for f in future_to_index:
if f != future:
f.cancel()
raise IterationNodeError(str(e))
case ErrorHandleMode.CONTINUE_ON_ERROR:
outputs[index] = None
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs[index] = None # Will be filtered later
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs[:] = [output for output in outputs if output is not None]
def _execute_single_iteration_parallel(
self,
index: int,
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
graph_engine = self._create_graph_engine(index, item)
# Collect events instead of yielding them directly
for event in self._run_single_iter(
variable_pool=graph_engine.graph_runtime_state.variable_pool,
outputs=outputs_temp,
graph_engine=graph_engine,
):
events.append(event)
# Get the output value from the temporary outputs list
output_value = outputs_temp[0] if outputs_temp else None
conversation_snapshot = self._extract_conversation_variable_snapshot(
variable_pool=graph_engine.graph_runtime_state.variable_pool
)
return (
iter_start_at,
events,
output_value,
conversation_snapshot,
graph_engine.graph_runtime_state.llm_usage,
)
def _handle_iteration_success(
self,
started_at: datetime,
inputs: dict[str, Sequence[object]],
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
*,
usage: LLMUsage,
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists
flattened_outputs = self._flatten_outputs_if_needed(outputs)
yield IterationSucceededEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
)
# Yield final success event
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": flattened_outputs},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
)
def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]:
"""
Flatten the outputs list if all elements are lists.
This maintains backward compatibility with version 1.8.1 behavior.
If flatten_output is False, returns outputs as-is (nested structure).
If flatten_output is True (default), flattens the list if all elements are lists.
"""
# If flatten_output is disabled, return outputs as-is
if not self._node_data.flatten_output:
return outputs
if not outputs:
return outputs
# Check if all non-None outputs are lists
non_none_outputs = [output for output in outputs if output is not None]
if not non_none_outputs:
return outputs
if all(isinstance(output, list) for output in non_none_outputs):
# Flatten the list of lists
flattened: list[Any] = []
for output in outputs:
if isinstance(output, list):
flattened.extend(output)
elif output is not None:
# This shouldn't happen based on our check, but handle it gracefully
flattened.append(output)
return flattened
return outputs
def _handle_iteration_failure(
self,
started_at: datetime,
inputs: dict[str, Sequence[object]],
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
*,
usage: LLMUsage,
error: IterationNodeError,
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists (even in failure case)
flattened_outputs = self._flatten_outputs_if_needed(outputs)
yield IterationFailedEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
error=str(error),
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(error),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = IterationNodeData.model_validate(node_data)
variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": typed_node_data.iterator_selector,
}
iteration_node_ids = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("iteration_id") == node_id:
in_iteration_node_id = node.get("id")
if in_iteration_node_id:
iteration_node_ids.add(in_iteration_node_id)
# Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("iteration_id") != node_id:
continue
# variable selector to variable mapping
try:
# Get node class
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
if node_type not in NODE_TYPE_CLASSES_MAPPING:
continue
node_version = sub_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
)
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
except NotImplementedError:
sub_node_variable_mapping = {}
# remove iteration variables
sub_node_variable_mapping = {
sub_node_id + "." + key: value
for key, value in sub_node_variable_mapping.items()
if value[0] != node_id
}
variable_mapping.update(sub_node_variable_mapping)
# remove variable out from iteration
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}
return variable_mapping
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
parent_pool = self.graph_runtime_state.variable_pool
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
current_keys = set(parent_conversations.keys())
snapshot_keys = set(snapshot.keys())
for removed_key in current_keys - snapshot_keys:
parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key))
for name, variable in snapshot.items():
parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable)
def _append_iteration_info_to_event(
self,
event: GraphNodeEventBase,
iter_run_index: int,
):
event.in_iteration_id = self._node_id
iter_metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: self._node_id,
WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index,
}
current_metadata = event.node_run_result.metadata
if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata:
event.node_run_result.metadata = {**current_metadata, **iter_metadata}
def _run_single_iter(
self,
*,
variable_pool: VariablePool,
outputs: list[object],
graph_engine: "GraphEngine",
) -> Generator[GraphNodeEventBase, None, None]:
rst = graph_engine.run()
# get current iteration index
index_variable = variable_pool.get([self._node_id, "index"])
if not isinstance(index_variable, IntegerVariable):
raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found")
current_index = index_variable.value
for event in rst:
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.ITERATION_START:
continue
if isinstance(event, GraphNodeEventBase):
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
yield event
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
result = variable_pool.get(self._node_data.output_selector)
if result is None:
outputs.append(None)
else:
outputs.append(result.to_object())
return
elif isinstance(event, GraphRunFailedEvent):
match self._node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED:
raise IterationNodeError(event.error)
case ErrorHandleMode.CONTINUE_ON_ERROR:
outputs.append(None)
return
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
return
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
graph_config=self.graph_config,
user_id=self.user_id,
user_from=self.user_from.value,
invoke_from=self.invoke_from.value,
call_depth=self.workflow_call_depth,
)
# Create a deep copy of the variable pool for each iteration
variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True)
# append iteration variable (item, index) to variable pool
variable_pool_copy.add([self._node_id, "index"], index)
variable_pool_copy.add([self._node_id, "item"], item)
# Create a new GraphRuntimeState for this iteration
graph_runtime_state_copy = GraphRuntimeState(
variable_pool=variable_pool_copy,
start_at=self.graph_runtime_state.start_at,
total_tokens=0,
node_run_steps=0,
)
# Create a new node factory with the new GraphRuntimeState
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
)
# Initialize the iteration graph with the new node factory
iteration_graph = Graph.init(
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id
)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
workflow_id=self.workflow_id,
graph=iteration_graph,
graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
)
return graph_engine

View File

@@ -0,0 +1,49 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import IterationStartNodeData
class IterationStartNode(Node):
"""
Iteration Start Node.
"""
node_type = NodeType.ITERATION_START
_node_data: IterationStartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationStartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the node.
"""
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)

View File

@@ -0,0 +1,3 @@
from .knowledge_index_node import KnowledgeIndexNode
__all__ = ["KnowledgeIndexNode"]

View File

@@ -0,0 +1,160 @@
from typing import Literal, Union
from pydantic import BaseModel
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.nodes.base import BaseNodeData
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
reranking_provider_name: str
reranking_model_name: str
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class RetrievalSetting(BaseModel):
"""
Retrieval Setting.
"""
search_method: RetrievalMethod
top_k: int
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting
class FileInfo(BaseModel):
"""
File Info.
"""
file_id: str
class OnlineDocumentIcon(BaseModel):
"""
Document Icon.
"""
icon_url: str
icon_type: str
icon_emoji: str
class OnlineDocumentInfo(BaseModel):
"""
Online document info.
"""
provider: str
workspace_id: str | None = None
page_id: str
page_type: str
icon: OnlineDocumentIcon | None = None
class WebsiteInfo(BaseModel):
"""
website import info.
"""
provider: str
url: str
class GeneralStructureChunk(BaseModel):
"""
General Structure Chunk.
"""
general_chunks: list[str]
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
class ParentChildChunk(BaseModel):
"""
Parent Child Chunk.
"""
parent_content: str
child_contents: list[str]
class ParentChildStructureChunk(BaseModel):
"""
Parent Child Structure Chunk.
"""
parent_child_chunks: list[ParentChildChunk]
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
class KnowledgeIndexNodeData(BaseNodeData):
"""
Knowledge index Node Data.
"""
type: str = "knowledge-index"
chunk_structure: str
index_chunk_variable_selector: list[str]

View File

@@ -0,0 +1,22 @@
class KnowledgeIndexNodeError(ValueError):
"""Base class for KnowledgeIndexNode errors."""
class ModelNotExistError(KnowledgeIndexNodeError):
"""Raised when the model does not exist."""
class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError):
"""Raised when the model credentials are not initialized."""
class ModelNotSupportedError(KnowledgeIndexNodeError):
"""Raised when the model is not supported."""
class ModelQuotaExceededError(KnowledgeIndexNodeError):
"""Raised when the model provider quota is exceeded."""
class InvalidModelTypeError(KnowledgeIndexNodeError):
"""Raised when the model is not a Large Language Model."""

View File

@@ -0,0 +1,214 @@
import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any
from sqlalchemy import func, select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from .entities import KnowledgeIndexNodeData
from .exc import (
KnowledgeIndexNodeError,
)
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
class KnowledgeIndexNode(Node):
_node_data: KnowledgeIndexNodeData
node_type = NodeType.KNOWLEDGE_INDEX
execution_type = NodeExecutionType.RESPONSE
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = KnowledgeIndexNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self) -> NodeRunResult: # type: ignore
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
if not dataset_id:
raise KnowledgeIndexNodeError("Dataset ID is required.")
dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.")
# extract variables
variable = variable_pool.get(node_data.index_chunk_variable_selector)
if not variable:
raise KnowledgeIndexNodeError("Index chunk variable is required.")
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
if invoke_from:
is_preview = invoke_from.value == InvokeFrom.DEBUGGER
else:
is_preview = False
chunks = variable.value
variables = {"chunks": chunks}
if not chunks:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
)
# index knowledge
try:
if is_preview:
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
outputs=outputs,
)
results = self._invoke_knowledge_index(
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=results)
except KnowledgeIndexNodeError as e:
logger.warning("Error when running knowledge index node")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
def _invoke_knowledge_index(
self,
dataset: Dataset,
node_data: KnowledgeIndexNodeData,
chunks: Mapping[str, Any],
variable_pool: VariablePool,
) -> Any:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if not document_id:
raise KnowledgeIndexNodeError("Document ID is required.")
original_document_id = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID])
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
if not batch:
raise KnowledgeIndexNodeError("Batch is required.")
document = db.session.query(Document).filter_by(id=document_id.value).first()
if not document:
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
doc_id_value = document.id
ds_id_value = dataset.id
dataset_name_value = dataset.name
document_name_value = document.name
created_at_value = document.created_at
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
if original_document_id:
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == original_document_id.value)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
index_processor.index(dataset, document, chunks)
indexing_end_at = time.perf_counter()
document.indexing_latency = indexing_end_at - indexing_start_at
# update document status
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.word_count = (
db.session.query(func.sum(DocumentSegment.word_count))
.where(
DocumentSegment.document_id == doc_id_value,
DocumentSegment.dataset_id == ds_id_value,
)
.scalar()
)
db.session.add(document)
# update document segment status
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == doc_id_value,
DocumentSegment.dataset_id == ds_id_value,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
)
db.session.commit()
return {
"dataset_id": ds_id_value,
"dataset_name": dataset_name_value,
"batch": batch.value,
"document_id": doc_id_value,
"document_name": document_name_value,
"created_at": created_at_value.timestamp(),
"display_status": "completed",
}
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
return index_processor.format_preview(chunks)
@classmethod
def version(cls) -> str:
return "1"
def get_streaming_template(self) -> Template:
"""
Get the template for streaming.
Returns:
Template instance for this knowledge index node
"""
return Template(segments=[])

View File

@@ -0,0 +1,3 @@
from .knowledge_retrieval_node import KnowledgeRetrievalNode
__all__ = ["KnowledgeRetrievalNode"]

View File

@@ -0,0 +1,134 @@
from collections.abc import Sequence
from typing import Literal
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
provider: str
model: str
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting
class MultipleRetrievalConfig(BaseModel):
"""
Multiple Retrieval Config.
"""
top_k: int
score_threshold: float | None = None
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class SingleRetrievalConfig(BaseModel):
"""
Single Retrieval Config.
"""
model: ModelConfig
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
"in",
"not in",
# for number
"=",
"",
">",
"<",
"",
"",
# for time
"before",
"after",
]
class Condition(BaseModel):
"""
Condition detail
"""
name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None
class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class KnowledgeRetrievalNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
type: str = "knowledge-retrieval"
query_variable_selector: list[str]
dataset_ids: list[str]
retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: MultipleRetrievalConfig | None = None
single_retrieval_config: SingleRetrievalConfig | None = None
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
metadata_model_config: ModelConfig | None = None
metadata_filtering_conditions: MetadataFilteringCondition | None = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property
def structured_output_enabled(self) -> bool:
# NOTE(QuantumGhost): Temporary workaround for issue #20725
# (https://github.com/langgenius/dify/issues/20725).
#
# The proper fix would be to make `KnowledgeRetrievalNode` inherit
# from `BaseNode` instead of `LLMNode`.
return False

View File

@@ -0,0 +1,22 @@
class KnowledgeRetrievalNodeError(ValueError):
"""Base class for KnowledgeRetrievalNode errors."""
class ModelNotExistError(KnowledgeRetrievalNodeError):
"""Raised when the model does not exist."""
class ModelCredentialsNotInitializedError(KnowledgeRetrievalNodeError):
"""Raised when the model credentials are not initialized."""
class ModelNotSupportedError(KnowledgeRetrievalNodeError):
"""Raised when the model is not supported."""
class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
"""Raised when the model provider quota is exceeded."""
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
"""Raised when the model is not a Large Language Model."""

View File

@@ -0,0 +1,792 @@
import json
import logging
import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import (
StringSegment,
)
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
ErrorStrategy,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from services.feature_service import FeatureService
from .entities import KnowledgeRetrievalNodeData
from .exc import (
InvalidModelTypeError,
KnowledgeRetrievalNodeError,
ModelCredentialsNotInitializedError,
ModelNotExistError,
ModelNotSupportedError,
ModelQuotaExceededError,
)
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
"score_threshold_enabled": False,
}
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_data: KnowledgeRetrievalNodeData
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls):
return "1"
def _run(self) -> NodeRunResult:
# extract variables
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Query variable is not string type.",
)
query = variable.value
variables = {"query": query}
if not query:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
)
# TODO(-LAN-): Move this check outside.
# check rate limit
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{self.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
with sessionmaker(db.engine).begin() as session:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
session.add(rate_limit_log)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
error_type="RateLimitExceeded",
)
# retrieve knowledge
usage = LLMUsage.empty_usage()
try:
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data={"usage": jsonable_encoder(usage)},
outputs=outputs, # type: ignore
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
except KnowledgeRetrievalNodeError as e:
logger.warning("Error when running knowledge retrieval node")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
finally:
db.session.close()
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, query: str
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
available_datasets = []
dataset_ids = node_data.dataset_ids
# Subquery: Count the number of available documents for each dataset
subquery = (
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.dataset_id.in_(dataset_ids),
)
.group_by(Document.dataset_id)
.having(func.count(Document.id) > 0)
.subquery()
)
results = (
db.session.query(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)
# avoid blocking at retrieval
db.session.close()
for dataset in results:
# pass if dataset is not available
if not dataset:
continue
available_datasets.append(dataset)
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
usage = self._merge_usage(usage, metadata_usage)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# fetch model config
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model)
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model, credentials=model_config.credentials
)
if model_schema:
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
all_documents = dataset_retrieval.single_retrieve(
available_datasets=available_datasets,
tenant_id=self.tenant_id,
user_id=self.user_id,
app_id=self.app_id,
user_from=self.user_from.value,
query=query,
model_config=model_config,
model_instance=model_instance,
planning_strategy=planning_strategy,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
else:
reranking_model = None
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
"vector_setting": {
"vector_weight": vector_setting.vector_weight,
"embedding_provider_name": vector_setting.embedding_provider_name,
"embedding_model_name": vector_setting.embedding_model_name,
},
"keyword_setting": {
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
},
}
else:
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve(
app_id=self.app_id,
tenant_id=self.tenant_id,
user_id=self.user_id,
user_from=self.user_from.value,
available_datasets=available_datasets,
query=query,
top_k=node_data.multiple_retrieval_config.top_k,
score_threshold=node_data.multiple_retrieval_config.score_threshold
if node_data.multiple_retrieval_config.score_threshold is not None
else 0.0,
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model=reranking_model,
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
retrieval_resource_list = []
# deal with external documents
for item in external_documents:
source = {
"metadata": {
"_source": "knowledge",
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": "workflow",
"score": item.metadata.get("score"),
"doc_metadata": item.metadata,
},
"title": item.metadata.get("title"),
"content": item.page_content,
}
retrieval_resource_list.append(source)
# deal with dify documents
if dify_documents:
records = RetrievalService.format_retrieval_documents(dify_documents)
if records:
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
document = db.session.scalar(stmt)
if dataset and document:
source = {
"metadata": {
"_source": "knowledge",
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": "workflow",
"score": record.score or 0.0,
"child_chunks": [
{
"id": str(getattr(chunk, "id", "")),
"content": str(getattr(chunk, "content", "")),
"position": int(getattr(chunk, "position", 0)),
"score": float(getattr(chunk, "score", 0.0)),
}
for chunk in (record.child_chunks or [])
],
"segment_hit_count": segment.hit_count,
"segment_word_count": segment.word_count,
"segment_position": segment.position,
"segment_index_node_hash": segment.index_node_hash,
"doc_metadata": document.doc_metadata,
},
"title": document.name,
}
if segment.answer:
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
else:
source["content"] = segment.get_sign_content()
retrieval_resource_list.append(source)
if retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list,
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position
return retrieval_resource_list, usage
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
usage = LLMUsage.empty_usage()
document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
filters: list[Any] = []
metadata_condition = None
if node_data.metadata_filtering_mode == "disabled":
return None, None, usage
elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
self._process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters,
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or",
conditions=conditions,
)
elif node_data.metadata_filtering_mode == "manual":
if node_data.metadata_filtering_conditions:
conditions = []
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = self._process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
else:
raise ValueError("Invalid metadata filtering mode")
if filters:
if (
node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and"
):
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.where(or_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition, usage
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
# get all metadata field
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance and fetch model config
model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
# fetch prompt messages
prompt_template = self._get_prompt_template(
node_data=node_data,
metadata_fields=all_metadata_fields,
query=query or "",
)
prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query=query,
memory=None,
model_config=model_config,
sys_files=[],
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
tenant_id=self.tenant_id,
)
result_text = ""
try:
# handle invoke result
generator = LLMNode.invoke_llm(
node_data_model=node_data.metadata_model_config,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
)
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = self._merge_usage(usage, event.usage)
break
result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = []
if "metadata_map" in result_text_json:
metadata_map = result_text_json["metadata_map"]
for item in metadata_map:
if item.get("metadata_field_name") in all_metadata_fields:
automatic_metadata_filters.append(
{
"metadata_name": item.get("metadata_field_name"),
"value": item.get("metadata_field_value"),
"condition": item.get("comparison_operator"),
}
)
except Exception:
return [], usage
return automatic_metadata_filters, usage
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
) -> list[Any]:
if value is None and condition not in ("empty", "not empty"):
return filters
json_field = Document.doc_metadata[metadata_name].as_string()
match condition:
case "contains":
filters.append(json_field.like(f"%{value}%"))
case "not contains":
filters.append(json_field.notlike(f"%{value}%"))
case "start with":
filters.append(json_field.like(f"{value}%"))
case "end with":
filters.append(json_field.like(f"%{value}"))
case "in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(False))
else:
filters.append(json_field.in_(value_list))
case "not in":
if isinstance(value, str):
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(True))
else:
filters.append(json_field.notin_(value_list))
case "is" | "=":
if isinstance(value, str):
filters.append(json_field == value)
elif isinstance(value, (int, float)):
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
case "is not" | "":
if isinstance(value, str):
filters.append(json_field != value)
elif isinstance(value, (int, float)):
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
case "empty":
filters.append(Document.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(Document.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
case "after" | ">":
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
case "" | "<=":
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
case "" | ">=":
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
case _:
pass
return filters
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
variable_mapping = {}
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
return variable_mapping
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model_name = model.name
provider_name = model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name, model_type=ModelType.LLM
)
if provider_model is None:
raise ModelNotExistError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = model.completion_params
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = model.mode
if not model_mode:
raise ModelNotExistError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema:
raise ModelNotExistError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore
input_text = query
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER,
text=METADATA_FILTER_USER_PROMPT_3.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
),
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=METADATA_FILTER_COMPLETION_PROMPT.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
)
)
else:
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")

View File

@@ -0,0 +1,66 @@
METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "=", "", ">", "<", "", "", "before", "after"] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
""" # noqa: E501
METADATA_FILTER_USER_PROMPT_1 = """
{ "input_text": "I want to know which companys email address test@example.com is?",
"metadata_fields": ["filename", "email", "phone", "address"]
}
"""
METADATA_FILTER_ASSISTANT_PROMPT_1 = """
```json
{"metadata_map": [
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
]
}
```
"""
METADATA_FILTER_USER_PROMPT_2 = """
{"input_text": "What are the movies with a score of more than 9 in 2024?",
"metadata_fields": ["name", "year", "rating", "country"]}
"""
METADATA_FILTER_ASSISTANT_PROMPT_2 = """
```json
{"metadata_map": [
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
]}
```
"""
METADATA_FILTER_USER_PROMPT_3 = """
'{{"input_text": "{input_text}",',
'"metadata_fields": {metadata_fields}}}'
"""
METADATA_FILTER_COMPLETION_PROMPT = """
### Job Description
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I want to know which companys email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
</example>
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output
""" # noqa: E501

View File

@@ -0,0 +1,3 @@
from .node import ListOperatorNode
__all__ = ["ListOperatorNode"]

View File

@@ -0,0 +1,69 @@
from collections.abc import Sequence
from enum import StrEnum
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
class FilterOperator(StrEnum):
# string conditions
CONTAINS = "contains"
START_WITH = "start with"
END_WITH = "end with"
IS = "is"
IN = "in"
EMPTY = "empty"
NOT_CONTAINS = "not contains"
IS_NOT = "is not"
NOT_IN = "not in"
NOT_EMPTY = "not empty"
# number conditions
EQUAL = "="
NOT_EQUAL = ""
LESS_THAN = "<"
GREATER_THAN = ">"
GREATER_THAN_OR_EQUAL = ""
LESS_THAN_OR_EQUAL = ""
class Order(StrEnum):
ASC = "asc"
DESC = "desc"
class FilterCondition(BaseModel):
key: str = ""
comparison_operator: FilterOperator = FilterOperator.CONTAINS
# the value is bool if the filter operator is comparing with
# a boolean constant.
value: str | Sequence[str] | bool = ""
class FilterBy(BaseModel):
enabled: bool = False
conditions: Sequence[FilterCondition] = Field(default_factory=list)
class OrderByConfig(BaseModel):
enabled: bool = False
key: str = ""
value: Order = Order.ASC
class Limit(BaseModel):
enabled: bool = False
size: int = -1
class ExtractConfig(BaseModel):
enabled: bool = False
serial: str = "1"
class ListOperatorNodeData(BaseNodeData):
variable: Sequence[str] = Field(default_factory=list)
filter_by: FilterBy
order_by: OrderByConfig
limit: Limit
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)

View File

@@ -0,0 +1,16 @@
class ListOperatorError(ValueError):
"""Base class for all ListOperator errors."""
pass
class InvalidFilterValueError(ListOperatorError):
pass
class InvalidKeyError(ListOperatorError):
pass
class InvalidConditionError(ListOperatorError):
pass

View File

@@ -0,0 +1,370 @@
from collections.abc import Callable, Mapping, Sequence
from typing import Any, TypeAlias, TypeVar
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import FilterOperator, ListOperatorNodeData, Order
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
_SUPPORTED_TYPES_TUPLE = (
ArrayFileSegment,
ArrayNumberSegment,
ArrayStringSegment,
ArrayBooleanSegment,
)
_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment
_T = TypeVar("_T")
def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
"""Returns the negation of a given filter function. If the original filter
returns `True` for a value, the negated filter will return `False`, and vice versa.
"""
def wrapper(value: _T) -> bool:
return not filter_(value)
return wrapper
class ListOperatorNode(Node):
node_type = NodeType.LIST_OPERATOR
_node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ListOperatorNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
inputs: dict[str, Sequence[object]] = {}
process_data: dict[str, Sequence[object]] = {}
outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
if variable is None:
error_message = f"Variable not found for selector: {self._node_data.variable}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
if not variable.value:
inputs = {"variable": []}
process_data = {"variable": []}
if isinstance(variable, ArraySegment):
result = variable.model_copy(update={"value": []})
else:
result = ArrayAnySegment(value=[])
outputs = {"result": result, "first_record": None, "last_record": None}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
if isinstance(variable, ArrayFileSegment):
inputs = {"variable": [item.to_dict() for item in variable.value]}
process_data["variable"] = [item.to_dict() for item in variable.value]
else:
inputs = {"variable": variable.value}
process_data["variable"] = variable.value
try:
# Filter
if self._node_data.filter_by.enabled:
variable = self._apply_filter(variable)
# Extract
if self._node_data.extract_by.enabled:
variable = self._extract_slice(variable)
# Order
if self._node_data.order_by.enabled:
variable = self._apply_order(variable)
# Slice
if self._node_data.limit.enabled:
variable = self._apply_slice(variable)
outputs = {
"result": variable,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
except ListOperatorError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
filter_func: Callable[[Any], bool]
result: list[Any] = []
for condition in self._node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
if isinstance(condition.value, str):
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
elif isinstance(condition.value, bool):
raise ValueError(f"File filter expects a string value, got {type(condition.value)}")
else:
value = condition.value
filter_func = _get_file_filter_func(
key=condition.key,
condition=condition.comparison_operator,
value=value,
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
else:
if not isinstance(condition.value, bool):
raise ValueError(f"Boolean filter expects a boolean value, got {type(condition.value)}")
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
return variable
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
result = sorted(variable.value, reverse=self._node_data.order_by.value == Order.DESC)
variable = variable.model_copy(update={"value": result})
else:
result = _order_file(
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
return variable
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
result = variable.value[: self._node_data.limit.size]
return variable.model_copy(update={"value": result})
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
if value > len(variable.value):
raise InvalidKeyError(f"Invalid serial index: must be <= {len(variable.value)}, got {value}")
value -= 1
result = variable.value[value]
return variable.model_copy(update={"value": [result]})
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
match key:
case "size":
return lambda x: x.size
case _:
raise InvalidKeyError(f"Invalid key: {key}")
def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
match key:
case "name":
return lambda x: x.filename or ""
case "type":
return lambda x: x.type
case "extension":
return lambda x: x.extension or ""
case "mime_type":
return lambda x: x.mime_type or ""
case "transfer_method":
return lambda x: x.transfer_method
case "url":
return lambda x: x.remote_url or ""
case "related_id":
return lambda x: x.related_id or ""
case _:
raise InvalidKeyError(f"Invalid key: {key}")
def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
match condition:
case "contains":
return _contains(value)
case "start with":
return _startswith(value)
case "end with":
return _endswith(value)
case "is":
return _is(value)
case "in":
return _in(value)
case "empty":
return lambda x: x == ""
case "not contains":
return _negation(_contains(value))
case "is not":
return _negation(_is(value))
case "not in":
return _negation(_in(value))
case "not empty":
return lambda x: x != ""
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
match condition:
case "in":
return _in(value)
case "not in":
return _negation(_in(value))
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
match condition:
case "=":
return _eq(value)
case "":
return _ne(value)
case "<":
return _lt(value)
case "":
return _le(value)
case ">":
return _gt(value)
case "":
return _ge(value)
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]:
match condition:
case FilterOperator.IS:
return _is(value)
case FilterOperator.IS_NOT:
return _negation(_is(value))
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
if key in {"type", "transfer_method"}:
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
elif key == "size" and isinstance(value, str):
extract_func = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
else:
raise InvalidKeyError(f"Invalid key: {key}")
def _contains(value: str) -> Callable[[str], bool]:
return lambda x: value in x
def _startswith(value: str) -> Callable[[str], bool]:
return lambda x: x.startswith(value)
def _endswith(value: str) -> Callable[[str], bool]:
return lambda x: x.endswith(value)
def _is(value: _T) -> Callable[[_T], bool]:
return lambda x: x == value
def _in(value: str | Sequence[str]) -> Callable[[str], bool]:
return lambda x: x in value
def _eq(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x == value
def _ne(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x != value
def _lt(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x < value
def _le(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x <= value
def _gt(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x > value
def _ge(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x >= value
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
extract_func: Callable[[File], Any]
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url", "related_id"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
elif order_by == "size":
extract_func = _get_file_extract_number_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
else:
raise InvalidKeyError(f"Invalid order key: {order_by}")

View File

@@ -0,0 +1,17 @@
from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from .node import LLMNode
__all__ = [
"LLMNode",
"LLMNodeChatModelMessage",
"LLMNodeCompletionModelPromptTemplate",
"LLMNodeData",
"ModelConfig",
"VisionConfig",
]

View File

@@ -0,0 +1,98 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
class ModelConfig(BaseModel):
provider: str
name: str
mode: LLMMode
completion_params: dict[str, Any] = Field(default_factory=dict)
class ContextConfig(BaseModel):
enabled: bool
variable_selector: list[str] | None = None
class VisionConfigOptions(BaseModel):
variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"])
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH
class VisionConfig(BaseModel):
enabled: bool = False
configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions)
@field_validator("configs", mode="before")
@classmethod
def convert_none_configs(cls, v: Any):
if v is None:
return VisionConfigOptions()
return v
class PromptConfig(BaseModel):
jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list)
@field_validator("jinja2_variables", mode="before")
@classmethod
def convert_none_jinja2_variables(cls, v: Any):
if v is None:
return []
return v
class LLMNodeChatModelMessage(ChatModelMessage):
text: str = ""
jinja2_text: str | None = None
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
jinja2_text: str | None = None
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
memory: MemoryConfig | None = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: Mapping[str, Any] | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
reasoning_format: Literal["separated", "tagged"] = Field(
# Keep tagged as default for backward compatibility
default="tagged",
description=(
"""
Strategy for handling model reasoning output.
separated: Return clean text (without <think> tags) + reasoning_content field.
Recommended for new workflows. Enables safe downstream parsing and
workflow variable access: {{#node_id.reasoning_content#}}
tagged : Return original text (with <think> tags) + reasoning_content field.
Maintains full backward compatibility while still providing reasoning_content
for workflow automation. Frontend thinking panels work as before.
"""
),
)
@field_validator("prompt_config", mode="before")
@classmethod
def convert_none_prompt_config(cls, v: Any):
if v is None:
return PromptConfig()
return v
@property
def structured_output_enabled(self) -> bool:
return self.structured_output_switch_on and self.structured_output is not None

View File

@@ -0,0 +1,45 @@
class LLMNodeError(ValueError):
"""Base class for LLM Node errors."""
class VariableNotFoundError(LLMNodeError):
"""Raised when a required variable is not found."""
class InvalidContextStructureError(LLMNodeError):
"""Raised when the context structure is invalid."""
class InvalidVariableTypeError(LLMNodeError):
"""Raised when the variable type is invalid."""
class ModelNotExistError(LLMNodeError):
"""Raised when the specified model does not exist."""
class LLMModeRequiredError(LLMNodeError):
"""Raised when LLM mode is required but not provided."""
class NoPromptFoundError(LLMNodeError):
"""Raised when no prompt is found in the LLM configuration."""
class TemplateTypeNotSupportError(LLMNodeError):
def __init__(self, *, type_name: str):
super().__init__(f"Prompt type {type_name} is not supported.")
class MemoryRolePrefixRequiredError(LLMNodeError):
"""Raised when memory role prefix is required for completion model."""
class FileTypeNotSupportError(LLMNodeError):
def __init__(self, *, type_name: str):
super().__init__(f"{type_name} type is not supported by this model")
class UnsupportedPromptContentTypeError(LLMNodeError):
def __init__(self, *, type_name: str):
super().__init__(f"Prompt content type {type_name} is not supported.")

View File

@@ -0,0 +1,157 @@
import mimetypes
import typing as tp
from sqlalchemy import Engine
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.file import File, FileTransferMethod, FileType
from core.helper import ssrf_proxy
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from extensions.ext_database import db as global_db
class LLMFileSaver(tp.Protocol):
"""LLMFileSaver is responsible for save multimodal output returned by
LLM.
"""
def save_binary_string(
self,
data: bytes,
mime_type: str,
file_type: FileType,
extension_override: str | None = None,
) -> File:
"""save_binary_string saves the inline file data returned by LLM.
Currently (2025-04-30), only some of Google Gemini models will return
multimodal output as inline data.
:param data: the contents of the file
:param mime_type: the media type of the file, specified by rfc6838
(https://datatracker.ietf.org/doc/html/rfc6838)
:param file_type: The file type of the inline file.
:param extension_override: Override the auto-detected file extension while saving this file.
The default value is `None`, which means do not override the file extension and guessing it
from the `mime_type` attribute while saving the file.
Setting it to values other than `None` means override the file's extension, and
will bypass the extension guessing saving the file.
Specially, setting it to empty string (`""`) will leave the file extension empty.
When it is not `None` or empty string (`""`), it should be a string beginning with a
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
and `tar.gz` are not.
"""
raise NotImplementedError()
def save_remote_url(self, url: str, file_type: FileType) -> File:
"""save_remote_url saves the file from a remote url returned by LLM.
Currently (2025-04-30), no model returns multimodel output as a url.
:param url: the url of the file.
:param file_type: the file type of the file, check `FileType` enum for reference.
"""
raise NotImplementedError()
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
class FileSaverImpl(LLMFileSaver):
_engine_factory: EngineFactory
_tenant_id: str
_user_id: str
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
if engine_factory is None:
def _factory():
return global_db.engine
engine_factory = _factory
self._engine_factory = engine_factory
self._user_id = user_id
self._tenant_id = tenant_id
def _get_tool_file_manager(self):
return ToolFileManager(engine=self._engine_factory())
def save_remote_url(self, url: str, file_type: FileType) -> File:
http_response = ssrf_proxy.get(url)
http_response.raise_for_status()
data = http_response.content
mime_type_from_header = http_response.headers.get("Content-Type")
mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
def save_binary_string(
self,
data: bytes,
mime_type: str,
file_type: FileType,
extension_override: str | None = None,
) -> File:
tool_file_manager = self._get_tool_file_manager()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self._user_id,
tenant_id=self._tenant_id,
# TODO(QuantumGhost): what is conversation id?
conversation_id=None,
file_binary=data,
mimetype=mime_type,
)
extension_override = _validate_extension_override(extension_override)
extension = _get_extension(mime_type, extension_override)
url = sign_tool_file(tool_file.id, extension)
return File(
tenant_id=self._tenant_id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
filename=tool_file.name,
extension=extension,
mime_type=mime_type,
size=len(data),
related_id=tool_file.id,
url=url,
storage_key=tool_file.file_key,
)
def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
"""get_extension return the extension of file.
If the `extension_override` parameter is set, this function should honor it and
return its value.
"""
if extension_override is not None:
return extension_override
return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
"""_extract_content_type_and_extension tries to
guess content type of file from url and `Content-Type` header in response.
"""
if content_type_header:
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
return content_type_header, extension
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
return content_type, extension
def _validate_extension_override(extension_override: str | None) -> str | None:
# `extension_override` is allow to be `None or `""`.
if extension_override is None:
return None
if extension_override == "":
return ""
if not extension_override.startswith("."):
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
return extension_override

View File

@@ -0,0 +1,156 @@
from collections.abc import Sequence
from typing import cast
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import QuotaUnit
from core.file.models import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
def fetch_model_config(
tenant_id: str, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
model = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=node_data_model.provider,
model=node_data_model.name,
)
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
# check model
provider_model = model.provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name, model_type=ModelType.LLM
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
# model config
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
return model, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=model.provider_model_bundle,
credentials=model.credentials,
parameters=node_data_model.completion_params,
stop=stop,
)
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
variable = variable_pool.get(selector)
if variable is None:
return []
elif isinstance(variable, FileSegment):
return [variable.value]
elif isinstance(variable, ArrayFileSegment):
return variable.value
elif isinstance(variable, NoneSegment | ArrayAnySegment):
return []
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def fetch_memory(
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
from .entities import LoopNodeData
from .loop_end_node import LoopEndNode
from .loop_node import LoopNode
from .loop_start_node import LoopStartNode
__all__ = ["LoopEndNode", "LoopNode", "LoopNodeData", "LoopStartNode"]

View File

@@ -0,0 +1,98 @@
from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from core.workflow.utils.condition.entities import Condition
_VALID_VAR_TYPE = frozenset(
[
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
]
)
def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
if seg_type not in _VALID_VAR_TYPE:
raise ValueError(...)
return seg_type
class LoopVariableData(BaseModel):
"""
Loop Variable Data.
"""
label: str
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: Any | list[str] | None = None
class LoopNodeData(BaseLoopNodeData):
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
outputs: dict[str, Any] = Field(default_factory=dict)
@field_validator("outputs", mode="before")
@classmethod
def validate_outputs(cls, v):
if v is None:
return {}
return v
class LoopStartNodeData(BaseNodeData):
"""
Loop Start Node Data.
"""
pass
class LoopEndNodeData(BaseNodeData):
"""
Loop End Node Data.
"""
pass
class LoopState(BaseLoopState):
"""
Loop State.
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Any = None
class MetaData(BaseLoopState.MetaData):
"""
Data.
"""
loop_length: int
def get_last_output(self) -> Any:
"""
Get last output.
"""
if self.outputs:
return self.outputs[-1]
return None
def get_current_output(self) -> Any:
"""
Get current output.
"""
return self.current_output

View File

@@ -0,0 +1,49 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopEndNodeData
class LoopEndNode(Node):
"""
Loop End Node.
"""
node_type = NodeType.LOOP_END
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopEndNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the node.
"""
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)

View File

@@ -0,0 +1,463 @@
import contextlib
import json
import logging
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import Segment, SegmentType
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunFailedEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import (
LoopFailedEvent,
LoopNextEvent,
LoopStartedEvent,
LoopSucceededEvent,
NodeEventBase,
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
from core.workflow.utils.condition.processor import ConditionProcessor
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
from libs.datetime_utils import naive_utc_now
if TYPE_CHECKING:
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
class LoopNode(LLMUsageTrackingMixin, Node):
"""
Loop Node.
"""
node_type = NodeType.LOOP
_node_data: LoopNodeData
execution_type = NodeExecutionType.CONTAINER
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
"""Run the node."""
# Get inputs
loop_count = self._node_data.loop_count
break_conditions = self._node_data.break_conditions
logical_operator = self._node_data.logical_operator
inputs = {"loop_count": loop_count}
if not self._node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
root_node_id = self._node_data.start_node_id
# Initialize loop variables in the original variable pool
loop_variable_selectors = {}
if self._node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
if isinstance(var.value, list)
else None,
}
for loop_variable in self._node_data.loop_variables:
if loop_variable.value_type not in value_processor:
raise ValueError(
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
)
processed_segment = value_processor[loop_variable.value_type](loop_variable)
if not processed_segment:
raise ValueError(f"Invalid value for loop variable {loop_variable.label}")
variable_selector = [self._node_id, loop_variable.label]
variable = segment_to_variable(segment=processed_segment, selector=variable_selector)
self.graph_runtime_state.variable_pool.add(variable_selector, variable.value)
loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value
start_at = naive_utc_now()
condition_processor = ConditionProcessor()
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
# Start Loop event
yield LoopStartedEvent(
start_at=start_at,
inputs=inputs,
metadata={"loop_length": loop_count},
)
try:
reach_break_condition = False
if break_conditions:
with contextlib.suppress(ValueError):
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
if reach_break_condition:
loop_count = 0
for i in range(loop_count):
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
loop_start_time = naive_utc_now()
reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i)
# Track loop duration
loop_duration_map[str(i)] = (naive_utc_now() - loop_start_time).total_seconds()
# Accumulate outputs from the sub-graph's response nodes
for key, value in graph_engine.graph_runtime_state.outputs.items():
if key == "answer":
# Concatenate answer outputs with newline
existing_answer = self.graph_runtime_state.get_output("answer", "")
if existing_answer:
self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}")
else:
self.graph_runtime_state.set_output("answer", value)
else:
# For other outputs, just update
self.graph_runtime_state.set_output(key, value)
# Accumulate usage from the sub-graph execution
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
# Collect loop variable values after iteration
single_loop_variable = {}
for key, selector in loop_variable_selectors.items():
segment = self.graph_runtime_state.variable_pool.get(selector)
single_loop_variable[key] = segment.value if segment else None
single_loop_variable_map[str(i)] = single_loop_variable
if reach_break_node:
break
if break_conditions:
_, _, reach_break_condition = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=break_conditions,
operator=logical_operator,
)
if reach_break_condition:
break
yield LoopNextEvent(
index=i + 1,
pre_loop_output=self._node_data.outputs,
)
self._accumulate_usage(loop_usage)
# Loop completed successfully
yield LoopSucceededEvent(
start_at=start_at,
inputs=inputs,
outputs=self._node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
outputs=self._node_data.outputs,
inputs=inputs,
llm_usage=loop_usage,
)
)
except Exception as e:
self._accumulate_usage(loop_usage)
yield LoopFailedEvent(
start_at=start_at,
inputs=inputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "error",
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
error=str(e),
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
llm_usage=loop_usage,
)
)
def _run_single_loop(
self,
*,
graph_engine: "GraphEngine",
current_index: int,
) -> Generator[NodeEventBase | GraphNodeEventBase, None, bool]:
reach_break_node = False
for event in graph_engine.run():
if isinstance(event, GraphNodeEventBase):
self._append_loop_info_to_event(event=event, loop_run_index=current_index)
if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.LOOP_START:
continue
if isinstance(event, GraphNodeEventBase):
yield event
if isinstance(event, NodeRunSucceededEvent) and event.node_type == NodeType.LOOP_END:
reach_break_node = True
if isinstance(event, GraphRunFailedEvent):
raise Exception(event.error)
for loop_var in self._node_data.loop_variables or []:
key, sel = loop_var.label, [self._node_id, loop_var.label]
segment = self.graph_runtime_state.variable_pool.get(sel)
self._node_data.outputs[key] = segment.value if segment else None
self._node_data.outputs["loop_round"] = current_index + 1
return reach_break_node
def _append_loop_info_to_event(
self,
event: GraphNodeEventBase,
loop_run_index: int,
):
event.in_loop_id = self._node_id
loop_metadata = {
WorkflowNodeExecutionMetadataKey.LOOP_ID: self._node_id,
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: loop_run_index,
}
current_metadata = event.node_run_result.metadata
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = LoopNodeData.model_validate(node_data)
variable_mapping = {}
# Extract loop node IDs statically from graph_config
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
# Get node configs from graph_config
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id:
continue
# variable selector to variable mapping
try:
# Get node class
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
if node_type not in NODE_TYPE_CLASSES_MAPPING:
continue
node_version = sub_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
)
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
except NotImplementedError:
sub_node_variable_mapping = {}
# remove loop variables
sub_node_variable_mapping = {
sub_node_id + "." + key: value
for key, value in sub_node_variable_mapping.items()
if value[0] != node_id
}
variable_mapping.update(sub_node_variable_mapping)
for loop_variable in typed_node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping
selector = loop_variable.value
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
# remove variable out from loop
variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids}
return variable_mapping
@classmethod
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
"""
Extract node IDs that belong to a specific loop from graph configuration.
This method statically analyzes the graph configuration to find all nodes
that are part of the specified loop, without creating actual node instances.
:param graph_config: the complete graph configuration
:param loop_node_id: the ID of the loop node
:return: set of node IDs that belong to the loop
"""
loop_node_ids = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("loop_id") == loop_node_id:
node_id = node.get("id")
if node_id:
loop_node_ids.add(node_id)
return loop_node_ids
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
# TODO: Refactor for maintainability:
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
# 2. Consider moving this method to LoopVariableData class for better encapsulation
if not var_type.is_array_type() or var_type == SegmentType.ARRAY_BOOLEAN:
value = original_value
elif var_type in [
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
]:
if original_value and isinstance(original_value, str):
value = json.loads(original_value)
else:
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
value = []
else:
raise AssertionError("this statement should be unreachable.")
try:
return build_segment_with_type(var_type, value=value)
except TypeMismatchError as type_exc:
# Attempt to parse the value as a JSON-encoded string, if applicable.
if not isinstance(original_value, str):
raise
try:
value = json.loads(original_value)
except ValueError:
raise type_exc
return build_segment_with_type(var_type, value)
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
graph_config=self.graph_config,
user_id=self.user_id,
user_from=self.user_from.value,
invoke_from=self.invoke_from.value,
call_depth=self.workflow_call_depth,
)
# Create a new GraphRuntimeState for this iteration
graph_runtime_state_copy = GraphRuntimeState(
variable_pool=self.graph_runtime_state.variable_pool,
start_at=start_at.timestamp(),
)
# Create a new node factory with the new GraphRuntimeState
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
)
# Initialize the loop graph with the new node factory
loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id)
# Create a new GraphEngine for this iteration
graph_engine = GraphEngine(
workflow_id=self.workflow_id,
graph=loop_graph,
graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
)
return graph_engine

View File

@@ -0,0 +1,49 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopStartNodeData
class LoopStartNode(Node):
"""
Loop Start Node.
"""
node_type = NodeType.LOOP_START
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopStartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the node.
"""
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)

View File

@@ -0,0 +1,85 @@
from typing import TYPE_CHECKING, final
from typing_extensions import override
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
@final
class DifyNodeFactory(NodeFactory):
"""
Default implementation of NodeFactory that uses the traditional node mapping.
This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
and instantiating the appropriate node class.
"""
def __init__(
self,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
@override
def create_node(self, node_config: dict[str, object]) -> Node:
"""
Create a Node instance from node configuration data using the traditional mapping.
:param node_config: node configuration dictionary containing type and other data
:return: initialized Node instance
:raises ValueError: if node type is unknown or configuration is invalid
"""
# Get node_id from config
node_id = node_config.get("id")
if not is_str(node_id):
raise ValueError("Node config missing id")
# Get node type from config
node_data = node_config.get("data", {})
if not is_str_dict(node_data):
raise ValueError(f"Node {node_id} missing data information")
node_type_str = node_data.get("type")
if not is_str(node_type_str):
raise ValueError(f"Node {node_id} missing or invalid type information")
try:
node_type = NodeType(node_type_str)
except ValueError:
raise ValueError(f"Unknown node type: {node_type_str}")
# Get node class
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
node_class = node_mapping.get(LATEST_VERSION)
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")
# Create node instance
node_instance = node_class(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
# Initialize node with provided data
node_data = node_config.get("data", {})
if not is_str_dict(node_data):
raise ValueError(f"Node {node_id} missing data information")
node_instance.init_node_data(node_data)
return node_instance

View File

@@ -0,0 +1,165 @@
from collections.abc import Mapping
from core.workflow.enums import NodeType
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code import CodeNode
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from core.workflow.nodes.list_operator import ListOperatorNode
from core.workflow.nodes.llm import LLMNode
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from core.workflow.nodes.question_classifier import QuestionClassifierNode
from core.workflow.nodes.start import StartNode
from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.trigger_plugin import TriggerEventNode
from core.workflow.nodes.trigger_schedule import TriggerScheduleNode
from core.workflow.nodes.trigger_webhook import TriggerWebhookNode
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
LATEST_VERSION = "latest"
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
# Specifically, if you have introduced new node types, you should add them here.
#
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
# hook. Try to avoid duplication of node information.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,
"1": StartNode,
},
NodeType.END: {
LATEST_VERSION: EndNode,
"1": EndNode,
},
NodeType.ANSWER: {
LATEST_VERSION: AnswerNode,
"1": AnswerNode,
},
NodeType.LLM: {
LATEST_VERSION: LLMNode,
"1": LLMNode,
},
NodeType.KNOWLEDGE_RETRIEVAL: {
LATEST_VERSION: KnowledgeRetrievalNode,
"1": KnowledgeRetrievalNode,
},
NodeType.IF_ELSE: {
LATEST_VERSION: IfElseNode,
"1": IfElseNode,
},
NodeType.CODE: {
LATEST_VERSION: CodeNode,
"1": CodeNode,
},
NodeType.TEMPLATE_TRANSFORM: {
LATEST_VERSION: TemplateTransformNode,
"1": TemplateTransformNode,
},
NodeType.QUESTION_CLASSIFIER: {
LATEST_VERSION: QuestionClassifierNode,
"1": QuestionClassifierNode,
},
NodeType.HTTP_REQUEST: {
LATEST_VERSION: HttpRequestNode,
"1": HttpRequestNode,
},
NodeType.TOOL: {
LATEST_VERSION: ToolNode,
# This is an issue that caused problems before.
# Logically, we shouldn't use two different versions to point to the same class here,
# but in order to maintain compatibility with historical data, this approach has been retained.
"2": ToolNode,
"1": ToolNode,
},
NodeType.VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
},
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
}, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: {
LATEST_VERSION: IterationNode,
"1": IterationNode,
},
NodeType.ITERATION_START: {
LATEST_VERSION: IterationStartNode,
"1": IterationStartNode,
},
NodeType.LOOP: {
LATEST_VERSION: LoopNode,
"1": LoopNode,
},
NodeType.LOOP_START: {
LATEST_VERSION: LoopStartNode,
"1": LoopStartNode,
},
NodeType.LOOP_END: {
LATEST_VERSION: LoopEndNode,
"1": LoopEndNode,
},
NodeType.PARAMETER_EXTRACTOR: {
LATEST_VERSION: ParameterExtractorNode,
"1": ParameterExtractorNode,
},
NodeType.VARIABLE_ASSIGNER: {
LATEST_VERSION: VariableAssignerNodeV2,
"1": VariableAssignerNodeV1,
"2": VariableAssignerNodeV2,
},
NodeType.DOCUMENT_EXTRACTOR: {
LATEST_VERSION: DocumentExtractorNode,
"1": DocumentExtractorNode,
},
NodeType.LIST_OPERATOR: {
LATEST_VERSION: ListOperatorNode,
"1": ListOperatorNode,
},
NodeType.AGENT: {
LATEST_VERSION: AgentNode,
# This is an issue that caused problems before.
# Logically, we shouldn't use two different versions to point to the same class here,
# but in order to maintain compatibility with historical data, this approach has been retained.
"2": AgentNode,
"1": AgentNode,
},
NodeType.HUMAN_INPUT: {
LATEST_VERSION: HumanInputNode,
"1": HumanInputNode,
},
NodeType.DATASOURCE: {
LATEST_VERSION: DatasourceNode,
"1": DatasourceNode,
},
NodeType.KNOWLEDGE_INDEX: {
LATEST_VERSION: KnowledgeIndexNode,
"1": KnowledgeIndexNode,
},
NodeType.TRIGGER_WEBHOOK: {
LATEST_VERSION: TriggerWebhookNode,
"1": TriggerWebhookNode,
},
NodeType.TRIGGER_PLUGIN: {
LATEST_VERSION: TriggerEventNode,
"1": TriggerEventNode,
},
NodeType.TRIGGER_SCHEDULE: {
LATEST_VERSION: TriggerScheduleNode,
"1": TriggerScheduleNode,
},
}

View File

@@ -0,0 +1,3 @@
from .parameter_extractor_node import ParameterExtractorNode
__all__ = ["ParameterExtractorNode"]

View File

@@ -0,0 +1,129 @@
from typing import Annotated, Any, Literal
from pydantic import (
BaseModel,
BeforeValidator,
Field,
field_validator,
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
_OLD_BOOL_TYPE_NAME = "bool"
_OLD_SELECT_TYPE_NAME = "select"
_VALID_PARAMETER_TYPES = frozenset(
[
SegmentType.STRING, # "string",
SegmentType.NUMBER, # "number",
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
_OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node
_OLD_SELECT_TYPE_NAME, # string type with enumeration choices.
]
)
def _validate_type(parameter_type: str) -> SegmentType:
if parameter_type not in _VALID_PARAMETER_TYPES:
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
if parameter_type == _OLD_BOOL_TYPE_NAME:
return SegmentType.BOOLEAN
elif parameter_type == _OLD_SELECT_TYPE_NAME:
return SegmentType.STRING
return SegmentType(parameter_type)
class ParameterConfig(BaseModel):
"""
Parameter Config.
"""
name: str
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
options: list[str] | None = None
description: str
required: bool
@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value) -> str:
if not value:
raise ValueError("Parameter name is required")
if value in {"__reason", "__is_success"}:
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
return str(value)
def is_array_type(self) -> bool:
return self.type.is_array_type()
def element_type(self) -> SegmentType:
"""Return the element type of the parameter.
Raises a ValueError if the parameter's type is not an array type.
"""
element_type = self.type.element_type()
# At this point, self.type is guaranteed to be one of `ARRAY_STRING`,
# `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`.
#
# See: _VALID_PARAMETER_TYPES for reference.
assert element_type is not None, f"the element type should not be None, {self.type=}"
return element_type
class ParameterExtractorNodeData(BaseNodeData):
"""
Parameter Extractor Node Data.
"""
model: ModelConfig
query: list[str]
parameters: list[ParameterConfig]
instruction: str | None = None
memory: MemoryConfig | None = None
reasoning_mode: Literal["function_call", "prompt"]
vision: VisionConfig = Field(default_factory=VisionConfig)
@field_validator("reasoning_mode", mode="before")
@classmethod
def set_reasoning_mode(cls, v) -> str:
return v or "function_call"
def get_parameter_json_schema(self):
"""
Get parameter json schema.
:return: parameter json schema
"""
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}
if parameter.type == SegmentType.STRING:
parameter_schema["type"] = "string"
elif parameter.type.is_array_type():
parameter_schema["type"] = "array"
element_type = parameter.type.element_type()
if element_type is None:
raise AssertionError("element type should not be None.")
parameter_schema["items"] = {"type": element_type.value}
else:
parameter_schema["type"] = parameter.type
if parameter.options:
parameter_schema["enum"] = parameter.options
parameters["properties"][parameter.name] = parameter_schema
if parameter.required:
parameters["required"].append(parameter.name)
return parameters

View File

@@ -0,0 +1,75 @@
from typing import Any
from core.variables.types import SegmentType
class ParameterExtractorNodeError(ValueError):
"""Base error for ParameterExtractorNode."""
class InvalidModelTypeError(ParameterExtractorNodeError):
"""Raised when the model is not a Large Language Model."""
class ModelSchemaNotFoundError(ParameterExtractorNodeError):
"""Raised when the model schema is not found."""
class InvalidInvokeResultError(ParameterExtractorNodeError):
"""Raised when the invoke result is invalid."""
class InvalidTextContentTypeError(ParameterExtractorNodeError):
"""Raised when the text content type is invalid."""
class InvalidNumberOfParametersError(ParameterExtractorNodeError):
"""Raised when the number of parameters is invalid."""
class RequiredParameterMissingError(ParameterExtractorNodeError):
"""Raised when a required parameter is missing."""
class InvalidSelectValueError(ParameterExtractorNodeError):
"""Raised when a select value is invalid."""
class InvalidNumberValueError(ParameterExtractorNodeError):
"""Raised when a number value is invalid."""
class InvalidBoolValueError(ParameterExtractorNodeError):
"""Raised when a bool value is invalid."""
class InvalidStringValueError(ParameterExtractorNodeError):
"""Raised when a string value is invalid."""
class InvalidArrayValueError(ParameterExtractorNodeError):
"""Raised when an array value is invalid."""
class InvalidModelModeError(ParameterExtractorNodeError):
"""Raised when the model mode is invalid."""
class InvalidValueTypeError(ParameterExtractorNodeError):
def __init__(
self,
/,
parameter_name: str,
expected_type: SegmentType,
actual_type: SegmentType | None,
value: Any,
):
message = (
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
)
super().__init__(message)
self.parameter_name = parameter_name
self.expected_type = expected_type
self.actual_type = actual_type
self.value = value

View File

@@ -0,0 +1,858 @@
import contextlib
import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
PromptMessageTool,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.runtime import VariablePool
from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData
from .exc import (
InvalidModelModeError,
InvalidModelTypeError,
InvalidNumberOfParametersError,
InvalidSelectValueError,
InvalidTextContentTypeError,
InvalidValueTypeError,
ModelSchemaNotFoundError,
ParameterExtractorNodeError,
RequiredParameterMissingError,
)
from .prompts import (
CHAT_EXAMPLE,
CHAT_GENERATE_JSON_PROMPT,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
COMPLETION_GENERATE_JSON_PROMPT,
FUNCTION_CALLING_EXTRACTOR_EXAMPLE,
FUNCTION_CALLING_EXTRACTOR_NAME,
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT,
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
)
logger = logging.getLogger(__name__)
def extract_json(text):
"""
From a given JSON started from '{' or '[' extract the complete JSON object.
"""
stack = []
for i, c in enumerate(text):
if c in {"{", "["}:
stack.append(c)
elif c in {"}", "]"}:
# check if stack is empty
if not stack:
return text[:i]
# check if the last element in stack is matching
if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["):
stack.pop()
if not stack:
return text[: i + 1]
else:
return text[:i]
return None
class ParameterExtractorNode(Node):
"""
Parameter Extractor Node.
"""
node_type = NodeType.PARAMETER_EXTRACTOR
_node_data: ParameterExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ParameterExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
_model_instance: ModelInstance | None = None
_model_config: ModelConfigWithCredentialsEntity | None = None
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"model": {
"prompt_templates": {
"completion_model": {
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
"stop": ["Human:"],
}
}
}
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
"""
Run the node.
"""
node_data = self._node_data
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""
variable_pool = self.graph_runtime_state.variable_pool
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=node_data.vision.configs.variable_selector,
)
if node_data.vision.enabled
else []
)
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise InvalidModelTypeError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(
model=model_config.model,
credentials=model_config.credentials,
)
if not model_schema:
raise ModelSchemaNotFoundError("Model schema not found")
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
if (
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
and node_data.reasoning_mode == "function_call"
):
# use function call
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
node_data=node_data,
query=query,
variable_pool=self.graph_runtime_state.variable_pool,
model_config=model_config,
memory=memory,
files=files,
vision_detail=node_data.vision.configs.detail,
)
else:
# use prompt engineering
prompt_messages = self._generate_prompt_engineering_prompt(
data=node_data,
query=query,
variable_pool=self.graph_runtime_state.variable_pool,
model_config=model_config,
memory=memory,
files=files,
vision_detail=node_data.vision.configs.detail,
)
prompt_message_tools = []
inputs = {
"query": query,
"files": [f.to_dict() for f in files],
"parameters": jsonable_encoder(node_data.parameters),
"instruction": jsonable_encoder(node_data.instruction),
}
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
"usage": None,
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
"tool_call": None,
"model_provider": model_config.provider,
"model_name": model_config.model,
}
try:
text, usage, tool_call = self._invoke(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
tools=prompt_message_tools,
stop=model_config.stop,
)
process_data["usage"] = jsonable_encoder(usage)
process_data["tool_call"] = jsonable_encoder(tool_call)
process_data["llm_text"] = text
except ParameterExtractorNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
outputs={"__is_success": 0, "__reason": str(e)},
error=str(e),
metadata={},
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)},
error=str(e),
metadata={},
)
error = None
if tool_call:
result = self._extract_json_from_tool_call(tool_call)
else:
result = self._extract_complete_json_response(text)
if not result:
result = self._generate_default_result(node_data)
error = "Failed to extract result from function call or text response, using empty result."
try:
result = self._validate_result(data=node_data, result=result or {})
except ParameterExtractorNodeError as e:
error = str(e)
# transform result into standard format
result = self._transform_result(data=node_data, result=result or {})
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={
"__is_success": 1 if not error else 0,
"__reason": error,
"__usage": jsonable_encoder(usage),
**result,
},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
def _invoke(
self,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
stop: list[str],
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,
tools=tools,
stop=stop,
stream=False,
user=self.user_id,
)
# handle invoke result
text = invoke_result.message.content or ""
if not isinstance(text, str):
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage, tool_call
def _generate_function_call_prompt(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
"""
Generate function call prompt.
"""
query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(
content=query, structure=json.dumps(node_data.get_parameter_json_schema())
)
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
prompt_template = self._get_function_calling_prompt_template(
node_data, query, variable_pool, memory, rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=files,
context="",
memory_config=node_data.memory,
memory=None,
model_config=model_config,
image_detail_config=vision_detail,
)
# find last user message
last_user_message_idx = -1
for i, prompt_message in enumerate(prompt_messages):
if prompt_message.role == PromptMessageRole.USER:
last_user_message_idx = i
# add function call messages before last user message
example_messages = []
for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE:
id = uuid.uuid4().hex
example_messages.extend(
[
UserPromptMessage(content=example["user"]["query"]),
AssistantPromptMessage(
content=example["assistant"]["text"],
tool_calls=[
AssistantPromptMessage.ToolCall(
id=id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=example["assistant"]["function_call"]["name"],
arguments=json.dumps(example["assistant"]["function_call"]["parameters"]),
),
)
],
),
ToolPromptMessage(
content="Great! You have called the function with the correct parameters.", tool_call_id=id
),
AssistantPromptMessage(
content="I have extracted the parameters, let's move on.",
),
]
)
prompt_messages = (
prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:]
)
# generate tool
tool = PromptMessageTool(
name=FUNCTION_CALLING_EXTRACTOR_NAME,
description="Extract parameters from the natural language text",
parameters=node_data.get_parameter_json_schema(),
)
return prompt_messages, [tool]
def _generate_prompt_engineering_prompt(
self,
data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate prompt engineering prompt.
"""
model_mode = ModelMode(data.model.mode)
if model_mode == ModelMode.COMPLETION:
return self._generate_prompt_engineering_completion_prompt(
node_data=data,
query=query,
variable_pool=variable_pool,
model_config=model_config,
memory=memory,
files=files,
vision_detail=vision_detail,
)
elif model_mode == ModelMode.CHAT:
return self._generate_prompt_engineering_chat_prompt(
node_data=data,
query=query,
variable_pool=variable_pool,
model_config=model_config,
memory=memory,
files=files,
vision_detail=vision_detail,
)
else:
raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
def _generate_prompt_engineering_completion_prompt(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate completion prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
)
prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={"structure": json.dumps(node_data.get_parameter_json_schema())},
query="",
files=files,
context="",
memory_config=node_data.memory,
memory=memory,
model_config=model_config,
image_detail_config=vision_detail,
)
return prompt_messages
def _generate_prompt_engineering_chat_prompt(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate chat prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
)
prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data,
query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
structure=json.dumps(node_data.get_parameter_json_schema()), text=query
),
variable_pool=variable_pool,
memory=memory,
max_token_limit=rest_token,
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=files,
context="",
memory_config=node_data.memory,
memory=None,
model_config=model_config,
image_detail_config=vision_detail,
)
# find last user message
last_user_message_idx = -1
for i, prompt_message in enumerate(prompt_messages):
if prompt_message.role == PromptMessageRole.USER:
last_user_message_idx = i
# add example messages before last user message
example_messages = []
for example in CHAT_EXAMPLE:
example_messages.extend(
[
UserPromptMessage(
content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
structure=json.dumps(example["user"]["json"]),
text=example["user"]["query"],
)
),
AssistantPromptMessage(
content=json.dumps(example["assistant"]["json"]),
),
]
)
prompt_messages = (
prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:]
)
return prompt_messages
def _validate_result(self, data: ParameterExtractorNodeData, result: dict):
if len(data.parameters) != len(result):
raise InvalidNumberOfParametersError("Invalid number of parameters")
for parameter in data.parameters:
if parameter.required and parameter.name not in result:
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
param_value = result.get(parameter.name)
if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL):
inferred_type = SegmentType.infer_segment_type(param_value)
raise InvalidValueTypeError(
parameter_name=parameter.name,
expected_type=parameter.type,
actual_type=inferred_type,
value=param_value,
)
if parameter.type == SegmentType.STRING and parameter.options:
if param_value not in parameter.options:
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
return result
@staticmethod
def _transform_number(value: int | float | str | bool) -> int | float | None:
"""
Attempts to transform the input into an integer or float.
Returns:
int or float: The transformed number if the conversion is successful.
None: If the transformation fails.
Note:
Boolean values `True` and `False` are converted to integers `1` and `0`, respectively.
This behavior ensures compatibility with existing workflows that may use boolean types as integers.
"""
if isinstance(value, bool):
return int(value)
elif isinstance(value, (int, float)):
return value
elif isinstance(value, str):
if "." in value:
try:
return float(value)
except ValueError:
return None
else:
try:
return int(value)
except ValueError:
return None
else:
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
"""
Transform result into standard format.
"""
transformed_result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.name in result:
param_value = result[parameter.name]
# transform value
if parameter.type == SegmentType.NUMBER:
transformed = self._transform_number(param_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
elif parameter.type == SegmentType.BOOLEAN:
if isinstance(result[parameter.name], (bool, int)):
transformed_result[parameter.name] = bool(result[parameter.name])
# elif isinstance(result[parameter.name], str):
# if result[parameter.name].lower() in ["true", "false"]:
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true")
elif parameter.type == SegmentType.STRING:
if isinstance(param_value, str):
transformed_result[parameter.name] = param_value
elif parameter.is_array_type():
if isinstance(param_value, list):
nested_type = parameter.element_type()
assert nested_type is not None
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
transformed_result[parameter.name] = segment_value
for item in param_value:
if nested_type == SegmentType.NUMBER:
transformed = self._transform_number(item)
if transformed is not None:
segment_value.value.append(transformed)
elif nested_type == SegmentType.STRING:
if isinstance(item, str):
segment_value.value.append(item)
elif nested_type == SegmentType.OBJECT:
if isinstance(item, dict):
segment_value.value.append(item)
elif nested_type == SegmentType.BOOLEAN:
if isinstance(item, bool):
segment_value.value.append(item)
if parameter.name not in transformed_result:
if parameter.type.is_array_type():
transformed_result[parameter.name] = build_segment_with_type(
segment_type=SegmentType(parameter.type), value=[]
)
elif parameter.type in (SegmentType.STRING, SegmentType.SECRET):
transformed_result[parameter.name] = ""
elif parameter.type == SegmentType.NUMBER:
transformed_result[parameter.name] = 0
elif parameter.type == SegmentType.BOOLEAN:
transformed_result[parameter.name] = False
else:
raise AssertionError("this statement should be unreachable.")
return transformed_result
def _extract_complete_json_response(self, result: str) -> dict | None:
"""
Extract complete json response.
"""
# extract json from the text
for idx in range(len(result)):
if result[idx] == "{" or result[idx] == "[":
json_str = extract_json(result[idx:])
if json_str:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
logger.info("extra error: %s", result)
return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
"""
Extract json from tool call.
"""
if not tool_call or not tool_call.function.arguments:
return None
result = tool_call.function.arguments
# extract json from the arguments
for idx in range(len(result)):
if result[idx] == "{" or result[idx] == "[":
json_str = extract_json(result[idx:])
if json_str:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
logger.info("extra error: %s", result)
return None
def _generate_default_result(self, data: ParameterExtractorNodeData):
"""
Generate default result.
"""
result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.type == "number":
result[parameter.name] = 0
elif parameter.type == "boolean":
result[parameter.name] = False
elif parameter.type in {"string", "select"}:
result[parameter.name] = ""
return result
def _get_function_calling_prompt_template(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]
else:
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
def _get_prompt_engineering_prompt_template(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]
elif model_mode == ModelMode.COMPLETION:
return CompletionModelPromptTemplate(
text=COMPLETION_GENERATE_JSON_PROMPT.format(
histories=memory_str, text=input_text, instruction=instruction
)
.replace("{γγγ", "")
.replace("}γγγ", "")
)
else:
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
def _calculate_rest_token(
self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
context: str | None,
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise InvalidModelTypeError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
if not model_schema:
raise ModelSchemaNotFoundError("Model schema not found")
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
else:
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config,
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
curr_message_tokens = (
model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000
) # add 1000 to ensure tool call messages
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config.
"""
if not self._model_instance or not self._model_config:
self._model_instance, self._model_config = llm_utils.fetch_model_config(
tenant_id=self.tenant_id, node_data_model=node_data_model
)
return self._model_instance, self._model_config
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
if typed_node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
for selector in selectors:
variable_mapping[selector.variable] = selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
return variable_mapping

View File

@@ -0,0 +1,184 @@
from typing import Any
FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters"
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
### Task
Always call the `{FUNCTION_CALLING_EXTRACTOR_NAME}` function with the correct parameters. Ensure that the information extraction is contextual and aligns with the provided criteria.
### Memory
Here is the chat history between the human and assistant, provided within <histories> tags:
<histories>
\x7bhistories\x7d
</histories>
### Instructions:
Some additional information is provided below. Always adhere to these instructions as closely as possible:
<instruction>
\x7binstruction\x7d
</instruction>
Steps:
1. Review the chat history provided within the <histories> tags.
2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text.
3. Generate a well-formatted output using the defined functions and arguments.
4. Use the `extract_parameter` function to create structured outputs with appropriate parameters.
5. Do not include any XML tags in your output.
### Example
To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples.
### Final Output
Produce well-formatted function calls in json without XML tags, as shown in the example.
""" # noqa: E501
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside <context></context> XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside <structure></structure> XML tags.
<context>
\x7bcontent\x7d
</context>
<structure>
\x7bstructure\x7d
</structure>
""" # noqa: E501
FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [
{
"user": {
"query": "What is the weather today in SF?",
"function": {
"name": FUNCTION_CALLING_EXTRACTOR_NAME,
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The location to get the weather information",
"required": True,
},
},
"required": ["location"],
},
},
},
"assistant": {
"text": "I need always call the function with the correct parameters."
" in this case, I need to call the function with the location parameter.",
"function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}},
},
},
{
"user": {
"query": "I want to eat some apple pie.",
"function": {
"name": FUNCTION_CALLING_EXTRACTOR_NAME,
"parameters": {
"type": "object",
"properties": {"food": {"type": "string", "description": "The food to eat", "required": True}},
"required": ["food"],
},
},
},
"assistant": {
"text": "I need always call the function with the correct parameters."
" in this case, I need to call the function with the food parameter.",
"function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}},
},
},
]
COMPLETION_GENERATE_JSON_PROMPT = """### Instructions:
Some extra information are provided below, I should always follow the instructions as possible as I can.
<instructions>
{instruction}
</instructions>
### Extract parameter Workflow
I need to extract the following information from the input text. The <information to be extracted> tag specifies the 'type', 'description' and 'required' of the information to be extracted.
<information to be extracted>
{{ structure }}
</information to be extracted>
Step 1: Carefully read the input and understand the structure of the expected output.
Step 2: Extract relevant parameters from the provided text based on the name and description of object.
Step 3: Structure the extracted parameters to JSON object as specified in <structure>.
Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted.
### Memory
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
### Structure
Here is the structure of the expected output, I should always follow the output structure.
{{γγγ
'properties1': 'relevant text extracted from input',
'properties2': 'relevant text extracted from input',
}}γγγ
### Input Text
Inside <text></text> XML tags, there is a text that I should extract parameters and convert to a JSON object.
<text>
{text}
</text>
### Answer
I should always output a valid JSON object. Output nothing other than the JSON object.
```JSON
""" # noqa: E501
CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object.
The structure of the JSON object you can found in the instructions.
### Memory
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
### Instructions:
Some extra information are provided below, you should always follow the instructions as possible as you can.
<instructions>
{instructions}
</instructions>
"""
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE = """### Structure
Here is the structure of the JSON object, you should always follow the structure.
<structure>
{structure}
</structure>
### Text to be converted to JSON
Inside <text></text> XML tags, there is a text that you should convert to a JSON object.
<text>
{text}
</text>
"""
CHAT_EXAMPLE = [
{
"user": {
"query": "What is the weather today in SF?",
"json": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The location to get the weather information",
"required": True,
}
},
"required": ["location"],
},
},
"assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}},
},
{
"user": {
"query": "I want to eat some apple pie.",
"json": {
"type": "object",
"properties": {"food": {"type": "string", "description": "The food to eat", "required": True}},
"required": ["food"],
},
},
"assistant": {"text": "I need to output a valid JSON object.", "json": {"food": "apple pie"}},
},
]

View File

@@ -0,0 +1,4 @@
from .entities import QuestionClassifierNodeData
from .question_classifier_node import QuestionClassifierNode
__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"]

View File

@@ -0,0 +1,28 @@
from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig
class ClassConfig(BaseModel):
id: str
name: str
class QuestionClassifierNodeData(BaseNodeData):
query_variable_selector: list[str]
model: ModelConfig
classes: list[ClassConfig]
instruction: str | None = None
memory: MemoryConfig | None = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property
def structured_output_enabled(self) -> bool:
# NOTE(QuantumGhost): Temporary workaround for issue #20725
# (https://github.com/langgenius/dify/issues/20725).
#
# The proper fix would be to make `QuestionClassifierNode` inherit
# from `BaseNode` instead of `LLMNode`.
return False

View File

@@ -0,0 +1,6 @@
class QuestionClassifierNodeError(ValueError):
"""Base class for QuestionClassifierNode errors."""
class InvalidModelTypeError(QuestionClassifierNodeError):
"""Raised when the model is not a Large Language Model."""

View File

@@ -0,0 +1,398 @@
import json
import re
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from libs.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
from .exc import InvalidModelTypeError
from .template_prompts import (
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
QUESTION_CLASSIFIER_COMPLETION_PROMPT,
QUESTION_CLASSIFIER_SYSTEM_PROMPT,
QUESTION_CLASSIFIER_USER_PROMPT_1,
QUESTION_CLASSIFIER_USER_PROMPT_2,
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.runtime import GraphRuntimeState
class QuestionClassifierNode(Node):
node_type = NodeType.QUESTION_CLASSIFIER
execution_type = NodeExecutionType.BRANCH
_node_data: QuestionClassifierNodeData
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = QuestionClassifierNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls):
return "1"
def _run(self):
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
query = variable.value if variable else None
variables = {"query": query}
# fetch model config
model_instance, model_config = llm_utils.fetch_model_config(
tenant_id=self.tenant_id,
node_data_model=node_data.model,
)
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
# fetch instruction
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=node_data.vision.configs.variable_selector,
)
if node_data.vision.enabled
else []
)
# fetch prompt messages
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query or "",
model_config=model_config,
context="",
)
prompt_template = self._get_prompt_template(
node_data=node_data,
query=query or "",
memory=memory,
max_token_limit=rest_token,
)
# Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...).
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
# two consecutive user prompts will be generated, causing model's error.
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
memory=memory,
model_config=model_config,
sys_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
tenant_id=self.tenant_id,
)
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
try:
# handle invoke result
generator = LLMNode.invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=False,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
)
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
rendered_classes = [
c.model_copy(update={"name": variable_pool.convert_template(c.name).text}) for c in node_data.classes
]
category_name = rendered_classes[0].name
category_id = rendered_classes[0].id
if "<think>" in result_text:
result_text = re.sub(r"<think[^>]*>[\s\S]*?</think>", "", result_text, flags=re.IGNORECASE)
result_text_json = parse_and_check_json_markdown(result_text, [])
# result_text_json = json.loads(result_text.strip('```JSON\n'))
if "category_name" in result_text_json and "category_id" in result_text_json:
category_id_result = result_text_json["category_id"]
classes = rendered_classes
classes_map = {class_.id: class_.name for class_ in classes}
category_ids = [_class.id for _class in classes]
if category_id_result in category_ids:
category_name = classes_map[category_id_result]
category_id = category_id_result
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"model_provider": model_config.provider,
"model_name": model_config.model,
}
outputs = {
"class_name": category_name,
"class_id": category_id,
"usage": jsonable_encoder(usage),
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=process_data,
outputs=outputs,
edge_source_handle=category_id,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
except ValueError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
variable_mapping = {"query": typed_node_data.query_variable_selector}
variable_selectors: list[VariableSelector] = []
if typed_node_data.instruction:
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
return variable_mapping
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
Get default config of node.
:param filters: filter by node config parameters (not used in this implementation).
:return:
"""
# filters parameter is not used in this node type
return {"type": "question-classifier", "config": {"instructions": ""}}
def _calculate_rest_token(
self,
node_data: QuestionClassifierNodeData,
query: str,
model_config: ModelConfigWithCredentialsEntity,
context: str | None,
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config,
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _get_prompt_template(
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
classes = node_data.classes
categories = []
for class_ in classes:
category = {"category_id": class_.id, "category_name": class_.name}
categories.append(category)
instruction = node_data.instruction or ""
input_text = query
memory_str = ""
if memory:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
)
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(
input_text=input_text,
categories=json.dumps(categories, ensure_ascii=False),
classification_instructions=instruction,
),
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(
histories=memory_str,
input_text=input_text,
categories=json.dumps(categories, ensure_ascii=False),
classification_instructions=instruction,
)
)
else:
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")

View File

@@ -0,0 +1,76 @@
QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
### Job Description',
You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.
### Task
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.
### Format
The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Memory
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
""" # noqa: E501
QUESTION_CLASSIFIER_USER_PROMPT_1 = """
{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],
"categories": [{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"},{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"},{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"},{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}],
"classification_instructions": ["classify the text based on the feedback provided by customer"]}
""" # noqa: E501
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """
```json
{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],
"category_id": "f5660049-284f-41a7-b301-fd24176a711c",
"category_name": "Customer Service"}
```
"""
QUESTION_CLASSIFIER_USER_PROMPT_2 = """
{"input_text": ["bad service, slow to bring the food"],
"categories": [{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"},{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"},{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}],
"classification_instructions": []}
""" # noqa: E501
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """
```json
{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],
"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f",
"category_name": "Experience"}
```
"""
QUESTION_CLASSIFIER_USER_PROMPT_3 = """
{{"input_text": ["{input_text}"],
"categories": {categories},
"classification_instructions": ["{classification_instructions}"]}}
"""
QUESTION_CLASSIFIER_COMPLETION_PROMPT = """
### Job Description
You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories.
### Task
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.
### Format
The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": [{{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"}},{{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"}},{{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"}},{{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}}], "classification_instructions": ["classify the text based on the feedback provided by customer"]}}
Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}}
User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"}},{{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"}},{{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}}], "classification_instructions": []}}
Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}}
</example>
### Memory
Here are the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
### User Input
{{"input_text" : ["{input_text}"], "categories" : {categories},"classification_instruction" : ["{classification_instructions}"]}}
### Assistant Output
""" # noqa: E501

View File

@@ -0,0 +1,3 @@
from .start_node import StartNode
__all__ = ["StartNode"]

View File

@@ -0,0 +1,14 @@
from collections.abc import Sequence
from pydantic import Field
from core.app.app_config.entities import VariableEntity
from core.workflow.nodes.base import BaseNodeData
class StartNodeData(BaseNodeData):
"""
Start Node Data
"""
variables: Sequence[VariableEntity] = Field(default_factory=list)

View File

@@ -0,0 +1,53 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.start.entities import StartNodeData
class StartNode(Node):
node_type = NodeType.START
execution_type = NodeExecutionType.ROOT
_node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = StartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)

View File

@@ -0,0 +1,3 @@
from .template_transform_node import TemplateTransformNode
__all__ = ["TemplateTransformNode"]

View File

@@ -0,0 +1,11 @@
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
class TemplateTransformNodeData(BaseNodeData):
"""
Template Transform Node Data.
"""
variables: list[VariableSelector]
template: str

View File

@@ -0,0 +1,93 @@
from collections.abc import Mapping, Sequence
from typing import Any
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node):
node_type = NodeType.TEMPLATE_TRANSFORM
_node_data: TemplateTransformNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = TemplateTransformNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {
"type": "template-transform",
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
# Get variables
variables: dict[str, Any] = {}
for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
)
except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters",
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = TemplateTransformNodeData.model_validate(node_data)
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in typed_node_data.variables
}

View File

@@ -0,0 +1,3 @@
from .tool_node import ToolNode
__all__ = ["ToolNode"]

View File

@@ -0,0 +1,84 @@
from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.base.entities import BaseNodeData
class ToolEntity(BaseModel):
provider_id: str
provider_type: ToolProviderType
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, Any]
credential_id: str | None = None
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")
@classmethod
def validate_tool_configurations(cls, value, values: ValidationInfo):
if not isinstance(value, dict):
raise ValueError("tool_configurations must be a dictionary")
for key in values.data.get("tool_configurations", {}):
value = values.data.get("tool_configurations", {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f"{key} must be a string")
return value
class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
if value is None:
return typ
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool | dict):
raise ValueError("value must be a string, int, float, bool or dict")
return typ
tool_parameters: dict[str, ToolInput]
# The version of the tool parameter.
# If this value is None, it indicates this is a previous version
# and requires using the legacy parameter parsing rules.
tool_node_version: str | None = None
@field_validator("tool_parameters", mode="before")
@classmethod
def filter_none_tool_inputs(cls, value):
if not isinstance(value, dict):
return value
return {
key: tool_input
for key, tool_input in value.items()
if tool_input is not None and cls._has_valid_value(tool_input)
}
@staticmethod
def _has_valid_value(tool_input):
"""Check if the value is valid"""
if isinstance(tool_input, dict):
return tool_input.get("value") is not None
return getattr(tool_input, "value", None) is not None

View File

@@ -0,0 +1,16 @@
class ToolNodeError(ValueError):
"""Base exception for tool node errors."""
pass
class ToolParameterError(ToolNodeError):
"""Exception raised for errors in tool parameters."""
pass
class ToolFileError(ToolNodeError):
"""Exception raised for errors related to tool files."""
pass

View File

@@ -0,0 +1,521 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .entities import ToolNodeData
from .exc import (
ToolFileError,
ToolNodeError,
ToolParameterError,
)
if TYPE_CHECKING:
from core.workflow.runtime import VariablePool
class ToolNode(Node):
"""
Tool Node
"""
node_type = NodeType.TOOL
_node_data: ToolNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ToolNodeData.model_validate(data)
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Run the tool node
"""
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
node_data = self._node_data
# fetch tool icon
tool_info = {
"provider_type": node_data.provider_type.value,
"provider_id": node_data.provider_id,
"plugin_unique_identifier": node_data.plugin_unique_identifier,
}
# get tool runtime
try:
from core.tools.tool_manager import ToolManager
# This is an issue that caused problems before.
# Logically, we shouldn't use the node_data.version field for judgment
# But for backward compatibility with historical data
# this version field judgment is still preserved here.
variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
variable_pool = self.graph_runtime_state.variable_pool
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool
)
except ToolNodeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to get tool runtime: {str(e)}",
error_type=type(e).__name__,
)
)
return
# get parameters
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data,
for_log=True,
)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = ToolEngine.generic_invoke(
tool=tool_runtime,
tool_parameters=parameters,
user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
)
except ToolNodeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool: {str(e)}",
error_type=type(e).__name__,
)
)
return
try:
# convert tool messages
_ = yield from self._transform_message(
messages=message_stream,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_id=self._node_id,
tool_runtime=tool_runtime,
)
except ToolInvokeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool {node_data.provider_name}: {str(e)}",
error_type=type(e).__name__,
)
)
except PluginInvokeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=e.to_user_friendly_error(plugin_name=node_data.provider_name),
error_type=type(e).__name__,
)
)
except PluginDaemonClientSideError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool, error: {e.description}",
error_type=type(e).__name__,
)
)
def _generate_parameters(
self,
*,
tool_parameters: Sequence[ToolParameter],
variable_pool: "VariablePool",
node_data: ToolNodeData,
for_log: bool = False,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
result: dict[str, Any] = {}
for parameter_name in node_data.tool_parameters:
parameter = tool_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
continue
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
result[parameter_name] = parameter_value
return result
def _fetch_files(self, variable_pool: "VariablePool") -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_id: str,
tool_runtime: Tool,
) -> Generator[NodeEventBase, None, LLMUsage]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict] = []
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"tool file {tool_file_id} not exists")
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
# JSON message handling for tool node
if message.message.json_object:
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
# Check if this LINK message is a file link
file_obj = (message.meta or {}).get("file")
if isinstance(file_obj, File):
files.append(file_obj)
stream_text = f"File: {message.message.text}\n"
else:
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, dict)
# Validate that meta contains a 'file' key
if "file" not in message.meta:
raise ToolNodeError("File message is missing 'file' key in meta")
# Validate that the file is an instance of File
if not isinstance(message.meta["file"], File):
raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = []
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json:
json_output.extend(json)
else:
json_output.append({"data": []})
# Send final chunk events for all streamed outputs
# Final chunk for text stream
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
# Final chunks for any streamed variables
for var_name in variables:
yield StreamChunkEvent(
selector=[self._node_id, var_name],
chunk="",
is_final=True,
)
usage = self._extract_tool_usage(tool_runtime)
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
}
if usage.total_tokens > 0:
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata=metadata,
inputs=parameters_for_log,
llm_usage=usage,
)
)
return usage
@staticmethod
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
if isinstance(tool_runtime, WorkflowTool):
return tool_runtime.latest_usage
return LLMUsage.empty_usage()
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = ToolNodeData.model_validate(node_data)
result = {}
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
elif input.type == "constant":
pass
result = {node_id + "." + key: value for key, value in result.items()}
return result
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

View File

@@ -0,0 +1,3 @@
from .trigger_event_node import TriggerEventNode
__all__ = ["TriggerEventNode"]

View File

@@ -0,0 +1,77 @@
from collections.abc import Mapping
from typing import Any, Literal, Union
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from core.trigger.entities.entities import EventParameter
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.trigger_plugin.exc import TriggerEventParameterError
class TriggerEventNodeData(BaseNodeData):
"""Plugin trigger node data"""
class TriggerEventInput(BaseModel):
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
type = value
value = validation_info.data.get("value")
if value is None:
return type
if type == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
if type == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
if type == "constant" and not isinstance(value, str | int | float | bool | dict | list):
raise ValueError("value must be a string, int, float, bool or dict")
return type
title: str
desc: str | None = None
plugin_id: str = Field(..., description="Plugin ID")
provider_id: str = Field(..., description="Provider ID")
event_name: str = Field(..., description="Event name")
subscription_id: str = Field(..., description="Subscription ID")
plugin_unique_identifier: str = Field(..., description="Plugin unique identifier")
event_parameters: Mapping[str, TriggerEventInput] = Field(default_factory=dict, description="Trigger parameters")
def resolve_parameters(
self,
*,
parameter_schemas: Mapping[str, EventParameter],
) -> Mapping[str, Any]:
"""
Generate parameters based on the given plugin trigger parameters.
Args:
parameter_schemas (Mapping[str, EventParameter]): The mapping of parameter schemas.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
result: dict[str, Any] = {}
for parameter_name in self.event_parameters:
parameter: EventParameter | None = parameter_schemas.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
event_input = self.event_parameters[parameter_name]
# trigger node only supports constant input
if event_input.type != "constant":
raise TriggerEventParameterError(f"Unknown plugin trigger input type '{event_input.type}'")
result[parameter_name] = event_input.value
return result

View File

@@ -0,0 +1,10 @@
class TriggerEventNodeError(ValueError):
"""Base exception for plugin trigger node errors."""
pass
class TriggerEventParameterError(TriggerEventNodeError):
"""Exception raised for errors in plugin trigger parameters."""
pass

View File

@@ -0,0 +1,89 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import TriggerEventNodeData
class TriggerEventNode(Node):
node_type = NodeType.TRIGGER_PLUGIN
execution_type = NodeExecutionType.ROOT
_node_data: TriggerEventNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = TriggerEventNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "plugin",
"config": {
"title": "",
"plugin_id": "",
"provider_id": "",
"event_name": "",
"subscription_id": "",
"plugin_unique_identifier": "",
"event_parameters": {},
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the plugin trigger node.
This node invokes the trigger to convert request data into events
and makes them available to downstream nodes.
"""
# Get trigger data passed when workflow was triggered
metadata = {
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
"provider_id": self._node_data.provider_id,
"event_name": self._node_data.event_name,
"plugin_unique_identifier": self._node_data.plugin_unique_identifier,
},
}
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
outputs=outputs,
metadata=metadata,
)

View File

@@ -0,0 +1,3 @@
from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode
__all__ = ["TriggerScheduleNode"]

View File

@@ -0,0 +1,49 @@
from typing import Literal, Union
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
class TriggerScheduleNodeData(BaseNodeData):
"""
Trigger Schedule Node Data
"""
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")
visual_config: dict | None = Field(default=None, description="Visual configuration details")
timezone: str = Field(default="UTC", description="Timezone for schedule execution")
class ScheduleConfig(BaseModel):
node_id: str
cron_expression: str
timezone: str = "UTC"
class SchedulePlanUpdate(BaseModel):
node_id: str | None = None
cron_expression: str | None = None
timezone: str | None = None
class VisualConfig(BaseModel):
"""Visual configuration for schedule trigger"""
# For hourly frequency
on_minute: int | None = Field(default=0, ge=0, le=59, description="Minute of the hour (0-59)")
# For daily, weekly, monthly frequencies
time: str | None = Field(default="12:00 AM", description="Time in 12-hour format (e.g., '2:30 PM')")
# For weekly frequency
weekdays: list[Literal["sun", "mon", "tue", "wed", "thu", "fri", "sat"]] | None = Field(
default=None, description="List of weekdays to run on"
)
# For monthly frequency
monthly_days: list[Union[int, Literal["last"]]] | None = Field(
default=None, description="Days of month to run on (1-31 or 'last')"
)

View File

@@ -0,0 +1,31 @@
from core.workflow.nodes.base.exc import BaseNodeError
class ScheduleNodeError(BaseNodeError):
"""Base schedule node error."""
pass
class ScheduleNotFoundError(ScheduleNodeError):
"""Schedule not found error."""
pass
class ScheduleConfigError(ScheduleNodeError):
"""Schedule configuration error."""
pass
class ScheduleExecutionError(ScheduleNodeError):
"""Schedule execution error."""
pass
class TenantOwnerNotFoundError(ScheduleExecutionError):
"""Tenant owner not found error for schedule execution."""
pass

View File

@@ -0,0 +1,69 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.trigger_schedule.entities import TriggerScheduleNodeData
class TriggerScheduleNode(Node):
node_type = NodeType.TRIGGER_SCHEDULE
execution_type = NodeExecutionType.ROOT
_node_data: TriggerScheduleNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = TriggerScheduleNodeData(**data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "trigger-schedule",
"config": {
"mode": "visual",
"frequency": "daily",
"visual_config": {"time": "12:00 AM", "on_minute": 0, "weekdays": ["sun"], "monthly_days": [1]},
"timezone": "UTC",
},
}
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
outputs=outputs,
)

Some files were not shown because too many files have changed in this diff Show More