dify
This commit is contained in:
132
dify/api/core/workflow/README.md
Normal file
132
dify/api/core/workflow/README.md
Normal file
@@ -0,0 +1,132 @@
|
||||
# Workflow
|
||||
|
||||
## Project Overview
|
||||
|
||||
This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
The graph engine follows a layered architecture with strict dependency rules:
|
||||
|
||||
1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution
|
||||
|
||||
- **Manager** - External control interface for stop/pause/resume commands
|
||||
- **Worker** - Node execution runtime
|
||||
- **Command Processing** - Handles control commands (abort, pause, resume)
|
||||
- **Event Management** - Event propagation and layer notifications
|
||||
- **Graph Traversal** - Edge processing and skip propagation
|
||||
- **Response Coordinator** - Path tracking and session management
|
||||
- **Layers** - Pluggable middleware (debug logging, execution limits)
|
||||
- **Command Channels** - Communication channels (InMemory, Redis)
|
||||
|
||||
1. **Graph** (`graph/`) - Graph structure and runtime state
|
||||
|
||||
- **Graph Template** - Workflow definition
|
||||
- **Edge** - Node connections with conditions
|
||||
- **Runtime State Protocol** - State management interface
|
||||
|
||||
1. **Nodes** (`nodes/`) - Node implementations
|
||||
|
||||
- **Base** - Abstract node classes and variable parsing
|
||||
- **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc.
|
||||
|
||||
1. **Events** (`node_events/`) - Event system
|
||||
|
||||
- **Base** - Event protocols
|
||||
- **Node Events** - Node lifecycle events
|
||||
|
||||
1. **Entities** (`entities/`) - Domain models
|
||||
|
||||
- **Variable Pool** - Variable storage
|
||||
- **Graph Init Params** - Initialization configuration
|
||||
|
||||
## Key Design Patterns
|
||||
|
||||
### Command Channel Pattern
|
||||
|
||||
External workflow control via Redis or in-memory channels:
|
||||
|
||||
```python
|
||||
# Send stop command to running workflow
|
||||
channel = RedisChannel(redis_client, f"workflow:{task_id}:commands")
|
||||
channel.send_command(AbortCommand(reason="User requested"))
|
||||
```
|
||||
|
||||
### Layer System
|
||||
|
||||
Extensible middleware for cross-cutting concerns:
|
||||
|
||||
```python
|
||||
engine = GraphEngine(graph)
|
||||
engine.layer(DebugLoggingLayer(level="INFO"))
|
||||
engine.layer(ExecutionLimitsLayer(max_nodes=100))
|
||||
```
|
||||
|
||||
### Event-Driven Architecture
|
||||
|
||||
All node executions emit events for monitoring and integration:
|
||||
|
||||
- `NodeRunStartedEvent` - Node execution begins
|
||||
- `NodeRunSucceededEvent` - Node completes successfully
|
||||
- `NodeRunFailedEvent` - Node encounters error
|
||||
- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle
|
||||
|
||||
### Variable Pool
|
||||
|
||||
Centralized variable storage with namespace isolation:
|
||||
|
||||
```python
|
||||
# Variables scoped by node_id
|
||||
pool.add(["node1", "output"], value)
|
||||
result = pool.get(["node1", "output"])
|
||||
```
|
||||
|
||||
## Import Architecture Rules
|
||||
|
||||
The codebase enforces strict layering via import-linter:
|
||||
|
||||
1. **Workflow Layers** (top to bottom):
|
||||
|
||||
- graph_engine → graph_events → graph → nodes → node_events → entities
|
||||
|
||||
1. **Graph Engine Internal Layers**:
|
||||
|
||||
- orchestration → command_processing → event_management → graph_traversal → domain
|
||||
|
||||
1. **Domain Isolation**:
|
||||
|
||||
- Domain models cannot import from infrastructure layers
|
||||
|
||||
1. **Command Channel Independence**:
|
||||
|
||||
- InMemory and Redis channels must remain independent
|
||||
|
||||
## Common Tasks
|
||||
|
||||
### Adding a New Node Type
|
||||
|
||||
1. Create node class in `nodes/<node_type>/`
|
||||
1. Inherit from `BaseNode` or appropriate base class
|
||||
1. Implement `_run()` method
|
||||
1. Register in `nodes/node_mapping.py`
|
||||
1. Add tests in `tests/unit_tests/core/workflow/nodes/`
|
||||
|
||||
### Implementing a Custom Layer
|
||||
|
||||
1. Create class inheriting from `Layer` base
|
||||
1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()`
|
||||
1. Add to engine via `engine.layer()`
|
||||
|
||||
### Debugging Workflow Execution
|
||||
|
||||
Enable debug logging layer:
|
||||
|
||||
```python
|
||||
debug_layer = DebugLoggingLayer(
|
||||
level="DEBUG",
|
||||
include_inputs=True,
|
||||
include_outputs=True
|
||||
)
|
||||
```
|
||||
0
dify/api/core/workflow/__init__.py
Normal file
0
dify/api/core/workflow/__init__.py
Normal file
4
dify/api/core/workflow/constants.py
Normal file
4
dify/api/core/workflow/constants.py
Normal file
@@ -0,0 +1,4 @@
|
||||
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"
|
||||
39
dify/api/core/workflow/conversation_variable_updater.py
Normal file
39
dify/api/core/workflow/conversation_variable_updater.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import abc
|
||||
from typing import Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
|
||||
|
||||
class ConversationVariableUpdater(Protocol):
|
||||
"""
|
||||
ConversationVariableUpdater defines an abstraction for updating conversation variable values.
|
||||
|
||||
It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating
|
||||
conversation variables.
|
||||
|
||||
Implementations may choose to batch updates. If batching is used, the `flush` method
|
||||
should be implemented to persist buffered changes, and `update`
|
||||
should handle buffering accordingly.
|
||||
|
||||
Note: Since implementations may buffer updates, instances of ConversationVariableUpdater
|
||||
are not thread-safe. Each VariableAssignerNode should create its own instance during execution.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, conversation_id: str, variable: "Variable"):
|
||||
"""
|
||||
Updates the value of the specified conversation variable in the underlying storage.
|
||||
|
||||
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
|
||||
:param variable: The `Variable` instance containing the updated value.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def flush(self):
|
||||
"""
|
||||
Flushes all pending updates to the underlying storage system.
|
||||
|
||||
If the implementation does not buffer updates, this method can be a no-op.
|
||||
"""
|
||||
pass
|
||||
17
dify/api/core/workflow/entities/__init__.py
Normal file
17
dify/api/core/workflow/entities/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from ..runtime.graph_runtime_state import GraphRuntimeState
|
||||
from ..runtime.variable_pool import VariablePool
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
from .workflow_pause import WorkflowPauseEntity
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"GraphRuntimeState",
|
||||
"VariablePool",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
"WorkflowPauseEntity",
|
||||
]
|
||||
8
dify/api/core/workflow/entities/agent.py
Normal file
8
dify/api/core/workflow/entities/agent.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentNodeStrategyInit(BaseModel):
|
||||
"""Agent node strategy initialization data."""
|
||||
|
||||
name: str
|
||||
icon: str | None = None
|
||||
20
dify/api/core/workflow/entities/graph_init_params.py
Normal file
20
dify/api/core/workflow/entities/graph_init_params.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: str = Field(
|
||||
..., description="user from, account or end-user"
|
||||
) # Should be UserFrom enum: 'account' | 'end-user'
|
||||
invoke_from: str = Field(
|
||||
..., description="invoke from, service-api, web-app, explore or debugger"
|
||||
) # Should be InvokeFrom enum: 'service-api' | 'web-app' | 'explore' | 'debugger'
|
||||
call_depth: int = Field(..., description="call depth")
|
||||
49
dify/api/core/workflow/entities/pause_reason.py
Normal file
49
dify/api/core/workflow/entities/pause_reason.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, ClassVar, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Discriminator, Tag
|
||||
|
||||
|
||||
class _PauseReasonType(StrEnum):
|
||||
HUMAN_INPUT_REQUIRED = auto()
|
||||
SCHEDULED_PAUSE = auto()
|
||||
|
||||
|
||||
class _PauseReasonBase(BaseModel):
|
||||
TYPE: ClassVar[_PauseReasonType]
|
||||
|
||||
|
||||
class HumanInputRequired(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
|
||||
class SchedulingPause(_PauseReasonBase):
|
||||
TYPE = _PauseReasonType.SCHEDULED_PAUSE
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
|
||||
if isinstance(v, _PauseReasonBase):
|
||||
return v.TYPE
|
||||
elif isinstance(v, dict):
|
||||
reason_type_str = v.get("TYPE")
|
||||
if reason_type_str is None:
|
||||
return None
|
||||
try:
|
||||
reason_type = _PauseReasonType(reason_type_str)
|
||||
except ValueError:
|
||||
return None
|
||||
return reason_type
|
||||
else:
|
||||
# return None if the discriminator value isn't found
|
||||
return None
|
||||
|
||||
|
||||
PauseReason: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
|
||||
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
|
||||
),
|
||||
Discriminator(_get_pause_reason_discriminator),
|
||||
]
|
||||
72
dify/api/core/workflow/entities/workflow_execution.py
Normal file
72
dify/api/core/workflow/entities/workflow_execution.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Domain entities for workflow execution.
|
||||
|
||||
Models are independent of the storage mechanism and don't contain
|
||||
implementation details like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class WorkflowExecution(BaseModel):
|
||||
"""
|
||||
Domain model for workflow execution based on WorkflowRun but without
|
||||
user, tenant, and app attributes.
|
||||
"""
|
||||
|
||||
id_: str = Field(...)
|
||||
workflow_id: str = Field(...)
|
||||
workflow_version: str = Field(...)
|
||||
workflow_type: WorkflowType = Field(...)
|
||||
graph: Mapping[str, Any] = Field(...)
|
||||
|
||||
inputs: Mapping[str, Any] = Field(...)
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
|
||||
status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
|
||||
error_message: str = Field(default="")
|
||||
total_tokens: int = Field(default=0)
|
||||
total_steps: int = Field(default=0)
|
||||
exceptions_count: int = Field(default=0)
|
||||
|
||||
started_at: datetime = Field(...)
|
||||
finished_at: datetime | None = None
|
||||
|
||||
@property
|
||||
def elapsed_time(self) -> float:
|
||||
"""
|
||||
Calculate elapsed time in seconds.
|
||||
If workflow is not finished, use current time.
|
||||
"""
|
||||
end_time = self.finished_at or naive_utc_now()
|
||||
return (end_time - self.started_at).total_seconds()
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls,
|
||||
*,
|
||||
id_: str,
|
||||
workflow_id: str,
|
||||
workflow_type: WorkflowType,
|
||||
workflow_version: str,
|
||||
graph: Mapping[str, Any],
|
||||
inputs: Mapping[str, Any],
|
||||
started_at: datetime,
|
||||
) -> "WorkflowExecution":
|
||||
return WorkflowExecution(
|
||||
id_=id_,
|
||||
workflow_id=workflow_id,
|
||||
workflow_type=workflow_type,
|
||||
workflow_version=workflow_version,
|
||||
graph=graph,
|
||||
inputs=inputs,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
started_at=started_at,
|
||||
)
|
||||
147
dify/api/core/workflow/entities/workflow_node_execution.py
Normal file
147
dify/api/core/workflow/entities/workflow_node_execution.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
Domain entities for workflow node execution.
|
||||
|
||||
This module contains the domain model for workflow node execution, which is used
|
||||
by the core workflow module. These models are independent of the storage mechanism
|
||||
and don't contain implementation details like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class WorkflowNodeExecution(BaseModel):
|
||||
"""
|
||||
Domain model for workflow node execution.
|
||||
|
||||
This model represents the core business entity of a node execution,
|
||||
without implementation details like tenant_id, app_id, etc.
|
||||
|
||||
Note: User/context-specific fields (triggered_from, created_by, created_by_role)
|
||||
have been moved to the repository implementation to keep the domain model clean.
|
||||
These fields are still accepted in the constructor for backward compatibility,
|
||||
but they are not stored in the model.
|
||||
"""
|
||||
|
||||
# --------- Core identification fields ---------
|
||||
|
||||
# Unique identifier for this execution record, used when persisting to storage.
|
||||
# Value is a UUID string (e.g., '09b3e04c-f9ae-404c-ad82-290b8d7bd382').
|
||||
id: str
|
||||
|
||||
# Optional secondary ID for cross-referencing purposes.
|
||||
#
|
||||
# NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`.
|
||||
# While `node_execution_id` may sometimes be a UUID string, this is not guaranteed.
|
||||
# In most scenarios, `id` should be used as the primary identifier.
|
||||
node_execution_id: str | None = None
|
||||
workflow_id: str # ID of the workflow this node belongs to
|
||||
workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging)
|
||||
# --------- Core identification fields ends ---------
|
||||
|
||||
# Execution positioning and flow
|
||||
index: int # Sequence number for ordering in trace visualization
|
||||
predecessor_node_id: str | None = None # ID of the node that executed before this one
|
||||
node_id: str # ID of the node being executed
|
||||
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
|
||||
title: str # Display title of the node
|
||||
|
||||
# Execution data
|
||||
# The `inputs` and `outputs` fields hold the full content
|
||||
inputs: Mapping[str, Any] | None = None # Input variables used by this node
|
||||
process_data: Mapping[str, Any] | None = None # Intermediate processing data
|
||||
outputs: Mapping[str, Any] | None = None # Output variables produced by this node
|
||||
|
||||
# Execution state
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
|
||||
error: str | None = None # Error message if execution failed
|
||||
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
|
||||
|
||||
# Additional metadata
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.)
|
||||
|
||||
# Timing information
|
||||
created_at: datetime # When execution started
|
||||
finished_at: datetime | None = None # When execution completed
|
||||
|
||||
_truncated_inputs: Mapping[str, Any] | None = PrivateAttr(None)
|
||||
_truncated_outputs: Mapping[str, Any] | None = PrivateAttr(None)
|
||||
_truncated_process_data: Mapping[str, Any] | None = PrivateAttr(None)
|
||||
|
||||
def get_truncated_inputs(self) -> Mapping[str, Any] | None:
|
||||
return self._truncated_inputs
|
||||
|
||||
def get_truncated_outputs(self) -> Mapping[str, Any] | None:
|
||||
return self._truncated_outputs
|
||||
|
||||
def get_truncated_process_data(self) -> Mapping[str, Any] | None:
|
||||
return self._truncated_process_data
|
||||
|
||||
def set_truncated_inputs(self, truncated_inputs: Mapping[str, Any] | None):
|
||||
self._truncated_inputs = truncated_inputs
|
||||
|
||||
def set_truncated_outputs(self, truncated_outputs: Mapping[str, Any] | None):
|
||||
self._truncated_outputs = truncated_outputs
|
||||
|
||||
def set_truncated_process_data(self, truncated_process_data: Mapping[str, Any] | None):
|
||||
self._truncated_process_data = truncated_process_data
|
||||
|
||||
def get_response_inputs(self) -> Mapping[str, Any] | None:
|
||||
inputs = self.get_truncated_inputs()
|
||||
if inputs:
|
||||
return inputs
|
||||
return self.inputs
|
||||
|
||||
@property
|
||||
def inputs_truncated(self):
|
||||
return self._truncated_inputs is not None
|
||||
|
||||
@property
|
||||
def outputs_truncated(self):
|
||||
return self._truncated_outputs is not None
|
||||
|
||||
@property
|
||||
def process_data_truncated(self):
|
||||
return self._truncated_process_data is not None
|
||||
|
||||
def get_response_outputs(self) -> Mapping[str, Any] | None:
|
||||
outputs = self.get_truncated_outputs()
|
||||
if outputs is not None:
|
||||
return outputs
|
||||
return self.outputs
|
||||
|
||||
def get_response_process_data(self) -> Mapping[str, Any] | None:
|
||||
process_data = self.get_truncated_process_data()
|
||||
if process_data is not None:
|
||||
return process_data
|
||||
return self.process_data
|
||||
|
||||
def update_from_mapping(
|
||||
self,
|
||||
inputs: Mapping[str, Any] | None = None,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Update the model from mappings.
|
||||
|
||||
Args:
|
||||
inputs: The inputs to update
|
||||
process_data: The process data to update
|
||||
outputs: The outputs to update
|
||||
metadata: The metadata to update
|
||||
"""
|
||||
if inputs is not None:
|
||||
self.inputs = dict(inputs)
|
||||
if process_data is not None:
|
||||
self.process_data = dict(process_data)
|
||||
if outputs is not None:
|
||||
self.outputs = dict(outputs)
|
||||
if metadata is not None:
|
||||
self.metadata = dict(metadata)
|
||||
61
dify/api/core/workflow/entities/workflow_pause.py
Normal file
61
dify/api/core/workflow/entities/workflow_pause.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Domain entities for workflow pause management.
|
||||
|
||||
This module contains the domain model for workflow pause, which is used
|
||||
by the core workflow module. These models are independent of the storage mechanism
|
||||
and don't contain implementation details like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class WorkflowPauseEntity(ABC):
|
||||
"""
|
||||
Abstract base class for workflow pause entities.
|
||||
|
||||
This domain model represents a paused workflow execution state,
|
||||
without implementation details like tenant_id, app_id, etc.
|
||||
It provides the interface for managing workflow pause/resume operations
|
||||
and state persistence through file storage.
|
||||
|
||||
The `WorkflowPauseEntity` is never reused. If a workflow execution pauses multiple times,
|
||||
it will generate multiple `WorkflowPauseEntity` records.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def id(self) -> str:
|
||||
"""The identifier of current WorkflowPauseEntity"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def workflow_execution_id(self) -> str:
|
||||
"""The identifier of the workflow execution record the pause associated with.
|
||||
Correspond to `WorkflowExecution.id`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_state(self) -> bytes:
|
||||
"""
|
||||
Retrieve the serialized workflow state from storage.
|
||||
|
||||
This method should load and return the workflow execution state
|
||||
that was saved when the workflow was paused. The state contains
|
||||
all necessary information to resume the workflow execution.
|
||||
|
||||
Returns:
|
||||
bytes: The serialized workflow state containing
|
||||
execution context, variable values, node states, etc.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def resumed_at(self) -> datetime | None:
|
||||
"""`resumed_at` return the resumption time of the current pause, or `None` if
|
||||
the pause is not resumed yet.
|
||||
"""
|
||||
pass
|
||||
262
dify/api/core/workflow/enums.py
Normal file
262
dify/api/core/workflow/enums.py
Normal file
@@ -0,0 +1,262 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class NodeState(StrEnum):
|
||||
"""State of a node or edge during workflow execution."""
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
TAKEN = "taken"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class SystemVariableKey(StrEnum):
|
||||
"""
|
||||
System Variables.
|
||||
"""
|
||||
|
||||
QUERY = "query"
|
||||
FILES = "files"
|
||||
CONVERSATION_ID = "conversation_id"
|
||||
USER_ID = "user_id"
|
||||
DIALOGUE_COUNT = "dialogue_count"
|
||||
APP_ID = "app_id"
|
||||
WORKFLOW_ID = "workflow_id"
|
||||
WORKFLOW_EXECUTION_ID = "workflow_run_id"
|
||||
TIMESTAMP = "timestamp"
|
||||
# RAG Pipeline
|
||||
DOCUMENT_ID = "document_id"
|
||||
ORIGINAL_DOCUMENT_ID = "original_document_id"
|
||||
BATCH = "batch"
|
||||
DATASET_ID = "dataset_id"
|
||||
DATASOURCE_TYPE = "datasource_type"
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
INVOKE_FROM = "invoke_from"
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
ANSWER = "answer"
|
||||
LLM = "llm"
|
||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||
KNOWLEDGE_INDEX = "knowledge-index"
|
||||
IF_ELSE = "if-else"
|
||||
CODE = "code"
|
||||
TEMPLATE_TRANSFORM = "template-transform"
|
||||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
DATASOURCE = "datasource"
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
LOOP_START = "loop-start"
|
||||
LOOP_END = "loop-end"
|
||||
ITERATION = "iteration"
|
||||
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
VARIABLE_ASSIGNER = "assigner"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
AGENT = "agent"
|
||||
TRIGGER_WEBHOOK = "trigger-webhook"
|
||||
TRIGGER_SCHEDULE = "trigger-schedule"
|
||||
TRIGGER_PLUGIN = "trigger-plugin"
|
||||
HUMAN_INPUT = "human-input"
|
||||
|
||||
@property
|
||||
def is_trigger_node(self) -> bool:
|
||||
"""Check if this node type is a trigger node."""
|
||||
return self in [
|
||||
NodeType.TRIGGER_WEBHOOK,
|
||||
NodeType.TRIGGER_SCHEDULE,
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
]
|
||||
|
||||
@property
|
||||
def is_start_node(self) -> bool:
|
||||
"""Check if this node type can serve as a workflow entry point."""
|
||||
return self in [
|
||||
NodeType.START,
|
||||
NodeType.DATASOURCE,
|
||||
NodeType.TRIGGER_WEBHOOK,
|
||||
NodeType.TRIGGER_SCHEDULE,
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
]
|
||||
|
||||
|
||||
class NodeExecutionType(StrEnum):
|
||||
"""Node execution type classification."""
|
||||
|
||||
EXECUTABLE = "executable" # Regular nodes that execute and produce outputs
|
||||
RESPONSE = "response" # Response nodes that stream outputs (Answer, End)
|
||||
BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier)
|
||||
CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph)
|
||||
ROOT = "root" # Nodes that can serve as execution entry points
|
||||
|
||||
|
||||
class ErrorStrategy(StrEnum):
|
||||
FAIL_BRANCH = "fail-branch"
|
||||
DEFAULT_VALUE = "default-value"
|
||||
|
||||
|
||||
class FailBranchSourceHandle(StrEnum):
|
||||
FAILED = "fail-branch"
|
||||
SUCCESS = "success-branch"
|
||||
|
||||
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
Workflow Type Enum for domain layer
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
|
||||
class WorkflowExecutionStatus(StrEnum):
|
||||
# State diagram for the workflw status:
|
||||
# (@) means start, (*) means end
|
||||
#
|
||||
# ┌------------------>------------------------->------------------->--------------┐
|
||||
# | |
|
||||
# | ┌-----------------------<--------------------┐ |
|
||||
# ^ | | |
|
||||
# | | ^ |
|
||||
# | V | |
|
||||
# ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V
|
||||
# | Scheduled |------->| Running |---------------------->| paused | |
|
||||
# └-----------┘ └-----------------------┘ └-----------┘ |
|
||||
# | | | | | | |
|
||||
# | | | | | | |
|
||||
# ^ | | | V V |
|
||||
# | | | | | ┌---------┐ |
|
||||
# (@) | | | └------------------------>| Stopped |<----┘
|
||||
# | | | └---------┘
|
||||
# | | | |
|
||||
# | | V V
|
||||
# | | ┌-----------┐ |
|
||||
# | | | Succeeded |------------->--------------┤
|
||||
# | | └-----------┘ |
|
||||
# | V V
|
||||
# | +--------┐ |
|
||||
# | | Failed |---------------------->----------------┤
|
||||
# | └--------┘ |
|
||||
# V V
|
||||
# ┌---------------------┐ |
|
||||
# | Partially Succeeded |---------------------->-----------------┘--------> (*)
|
||||
# └---------------------┘
|
||||
#
|
||||
# Mermaid diagram:
|
||||
#
|
||||
# ---
|
||||
# title: State diagram for Workflow run state
|
||||
# ---
|
||||
# stateDiagram-v2
|
||||
# scheduled: Scheduled
|
||||
# running: Running
|
||||
# succeeded: Succeeded
|
||||
# failed: Failed
|
||||
# partial_succeeded: Partial Succeeded
|
||||
# paused: Paused
|
||||
# stopped: Stopped
|
||||
#
|
||||
# [*] --> scheduled:
|
||||
# scheduled --> running: Start Execution
|
||||
# running --> paused: Human input required
|
||||
# paused --> running: human input added
|
||||
# paused --> stopped: User stops execution
|
||||
# running --> succeeded: Execution finishes without any error
|
||||
# running --> failed: Execution finishes with errors
|
||||
# running --> stopped: User stops execution
|
||||
# running --> partial_succeeded: some execution occurred and handled during execution
|
||||
#
|
||||
# scheduled --> stopped: User stops execution
|
||||
#
|
||||
# succeeded --> [*]
|
||||
# failed --> [*]
|
||||
# partial_succeeded --> [*]
|
||||
# stopped --> [*]
|
||||
|
||||
# `SCHEDULED` means that the workflow is scheduled to run, but has not
|
||||
# started running yet. (maybe due to possible worker saturation.)
|
||||
#
|
||||
# This enum value is currently unused.
|
||||
SCHEDULED = "scheduled"
|
||||
|
||||
# `RUNNING` means the workflow is exeuting.
|
||||
RUNNING = "running"
|
||||
|
||||
# `SUCCEEDED` means the execution of workflow succeed without any error.
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
# `FAILED` means the execution of workflow failed without some errors.
|
||||
FAILED = "failed"
|
||||
|
||||
# `STOPPED` means the execution of workflow was stopped, either manually
|
||||
# by the user, or automatically by the Dify application (E.G. the moderation
|
||||
# mechanism.)
|
||||
STOPPED = "stopped"
|
||||
|
||||
# `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow
|
||||
# execution, but they were successfully handled (e.g., by using an error
|
||||
# strategy such as "fail branch" or "default value").
|
||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||
|
||||
# `PAUSED` indicates that the workflow execution is temporarily paused
|
||||
# (e.g., awaiting human input) and is expected to resume later.
|
||||
PAUSED = "paused"
|
||||
|
||||
def is_ended(self) -> bool:
|
||||
return self in _END_STATE
|
||||
|
||||
|
||||
_END_STATE = frozenset(
|
||||
[
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
WorkflowExecutionStatus.STOPPED,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
TRIGGER_INFO = "trigger_info"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
PENDING = "pending" # Node is scheduled but not yet executing
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
EXCEPTION = "exception"
|
||||
STOPPED = "stopped"
|
||||
PAUSED = "paused"
|
||||
|
||||
# Legacy statuses - kept for backward compatibility
|
||||
RETRY = "retry" # Legacy: replaced by retry mechanism in error handling
|
||||
16
dify/api/core/workflow/errors.py
Normal file
16
dify/api/core/workflow/errors.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
|
||||
class WorkflowNodeRunFailedError(Exception):
|
||||
def __init__(self, node: Node, err_msg: str):
|
||||
self._node = node
|
||||
self._error = err_msg
|
||||
super().__init__(f"Node {node.title} run failed: {err_msg}")
|
||||
|
||||
@property
|
||||
def node(self) -> Node:
|
||||
return self._node
|
||||
|
||||
@property
|
||||
def error(self) -> str:
|
||||
return self._error
|
||||
11
dify/api/core/workflow/graph/__init__.py
Normal file
11
dify/api/core/workflow/graph/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .edge import Edge
|
||||
from .graph import Graph, GraphBuilder, NodeFactory
|
||||
from .graph_template import GraphTemplate
|
||||
|
||||
__all__ = [
|
||||
"Edge",
|
||||
"Graph",
|
||||
"GraphBuilder",
|
||||
"GraphTemplate",
|
||||
"NodeFactory",
|
||||
]
|
||||
15
dify/api/core/workflow/graph/edge.py
Normal file
15
dify/api/core/workflow/graph/edge.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
|
||||
@dataclass
|
||||
class Edge:
|
||||
"""Edge connecting two nodes in a workflow graph."""
|
||||
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
tail: str = "" # tail node id (source)
|
||||
head: str = "" # head node id (target)
|
||||
source_handle: str = "source" # source handle for conditional branching
|
||||
state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state
|
||||
465
dify/api/core/workflow/graph/graph.py
Normal file
465
dify/api/core/workflow/graph/graph.py
Normal file
@@ -0,0 +1,465 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Protocol, cast, final
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .edge import Edge
|
||||
from .validation import get_graph_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeFactory(Protocol):
|
||||
"""
|
||||
Protocol for creating Node instances from node data dictionaries.
|
||||
|
||||
This protocol decouples the Graph class from specific node mapping implementations,
|
||||
allowing for different node creation strategies while maintaining type safety.
|
||||
"""
|
||||
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data.
|
||||
|
||||
:param node_config: node configuration dictionary containing type and other data
|
||||
:return: initialized Node instance
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@final
|
||||
class Graph:
|
||||
"""Graph representation with nodes and edges for workflow execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
nodes: dict[str, Node] | None = None,
|
||||
edges: dict[str, Edge] | None = None,
|
||||
in_edges: dict[str, list[str]] | None = None,
|
||||
out_edges: dict[str, list[str]] | None = None,
|
||||
root_node: Node,
|
||||
):
|
||||
"""
|
||||
Initialize Graph instance.
|
||||
|
||||
:param nodes: graph nodes mapping (node id: node object)
|
||||
:param edges: graph edges mapping (edge id: edge object)
|
||||
:param in_edges: incoming edges mapping (node id: list of edge ids)
|
||||
:param out_edges: outgoing edges mapping (node id: list of edge ids)
|
||||
:param root_node: root node object
|
||||
"""
|
||||
self.nodes = nodes or {}
|
||||
self.edges = edges or {}
|
||||
self.in_edges = in_edges or {}
|
||||
self.out_edges = out_edges or {}
|
||||
self.root_node = root_node
|
||||
|
||||
@classmethod
|
||||
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
|
||||
"""
|
||||
Parse node configurations and build a mapping of node IDs to configs.
|
||||
|
||||
:param node_configs: list of node configuration dictionaries
|
||||
:return: mapping of node ID to node config
|
||||
"""
|
||||
node_configs_map: dict[str, dict[str, object]] = {}
|
||||
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id or not isinstance(node_id, str):
|
||||
continue
|
||||
|
||||
node_configs_map[node_id] = node_config
|
||||
|
||||
return node_configs_map
|
||||
|
||||
@classmethod
|
||||
def _find_root_node_id(
|
||||
cls,
|
||||
node_configs_map: Mapping[str, Mapping[str, object]],
|
||||
edge_configs: Sequence[Mapping[str, object]],
|
||||
root_node_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Find the root node ID if not specified.
|
||||
|
||||
:param node_configs_map: mapping of node ID to node config
|
||||
:param edge_configs: list of edge configurations
|
||||
:param root_node_id: explicitly specified root node ID
|
||||
:return: determined root node ID
|
||||
"""
|
||||
if root_node_id:
|
||||
if root_node_id not in node_configs_map:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
return root_node_id
|
||||
|
||||
# Find nodes with no incoming edges
|
||||
nodes_with_incoming: set[str] = set()
|
||||
for edge_config in edge_configs:
|
||||
target = edge_config.get("target")
|
||||
if isinstance(target, str):
|
||||
nodes_with_incoming.add(target)
|
||||
|
||||
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
|
||||
|
||||
# Prefer START node if available
|
||||
start_node_id = None
|
||||
for nid in root_candidates:
|
||||
node_data = node_configs_map[nid].get("data")
|
||||
if not is_str_dict(node_data):
|
||||
continue
|
||||
node_type = node_data.get("type")
|
||||
if not isinstance(node_type, str):
|
||||
continue
|
||||
if NodeType(node_type).is_start_node:
|
||||
start_node_id = nid
|
||||
break
|
||||
|
||||
root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
|
||||
|
||||
if not root_node_id:
|
||||
raise ValueError("Unable to determine root node ID")
|
||||
|
||||
return root_node_id
|
||||
|
||||
@classmethod
|
||||
def _build_edges(
|
||||
cls, edge_configs: list[dict[str, object]]
|
||||
) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
|
||||
"""
|
||||
Build edge objects and mappings from edge configurations.
|
||||
|
||||
:param edge_configs: list of edge configurations
|
||||
:return: tuple of (edges dict, in_edges dict, out_edges dict)
|
||||
"""
|
||||
edges: dict[str, Edge] = {}
|
||||
in_edges: dict[str, list[str]] = defaultdict(list)
|
||||
out_edges: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
edge_counter = 0
|
||||
for edge_config in edge_configs:
|
||||
source = edge_config.get("source")
|
||||
target = edge_config.get("target")
|
||||
|
||||
if not is_str(source) or not is_str(target):
|
||||
continue
|
||||
|
||||
# Create edge
|
||||
edge_id = f"edge_{edge_counter}"
|
||||
edge_counter += 1
|
||||
|
||||
source_handle = edge_config.get("sourceHandle", "source")
|
||||
if not is_str(source_handle):
|
||||
continue
|
||||
|
||||
edge = Edge(
|
||||
id=edge_id,
|
||||
tail=source,
|
||||
head=target,
|
||||
source_handle=source_handle,
|
||||
)
|
||||
|
||||
edges[edge_id] = edge
|
||||
out_edges[source].append(edge_id)
|
||||
in_edges[target].append(edge_id)
|
||||
|
||||
return edges, dict(in_edges), dict(out_edges)
|
||||
|
||||
@classmethod
|
||||
def _create_node_instances(
|
||||
cls,
|
||||
node_configs_map: dict[str, dict[str, object]],
|
||||
node_factory: "NodeFactory",
|
||||
) -> dict[str, Node]:
|
||||
"""
|
||||
Create node instances from configurations using the node factory.
|
||||
|
||||
:param node_configs_map: mapping of node ID to node config
|
||||
:param node_factory: factory for creating node instances
|
||||
:return: mapping of node ID to node instance
|
||||
"""
|
||||
nodes: dict[str, Node] = {}
|
||||
|
||||
for node_id, node_config in node_configs_map.items():
|
||||
try:
|
||||
node_instance = node_factory.create_node(node_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to create node instance for node_id %s", node_id)
|
||||
raise
|
||||
nodes[node_id] = node_instance
|
||||
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def new(cls) -> "GraphBuilder":
|
||||
"""Create a fluent builder for assembling a graph programmatically."""
|
||||
|
||||
return GraphBuilder(graph_cls=cls)
|
||||
|
||||
@classmethod
|
||||
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
|
||||
"""
|
||||
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
|
||||
|
||||
:param nodes: mapping of node ID to node instance
|
||||
"""
|
||||
for node in nodes.values():
|
||||
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
node.execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
@classmethod
|
||||
def _mark_inactive_root_branches(
|
||||
cls,
|
||||
nodes: dict[str, Node],
|
||||
edges: dict[str, Edge],
|
||||
in_edges: dict[str, list[str]],
|
||||
out_edges: dict[str, list[str]],
|
||||
active_root_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Mark nodes and edges from inactive root branches as skipped.
|
||||
|
||||
Algorithm:
|
||||
1. Mark inactive root nodes as skipped
|
||||
2. For skipped nodes, mark all their outgoing edges as skipped
|
||||
3. For each edge marked as skipped, check its target node:
|
||||
- If ALL incoming edges are skipped, mark the node as skipped
|
||||
- Otherwise, leave the node state unchanged
|
||||
|
||||
:param nodes: mapping of node ID to node instance
|
||||
:param edges: mapping of edge ID to edge instance
|
||||
:param in_edges: mapping of node ID to incoming edge IDs
|
||||
:param out_edges: mapping of node ID to outgoing edge IDs
|
||||
:param active_root_id: ID of the active root node
|
||||
"""
|
||||
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
|
||||
top_level_roots: list[str] = [
|
||||
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
|
||||
]
|
||||
|
||||
# If there's only one root or the active root is not a top-level root, no marking needed
|
||||
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
|
||||
return
|
||||
|
||||
# Mark inactive root nodes as skipped
|
||||
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
|
||||
for root_id in inactive_roots:
|
||||
if root_id in nodes:
|
||||
nodes[root_id].state = NodeState.SKIPPED
|
||||
|
||||
# Recursively mark downstream nodes and edges
|
||||
def mark_downstream(node_id: str) -> None:
|
||||
"""Recursively mark downstream nodes and edges as skipped."""
|
||||
if nodes[node_id].state != NodeState.SKIPPED:
|
||||
return
|
||||
# If this node is skipped, mark all its outgoing edges as skipped
|
||||
out_edge_ids = out_edges.get(node_id, [])
|
||||
for edge_id in out_edge_ids:
|
||||
edge = edges[edge_id]
|
||||
edge.state = NodeState.SKIPPED
|
||||
|
||||
# Check the target node of this edge
|
||||
target_node = nodes[edge.head]
|
||||
in_edge_ids = in_edges.get(target_node.id, [])
|
||||
in_edge_states = [edges[eid].state for eid in in_edge_ids]
|
||||
|
||||
# If all incoming edges are skipped, mark the node as skipped
|
||||
if all(state == NodeState.SKIPPED for state in in_edge_states):
|
||||
target_node.state = NodeState.SKIPPED
|
||||
# Recursively process downstream nodes
|
||||
mark_downstream(target_node.id)
|
||||
|
||||
# Process each inactive root and its downstream nodes
|
||||
for root_id in inactive_roots:
|
||||
mark_downstream(root_id)
|
||||
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, object],
|
||||
node_factory: "NodeFactory",
|
||||
root_node_id: str | None = None,
|
||||
) -> "Graph":
|
||||
"""
|
||||
Initialize graph
|
||||
|
||||
:param graph_config: graph config containing nodes and edges
|
||||
:param node_factory: factory for creating node instances from config data
|
||||
:param root_node_id: root node id
|
||||
:return: graph instance
|
||||
"""
|
||||
# Parse configs
|
||||
edge_configs = graph_config.get("edges", [])
|
||||
node_configs = graph_config.get("nodes", [])
|
||||
|
||||
edge_configs = cast(list[dict[str, object]], edge_configs)
|
||||
node_configs = cast(list[dict[str, object]], node_configs)
|
||||
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
||||
|
||||
# Parse node configurations
|
||||
node_configs_map = cls._parse_node_configs(node_configs)
|
||||
|
||||
# Find root node
|
||||
root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
|
||||
|
||||
# Build edges
|
||||
edges, in_edges, out_edges = cls._build_edges(edge_configs)
|
||||
|
||||
# Create node instances
|
||||
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
||||
|
||||
# Promote fail-branch nodes to branch execution type at graph level
|
||||
cls._promote_fail_branch_nodes(nodes)
|
||||
|
||||
# Get root node instance
|
||||
root_node = nodes[root_node_id]
|
||||
|
||||
# Mark inactive root branches as skipped
|
||||
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
||||
|
||||
# Create and return the graph
|
||||
graph = cls(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
in_edges=in_edges,
|
||||
out_edges=out_edges,
|
||||
root_node=root_node,
|
||||
)
|
||||
|
||||
# Validate the graph structure using built-in validators
|
||||
get_graph_validator().validate(graph)
|
||||
|
||||
return graph
|
||||
|
||||
@property
|
||||
def node_ids(self) -> list[str]:
|
||||
"""
|
||||
Get list of node IDs (compatibility property for existing code)
|
||||
|
||||
:return: list of node IDs
|
||||
"""
|
||||
return list(self.nodes.keys())
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> list[Edge]:
|
||||
"""
|
||||
Get all outgoing edges from a node (V2 method)
|
||||
|
||||
:param node_id: node id
|
||||
:return: list of outgoing edges
|
||||
"""
|
||||
edge_ids = self.out_edges.get(node_id, [])
|
||||
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||
|
||||
def get_incoming_edges(self, node_id: str) -> list[Edge]:
|
||||
"""
|
||||
Get all incoming edges to a node (V2 method)
|
||||
|
||||
:param node_id: node id
|
||||
:return: list of incoming edges
|
||||
"""
|
||||
edge_ids = self.in_edges.get(node_id, [])
|
||||
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||
|
||||
|
||||
@final
|
||||
class GraphBuilder:
|
||||
"""Fluent helper for constructing simple graphs, primarily for tests."""
|
||||
|
||||
def __init__(self, *, graph_cls: type[Graph]):
|
||||
self._graph_cls = graph_cls
|
||||
self._nodes: list[Node] = []
|
||||
self._nodes_by_id: dict[str, Node] = {}
|
||||
self._edges: list[Edge] = []
|
||||
self._edge_counter = 0
|
||||
|
||||
def add_root(self, node: Node) -> "GraphBuilder":
|
||||
"""Register the root node. Must be called exactly once."""
|
||||
|
||||
if self._nodes:
|
||||
raise ValueError("Root node has already been added")
|
||||
self._register_node(node)
|
||||
self._nodes.append(node)
|
||||
return self
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
node: Node,
|
||||
*,
|
||||
from_node_id: str | None = None,
|
||||
source_handle: str = "source",
|
||||
) -> "GraphBuilder":
|
||||
"""Append a node and connect it from the specified predecessor."""
|
||||
|
||||
if not self._nodes:
|
||||
raise ValueError("Root node must be added before adding other nodes")
|
||||
|
||||
predecessor_id = from_node_id or self._nodes[-1].id
|
||||
if predecessor_id not in self._nodes_by_id:
|
||||
raise ValueError(f"Predecessor node '{predecessor_id}' not found")
|
||||
|
||||
predecessor = self._nodes_by_id[predecessor_id]
|
||||
self._register_node(node)
|
||||
self._nodes.append(node)
|
||||
|
||||
edge_id = f"edge_{self._edge_counter}"
|
||||
self._edge_counter += 1
|
||||
edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle)
|
||||
self._edges.append(edge)
|
||||
|
||||
return self
|
||||
|
||||
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
|
||||
"""Connect two existing nodes without adding a new node."""
|
||||
|
||||
if tail not in self._nodes_by_id:
|
||||
raise ValueError(f"Tail node '{tail}' not found")
|
||||
if head not in self._nodes_by_id:
|
||||
raise ValueError(f"Head node '{head}' not found")
|
||||
|
||||
edge_id = f"edge_{self._edge_counter}"
|
||||
self._edge_counter += 1
|
||||
edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle)
|
||||
self._edges.append(edge)
|
||||
|
||||
return self
|
||||
|
||||
def build(self) -> Graph:
|
||||
"""Materialize the graph instance from the accumulated nodes and edges."""
|
||||
|
||||
if not self._nodes:
|
||||
raise ValueError("Cannot build an empty graph")
|
||||
|
||||
nodes = {node.id: node for node in self._nodes}
|
||||
edges = {edge.id: edge for edge in self._edges}
|
||||
in_edges: dict[str, list[str]] = defaultdict(list)
|
||||
out_edges: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
for edge in self._edges:
|
||||
out_edges[edge.tail].append(edge.id)
|
||||
in_edges[edge.head].append(edge.id)
|
||||
|
||||
return self._graph_cls(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
in_edges=dict(in_edges),
|
||||
out_edges=dict(out_edges),
|
||||
root_node=self._nodes[0],
|
||||
)
|
||||
|
||||
def _register_node(self, node: Node) -> None:
|
||||
if not node.id:
|
||||
raise ValueError("Node must have a non-empty id")
|
||||
if node.id in self._nodes_by_id:
|
||||
raise ValueError(f"Duplicate node id detected: {node.id}")
|
||||
self._nodes_by_id[node.id] = node
|
||||
20
dify/api/core/workflow/graph/graph_template.py
Normal file
20
dify/api/core/workflow/graph/graph_template.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GraphTemplate(BaseModel):
|
||||
"""
|
||||
Graph Template for container nodes and subgraph expansion
|
||||
|
||||
According to GraphEngine V2 spec, GraphTemplate contains:
|
||||
- nodes: mapping of node definitions
|
||||
- edges: mapping of edge definitions
|
||||
- root_ids: list of root node IDs
|
||||
- output_selectors: list of output selectors for the template
|
||||
"""
|
||||
|
||||
nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping")
|
||||
edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping")
|
||||
root_ids: list[str] = Field(default_factory=list, description="root node IDs")
|
||||
output_selectors: list[str] = Field(default_factory=list, description="output selectors")
|
||||
161
dify/api/core/workflow/graph/validation.py
Normal file
161
dify/api/core/workflow/graph/validation.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import Graph
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GraphValidationIssue:
|
||||
"""Immutable value object describing a single validation issue."""
|
||||
|
||||
code: str
|
||||
message: str
|
||||
node_id: str | None = None
|
||||
|
||||
|
||||
class GraphValidationError(ValueError):
|
||||
"""Raised when graph validation fails."""
|
||||
|
||||
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
|
||||
if not issues:
|
||||
raise ValueError("GraphValidationError requires at least one issue.")
|
||||
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
|
||||
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class GraphValidationRule(Protocol):
|
||||
"""Protocol that individual validation rules must satisfy."""
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
"""Validate the provided graph and return any discovered issues."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _EdgeEndpointValidator:
|
||||
"""Ensures all edges reference existing nodes."""
|
||||
|
||||
missing_node_code: str = "MISSING_NODE"
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
issues: list[GraphValidationIssue] = []
|
||||
for edge in graph.edges.values():
|
||||
if edge.tail not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.missing_node_code,
|
||||
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
|
||||
node_id=edge.tail,
|
||||
)
|
||||
)
|
||||
if edge.head not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.missing_node_code,
|
||||
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
|
||||
node_id=edge.head,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _RootNodeValidator:
|
||||
"""Validates root node invariants."""
|
||||
|
||||
invalid_root_code: str = "INVALID_ROOT"
|
||||
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
root_node = graph.root_node
|
||||
issues: list[GraphValidationIssue] = []
|
||||
if root_node.id not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.invalid_root_code,
|
||||
message=f"Root node '{root_node.id}' is missing from the node registry.",
|
||||
node_id=root_node.id,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
node_type = getattr(root_node, "node_type", None)
|
||||
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.invalid_root_code,
|
||||
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
|
||||
node_id=root_node.id,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GraphValidator:
|
||||
"""Coordinates execution of graph validation rules."""
|
||||
|
||||
rules: tuple[GraphValidationRule, ...]
|
||||
|
||||
def validate(self, graph: Graph) -> None:
|
||||
"""Validate the graph against all configured rules."""
|
||||
issues: list[GraphValidationIssue] = []
|
||||
for rule in self.rules:
|
||||
issues.extend(rule.validate(graph))
|
||||
|
||||
if issues:
|
||||
raise GraphValidationError(issues)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _TriggerStartExclusivityValidator:
|
||||
"""Ensures trigger nodes do not coexist with UserInput (start) nodes."""
|
||||
|
||||
conflict_code: str = "TRIGGER_START_NODE_CONFLICT"
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
start_node_id: str | None = None
|
||||
trigger_node_ids: list[str] = []
|
||||
|
||||
for node in graph.nodes.values():
|
||||
node_type = getattr(node, "node_type", None)
|
||||
if not isinstance(node_type, NodeType):
|
||||
continue
|
||||
|
||||
if node_type == NodeType.START:
|
||||
start_node_id = node.id
|
||||
elif node_type.is_trigger_node:
|
||||
trigger_node_ids.append(node.id)
|
||||
|
||||
if start_node_id and trigger_node_ids:
|
||||
trigger_list = ", ".join(trigger_node_ids)
|
||||
return [
|
||||
GraphValidationIssue(
|
||||
code=self.conflict_code,
|
||||
message=(
|
||||
f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}."
|
||||
),
|
||||
node_id=start_node_id,
|
||||
)
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
|
||||
_EdgeEndpointValidator(),
|
||||
_RootNodeValidator(),
|
||||
_TriggerStartExclusivityValidator(),
|
||||
)
|
||||
|
||||
|
||||
def get_graph_validator() -> GraphValidator:
|
||||
"""Construct the validator composed of default rules."""
|
||||
return GraphValidator(_DEFAULT_RULES)
|
||||
3
dify/api/core/workflow/graph_engine/__init__.py
Normal file
3
dify/api/core/workflow/graph_engine/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .graph_engine import GraphEngine
|
||||
|
||||
__all__ = ["GraphEngine"]
|
||||
@@ -0,0 +1,33 @@
|
||||
# Command Channels
|
||||
|
||||
Channel implementations for external workflow control.
|
||||
|
||||
## Components
|
||||
|
||||
### InMemoryChannel
|
||||
|
||||
Thread-safe in-memory queue for single-process deployments.
|
||||
|
||||
- `fetch_commands()` - Get pending commands
|
||||
- `send_command()` - Add command to queue
|
||||
|
||||
### RedisChannel
|
||||
|
||||
Redis-based queue for distributed deployments.
|
||||
|
||||
- `fetch_commands()` - Get commands with JSON deserialization
|
||||
- `send_command()` - Store commands with TTL
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
# Local execution
|
||||
channel = InMemoryChannel()
|
||||
channel.send_command(AbortCommand(graph_id="workflow-123"))
|
||||
|
||||
# Distributed execution
|
||||
redis_channel = RedisChannel(
|
||||
redis_client=redis_client,
|
||||
channel_key="workflow:123:commands"
|
||||
)
|
||||
```
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Command channel implementations for GraphEngine."""
|
||||
|
||||
from .in_memory_channel import InMemoryChannel
|
||||
from .redis_channel import RedisChannel
|
||||
|
||||
__all__ = ["InMemoryChannel", "RedisChannel"]
|
||||
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
In-memory implementation of CommandChannel for local/testing scenarios.
|
||||
|
||||
This implementation uses a thread-safe queue for command communication
|
||||
within a single process. Each instance handles commands for one workflow execution.
|
||||
"""
|
||||
|
||||
from queue import Queue
|
||||
from typing import final
|
||||
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
|
||||
|
||||
@final
|
||||
class InMemoryChannel:
|
||||
"""
|
||||
In-memory command channel implementation using a thread-safe queue.
|
||||
|
||||
Each instance is dedicated to a single GraphEngine/workflow execution.
|
||||
Suitable for local development, testing, and single-instance deployments.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the in-memory channel with a single queue."""
|
||||
self._queue: Queue[GraphEngineCommand] = Queue()
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch all pending commands from the queue.
|
||||
|
||||
Returns:
|
||||
List of pending commands (drains the queue)
|
||||
"""
|
||||
commands: list[GraphEngineCommand] = []
|
||||
|
||||
# Drain all available commands from the queue
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
command = self._queue.get_nowait()
|
||||
commands.append(command)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
return commands
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to this channel's queue.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
self._queue.put(command)
|
||||
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Redis-based implementation of CommandChannel for distributed scenarios.
|
||||
|
||||
This implementation uses Redis lists for command queuing, supporting
|
||||
multi-instance deployments and cross-server communication.
|
||||
Each instance uses a unique key for its command queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
|
||||
|
||||
@final
|
||||
class RedisChannel:
|
||||
"""
|
||||
Redis-based command channel implementation for distributed systems.
|
||||
|
||||
Each instance uses a unique Redis key for its command queue.
|
||||
Commands are JSON-serialized for transport.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: "RedisClientWrapper",
|
||||
channel_key: str,
|
||||
command_ttl: int = 3600,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Redis channel.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance
|
||||
channel_key: Unique key for this channel's command queue
|
||||
command_ttl: TTL for command keys in seconds (default: 3600)
|
||||
"""
|
||||
self._redis = redis_client
|
||||
self._key = channel_key
|
||||
self._command_ttl = command_ttl
|
||||
self._pending_key = f"{channel_key}:pending"
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch all pending commands from Redis.
|
||||
|
||||
Returns:
|
||||
List of pending commands (drains the Redis list)
|
||||
"""
|
||||
if not self._has_pending_commands():
|
||||
return []
|
||||
|
||||
commands: list[GraphEngineCommand] = []
|
||||
|
||||
# Use pipeline for atomic operations
|
||||
with self._redis.pipeline() as pipe:
|
||||
# Get all commands and clear the list atomically
|
||||
pipe.lrange(self._key, 0, -1)
|
||||
pipe.delete(self._key)
|
||||
results = pipe.execute()
|
||||
|
||||
# Parse commands from JSON
|
||||
if results[0]:
|
||||
for command_json in results[0]:
|
||||
try:
|
||||
command_data = json.loads(command_json)
|
||||
command = self._deserialize_command(command_data)
|
||||
if command:
|
||||
commands.append(command)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Skip invalid commands
|
||||
continue
|
||||
|
||||
return commands
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to Redis.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
command_json = json.dumps(command.model_dump())
|
||||
|
||||
# Push to list and set expiry
|
||||
with self._redis.pipeline() as pipe:
|
||||
pipe.rpush(self._key, command_json)
|
||||
pipe.expire(self._key, self._command_ttl)
|
||||
pipe.set(self._pending_key, "1", ex=self._command_ttl)
|
||||
pipe.execute()
|
||||
|
||||
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
|
||||
"""
|
||||
Deserialize a command from dictionary data.
|
||||
|
||||
Args:
|
||||
data: Command data dictionary
|
||||
|
||||
Returns:
|
||||
Deserialized command or None if invalid
|
||||
"""
|
||||
command_type_value = data.get("command_type")
|
||||
if not isinstance(command_type_value, str):
|
||||
return None
|
||||
|
||||
try:
|
||||
command_type = CommandType(command_type_value)
|
||||
|
||||
if command_type == CommandType.ABORT:
|
||||
return AbortCommand.model_validate(data)
|
||||
if command_type == CommandType.PAUSE:
|
||||
return PauseCommand.model_validate(data)
|
||||
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def _has_pending_commands(self) -> bool:
|
||||
"""
|
||||
Check and consume the pending marker to avoid unnecessary list reads.
|
||||
|
||||
Returns:
|
||||
True if commands should be fetched from Redis.
|
||||
"""
|
||||
with self._redis.pipeline() as pipe:
|
||||
pipe.get(self._pending_key)
|
||||
pipe.delete(self._pending_key)
|
||||
pending_value, _ = pipe.execute()
|
||||
|
||||
return pending_value is not None
|
||||
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Command processing subsystem for graph engine.
|
||||
|
||||
This package handles external commands sent to the engine
|
||||
during execution.
|
||||
"""
|
||||
|
||||
from .command_handlers import AbortCommandHandler, PauseCommandHandler
|
||||
from .command_processor import CommandProcessor
|
||||
|
||||
__all__ = [
|
||||
"AbortCommandHandler",
|
||||
"CommandProcessor",
|
||||
"PauseCommandHandler",
|
||||
]
|
||||
@@ -0,0 +1,33 @@
|
||||
import logging
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class AbortCommandHandler(CommandHandler):
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, AbortCommand)
|
||||
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.abort(command.reason or "User requested abort")
|
||||
|
||||
|
||||
@final
|
||||
class PauseCommandHandler(CommandHandler):
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, PauseCommand)
|
||||
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
|
||||
# Convert string reason to PauseReason if needed
|
||||
reason = command.reason
|
||||
pause_reason = SchedulingPause(message=reason)
|
||||
execution.pause(pause_reason)
|
||||
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Main command processor for handling external commands.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Protocol, final
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
from ..protocols.command_channel import CommandChannel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommandHandler(Protocol):
|
||||
"""Protocol for command handlers."""
|
||||
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...
|
||||
|
||||
|
||||
@final
|
||||
class CommandProcessor:
|
||||
"""
|
||||
Processes external commands sent to the engine.
|
||||
|
||||
This polls the command channel and dispatches commands to
|
||||
appropriate handlers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command_channel: CommandChannel,
|
||||
graph_execution: GraphExecution,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the command processor.
|
||||
|
||||
Args:
|
||||
command_channel: Channel for receiving commands
|
||||
graph_execution: Graph execution aggregate
|
||||
"""
|
||||
self._command_channel = command_channel
|
||||
self._graph_execution = graph_execution
|
||||
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
|
||||
|
||||
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
|
||||
"""
|
||||
Register a handler for a command type.
|
||||
|
||||
Args:
|
||||
command_type: Type of command to handle
|
||||
handler: Handler for the command
|
||||
"""
|
||||
self._handlers[command_type] = handler
|
||||
|
||||
def process_commands(self) -> None:
|
||||
"""Check for and process any pending commands."""
|
||||
try:
|
||||
commands = self._command_channel.fetch_commands()
|
||||
for command in commands:
|
||||
self._handle_command(command)
|
||||
except Exception as e:
|
||||
logger.warning("Error processing commands: %s", e)
|
||||
|
||||
def _handle_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Handle a single command.
|
||||
|
||||
Args:
|
||||
command: The command to handle
|
||||
"""
|
||||
handler = self._handlers.get(type(command))
|
||||
if handler:
|
||||
try:
|
||||
handler.handle(command, self._graph_execution)
|
||||
except Exception:
|
||||
logger.exception("Error handling command %s", command.__class__.__name__)
|
||||
else:
|
||||
logger.warning("No handler registered for command: %s", command.__class__.__name__)
|
||||
14
dify/api/core/workflow/graph_engine/domain/__init__.py
Normal file
14
dify/api/core/workflow/graph_engine/domain/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Domain models for graph engine.
|
||||
|
||||
This package contains the core domain entities, value objects, and aggregates
|
||||
that represent the business concepts of workflow graph execution.
|
||||
"""
|
||||
|
||||
from .graph_execution import GraphExecution
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
__all__ = [
|
||||
"GraphExecution",
|
||||
"NodeExecution",
|
||||
]
|
||||
240
dify/api/core/workflow/graph_engine/domain/graph_execution.py
Normal file
240
dify/api/core/workflow/graph_engine/domain/graph_execution.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""GraphExecution aggregate root managing the overall graph execution state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from importlib import import_module
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
|
||||
class GraphExecutionErrorState(BaseModel):
|
||||
"""Serializable representation of an execution error."""
|
||||
|
||||
module: str = Field(description="Module containing the exception class")
|
||||
qualname: str = Field(description="Qualified name of the exception class")
|
||||
message: str | None = Field(default=None, description="Exception message string")
|
||||
|
||||
|
||||
class NodeExecutionState(BaseModel):
|
||||
"""Serializable representation of a node execution entity."""
|
||||
|
||||
node_id: str
|
||||
state: NodeState = Field(default=NodeState.UNKNOWN)
|
||||
retry_count: int = Field(default=0)
|
||||
execution_id: str | None = Field(default=None)
|
||||
error: str | None = Field(default=None)
|
||||
|
||||
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Pydantic model describing serialized GraphExecution state."""
|
||||
|
||||
type: Literal["GraphExecution"] = Field(default="GraphExecution")
|
||||
version: str = Field(default="1.0")
|
||||
workflow_id: str
|
||||
started: bool = Field(default=False)
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
paused: bool = Field(default=False)
|
||||
pause_reason: PauseReason | None = Field(default=None)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
|
||||
|
||||
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
||||
"""Convert an exception into its serializable representation."""
|
||||
|
||||
if error is None:
|
||||
return None
|
||||
|
||||
return GraphExecutionErrorState(
|
||||
module=error.__class__.__module__,
|
||||
qualname=error.__class__.__qualname__,
|
||||
message=str(error),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]:
|
||||
"""Locate an exception class from its module and qualified name."""
|
||||
|
||||
module = import_module(module_name)
|
||||
attr: object = module
|
||||
for part in qualname.split("."):
|
||||
attr = getattr(attr, part)
|
||||
|
||||
if isinstance(attr, type) and issubclass(attr, Exception):
|
||||
return attr
|
||||
|
||||
raise TypeError(f"{qualname} in {module_name} is not an Exception subclass")
|
||||
|
||||
|
||||
def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None:
|
||||
"""Reconstruct an exception instance from serialized data."""
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
exception_class = _resolve_exception_class(state.module, state.qualname)
|
||||
if state.message is None:
|
||||
return exception_class()
|
||||
return exception_class(state.message)
|
||||
except Exception:
|
||||
# Fallback to RuntimeError when reconstruction fails
|
||||
if state.message is None:
|
||||
return RuntimeError(state.qualname)
|
||||
return RuntimeError(state.message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphExecution:
|
||||
"""
|
||||
Aggregate root for graph execution.
|
||||
|
||||
This manages the overall execution state of a workflow graph,
|
||||
coordinating between multiple node executions.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
started: bool = False
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
paused: bool = False
|
||||
pause_reason: PauseReason | None = None
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Mark the graph execution as started."""
|
||||
if self.started:
|
||||
raise RuntimeError("Graph execution already started")
|
||||
self.started = True
|
||||
|
||||
def complete(self) -> None:
|
||||
"""Mark the graph execution as completed."""
|
||||
if not self.started:
|
||||
raise RuntimeError("Cannot complete execution that hasn't started")
|
||||
if self.completed:
|
||||
raise RuntimeError("Graph execution already completed")
|
||||
self.completed = True
|
||||
|
||||
def abort(self, reason: str) -> None:
|
||||
"""Abort the graph execution."""
|
||||
self.aborted = True
|
||||
self.error = RuntimeError(f"Aborted: {reason}")
|
||||
|
||||
def pause(self, reason: PauseReason) -> None:
|
||||
"""Pause the graph execution without marking it complete."""
|
||||
if self.completed:
|
||||
raise RuntimeError("Cannot pause execution that has completed")
|
||||
if self.aborted:
|
||||
raise RuntimeError("Cannot pause execution that has been aborted")
|
||||
if self.paused:
|
||||
return
|
||||
self.paused = True
|
||||
self.pause_reason = reason
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Mark the graph execution as failed."""
|
||||
self.error = error
|
||||
self.completed = True
|
||||
|
||||
def get_or_create_node_execution(self, node_id: str) -> NodeExecution:
|
||||
"""Get or create a node execution entity."""
|
||||
if node_id not in self.node_executions:
|
||||
self.node_executions[node_id] = NodeExecution(node_id=node_id)
|
||||
return self.node_executions[node_id]
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the execution is currently running."""
|
||||
return self.started and not self.completed and not self.aborted and not self.paused
|
||||
|
||||
@property
|
||||
def is_paused(self) -> bool:
|
||||
"""Check if the execution is currently paused."""
|
||||
return self.paused
|
||||
|
||||
@property
|
||||
def has_error(self) -> bool:
|
||||
"""Check if the execution has encountered an error."""
|
||||
return self.error is not None
|
||||
|
||||
@property
|
||||
def error_message(self) -> str | None:
|
||||
"""Get the error message if an error exists."""
|
||||
if not self.error:
|
||||
return None
|
||||
return str(self.error)
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the aggregate state into a JSON string."""
|
||||
|
||||
node_states = [
|
||||
NodeExecutionState(
|
||||
node_id=node_id,
|
||||
state=node_execution.state,
|
||||
retry_count=node_execution.retry_count,
|
||||
execution_id=node_execution.execution_id,
|
||||
error=node_execution.error,
|
||||
)
|
||||
for node_id, node_execution in sorted(self.node_executions.items())
|
||||
]
|
||||
|
||||
state = GraphExecutionState(
|
||||
workflow_id=self.workflow_id,
|
||||
started=self.started,
|
||||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
paused=self.paused,
|
||||
pause_reason=self.pause_reason,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
)
|
||||
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore aggregate state from a serialized JSON string."""
|
||||
|
||||
state = GraphExecutionState.model_validate_json(data)
|
||||
|
||||
if state.type != "GraphExecution":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported serialized version: {state.version}")
|
||||
|
||||
if self.workflow_id != state.workflow_id:
|
||||
raise ValueError("Serialized workflow_id does not match aggregate identity")
|
||||
|
||||
self.started = state.started
|
||||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.paused = state.paused
|
||||
self.pause_reason = state.pause_reason
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
item.node_id: NodeExecution(
|
||||
node_id=item.node_id,
|
||||
state=item.state,
|
||||
retry_count=item.retry_count,
|
||||
execution_id=item.execution_id,
|
||||
error=item.error,
|
||||
)
|
||||
for item in state.node_executions
|
||||
}
|
||||
|
||||
def record_node_failure(self) -> None:
|
||||
"""Increment the count of node failures encountered during execution."""
|
||||
self.exceptions_count += 1
|
||||
45
dify/api/core/workflow/graph_engine/domain/node_execution.py
Normal file
45
dify/api/core/workflow/graph_engine/domain/node_execution.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
NodeExecution entity representing a node's execution state.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeExecution:
|
||||
"""
|
||||
Entity representing the execution state of a single node.
|
||||
|
||||
This is a mutable entity that tracks the runtime state of a node
|
||||
during graph execution.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
state: NodeState = NodeState.UNKNOWN
|
||||
retry_count: int = 0
|
||||
execution_id: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
def mark_started(self, execution_id: str) -> None:
|
||||
"""Mark the node as started with an execution ID."""
|
||||
self.state = NodeState.TAKEN
|
||||
self.execution_id = execution_id
|
||||
|
||||
def mark_taken(self) -> None:
|
||||
"""Mark the node as successfully completed."""
|
||||
self.state = NodeState.TAKEN
|
||||
self.error = None
|
||||
|
||||
def mark_failed(self, error: str) -> None:
|
||||
"""Mark the node as failed with an error."""
|
||||
self.error = error
|
||||
|
||||
def mark_skipped(self) -> None:
|
||||
"""Mark the node as skipped."""
|
||||
self.state = NodeState.SKIPPED
|
||||
|
||||
def increment_retry(self) -> None:
|
||||
"""Increment the retry count for this node."""
|
||||
self.retry_count += 1
|
||||
39
dify/api/core/workflow/graph_engine/entities/commands.py
Normal file
39
dify/api/core/workflow/graph_engine/entities/commands.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
GraphEngine command entities for external control.
|
||||
|
||||
This module defines command types that can be sent to a running GraphEngine
|
||||
instance to control its execution flow.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CommandType(StrEnum):
|
||||
"""Types of commands that can be sent to GraphEngine."""
|
||||
|
||||
ABORT = "abort"
|
||||
PAUSE = "pause"
|
||||
|
||||
|
||||
class GraphEngineCommand(BaseModel):
|
||||
"""Base class for all GraphEngine commands."""
|
||||
|
||||
command_type: CommandType = Field(..., description="Type of command")
|
||||
payload: dict[str, Any] | None = Field(default=None, description="Optional command payload")
|
||||
|
||||
|
||||
class AbortCommand(GraphEngineCommand):
|
||||
"""Command to abort a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
|
||||
reason: str | None = Field(default=None, description="Optional reason for abort")
|
||||
|
||||
|
||||
class PauseCommand(GraphEngineCommand):
|
||||
"""Command to pause a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
|
||||
reason: str = Field(default="unknown reason", description="reason for pause")
|
||||
211
dify/api/core/workflow/graph_engine/error_handler.py
Normal file
211
dify/api/core/workflow/graph_engine/error_handler.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Main error handler that coordinates error strategies.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy as ErrorStrategyEnum,
|
||||
)
|
||||
from core.workflow.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetryEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .domain import GraphExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class ErrorHandler:
|
||||
"""
|
||||
Coordinates error handling strategies for node failures.
|
||||
|
||||
This acts as a facade for the various error strategies,
|
||||
selecting and applying the appropriate strategy based on
|
||||
node configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None:
|
||||
"""
|
||||
Initialize the error handler.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
graph_execution: The graph execution state
|
||||
"""
|
||||
self._graph = graph
|
||||
self._graph_execution = graph_execution
|
||||
|
||||
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle a node failure event.
|
||||
|
||||
Selects and applies the appropriate error strategy based on
|
||||
the node's configuration.
|
||||
|
||||
Args:
|
||||
event: The node failure event
|
||||
|
||||
Returns:
|
||||
Optional new event to process, or None to abort
|
||||
"""
|
||||
node = self._graph.nodes[event.node_id]
|
||||
# Get retry count from NodeExecution
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
retry_count = node_execution.retry_count
|
||||
|
||||
# First check if retry is configured and not exhausted
|
||||
if node.retry and retry_count < node.retry_config.max_retries:
|
||||
result = self._handle_retry(event, retry_count)
|
||||
if result:
|
||||
# Retry count will be incremented when NodeRunRetryEvent is handled
|
||||
return result
|
||||
|
||||
# Apply configured error strategy
|
||||
strategy = node.error_strategy
|
||||
|
||||
match strategy:
|
||||
case None:
|
||||
return self._handle_abort(event)
|
||||
case ErrorStrategyEnum.FAIL_BRANCH:
|
||||
return self._handle_fail_branch(event)
|
||||
case ErrorStrategyEnum.DEFAULT_VALUE:
|
||||
return self._handle_default_value(event)
|
||||
|
||||
def _handle_abort(self, event: NodeRunFailedEvent):
|
||||
"""
|
||||
Handle error by aborting execution.
|
||||
|
||||
This is the default strategy when no other strategy is specified.
|
||||
It stops the entire graph execution when a node fails.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
|
||||
Returns:
|
||||
None - signals abortion
|
||||
"""
|
||||
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
|
||||
# Return None to signal that execution should stop
|
||||
|
||||
def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int):
|
||||
"""
|
||||
Handle error by retrying the node.
|
||||
|
||||
This strategy re-attempts node execution up to a configured
|
||||
maximum number of retries with configurable intervals.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
retry_count: Current retry attempt count
|
||||
|
||||
Returns:
|
||||
NodeRunRetryEvent if retry should occur, None otherwise
|
||||
"""
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
# Check if we've exceeded max retries
|
||||
if not node.retry or retry_count >= node.retry_config.max_retries:
|
||||
return None
|
||||
|
||||
# Wait for retry interval
|
||||
time.sleep(node.retry_config.retry_interval_seconds)
|
||||
|
||||
# Create retry event
|
||||
return NodeRunRetryEvent(
|
||||
id=event.id,
|
||||
node_title=node.title,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_run_result=event.node_run_result,
|
||||
start_at=event.start_at,
|
||||
error=event.error,
|
||||
retry_index=retry_count + 1,
|
||||
)
|
||||
|
||||
def _handle_fail_branch(self, event: NodeRunFailedEvent):
|
||||
"""
|
||||
Handle error by taking the fail branch.
|
||||
|
||||
This strategy converts failures to exceptions and routes execution
|
||||
through a designated fail-branch edge.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
|
||||
Returns:
|
||||
NodeRunExceptionEvent to continue via fail branch
|
||||
"""
|
||||
outputs = {
|
||||
"error_message": event.node_run_result.error,
|
||||
"error_type": event.node_run_result.error_type,
|
||||
}
|
||||
|
||||
return NodeRunExceptionEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=outputs,
|
||||
edge_source_handle="fail-branch",
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH,
|
||||
},
|
||||
),
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_default_value(self, event: NodeRunFailedEvent):
|
||||
"""
|
||||
Handle error by using default values.
|
||||
|
||||
This strategy allows nodes to fail gracefully by providing
|
||||
predefined default output values.
|
||||
|
||||
Args:
|
||||
event: The failure event
|
||||
|
||||
Returns:
|
||||
NodeRunExceptionEvent with default values
|
||||
"""
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
outputs = {
|
||||
**node.default_value_dict,
|
||||
"error_message": event.node_run_result.error,
|
||||
"error_type": event.node_run_result.error_type,
|
||||
}
|
||||
|
||||
return NodeRunExceptionEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE,
|
||||
},
|
||||
),
|
||||
error=event.error,
|
||||
)
|
||||
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Event management subsystem for graph engine.
|
||||
|
||||
This package handles event routing, collection, and emission for
|
||||
workflow graph execution events.
|
||||
"""
|
||||
|
||||
from .event_handlers import EventHandler
|
||||
from .event_manager import EventManager
|
||||
|
||||
__all__ = [
|
||||
"EventHandler",
|
||||
"EventManager",
|
||||
]
|
||||
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Event handler implementations for different event types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..error_handler import ErrorHandler
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..graph_traversal import EdgeProcessor
|
||||
from .event_manager import EventManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class EventHandler:
|
||||
"""
|
||||
Registry of event handlers for different event types.
|
||||
|
||||
This centralizes the business logic for handling specific events,
|
||||
keeping it separate from the routing and collection infrastructure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_execution: GraphExecution,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
event_collector: "EventManager",
|
||||
edge_processor: "EdgeProcessor",
|
||||
state_manager: "GraphStateManager",
|
||||
error_handler: "ErrorHandler",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the event handler registry.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
graph_runtime_state: Runtime state with variable pool
|
||||
graph_execution: Graph execution aggregate
|
||||
response_coordinator: Response stream coordinator
|
||||
event_collector: Event manager for collecting events
|
||||
edge_processor: Edge processor for edge traversal
|
||||
state_manager: Unified state manager
|
||||
error_handler: Error handler
|
||||
"""
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_execution = graph_execution
|
||||
self._response_coordinator = response_coordinator
|
||||
self._event_collector = event_collector
|
||||
self._edge_processor = edge_processor
|
||||
self._state_manager = state_manager
|
||||
self._error_handler = error_handler
|
||||
|
||||
def dispatch(self, event: GraphNodeEventBase) -> None:
|
||||
"""
|
||||
Handle any node event by dispatching to the appropriate handler.
|
||||
|
||||
Args:
|
||||
event: The event to handle
|
||||
"""
|
||||
# Events in loops or iterations are always collected
|
||||
if event.in_loop_id or event.in_iteration_id:
|
||||
self._event_collector.collect(event)
|
||||
return
|
||||
return self._dispatch(event)
|
||||
|
||||
@singledispatchmethod
|
||||
def _dispatch(self, event: GraphNodeEventBase) -> None:
|
||||
self._event_collector.collect(event)
|
||||
logger.warning("Unhandled event type: %s", type(event).__name__)
|
||||
|
||||
@_dispatch.register(NodeRunIterationStartedEvent)
|
||||
@_dispatch.register(NodeRunIterationNextEvent)
|
||||
@_dispatch.register(NodeRunIterationSucceededEvent)
|
||||
@_dispatch.register(NodeRunIterationFailedEvent)
|
||||
@_dispatch.register(NodeRunLoopStartedEvent)
|
||||
@_dispatch.register(NodeRunLoopNextEvent)
|
||||
@_dispatch.register(NodeRunLoopSucceededEvent)
|
||||
@_dispatch.register(NodeRunLoopFailedEvent)
|
||||
@_dispatch.register(NodeRunAgentLogEvent)
|
||||
@_dispatch.register(NodeRunRetrieverResourceEvent)
|
||||
def _(self, event: GraphNodeEventBase) -> None:
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Handle node started event.
|
||||
|
||||
Args:
|
||||
event: The node started event
|
||||
"""
|
||||
# Track execution in domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
node_execution.mark_started(event.id)
|
||||
self._graph_runtime_state.increment_node_run_steps()
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event only for the first attempt; retries remain silent
|
||||
if is_initial_attempt:
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Handle stream chunk event with full processing.
|
||||
|
||||
Args:
|
||||
event: The stream chunk event
|
||||
"""
|
||||
# Process with response coordinator
|
||||
streaming_events = list(self._response_coordinator.intercept_event(event))
|
||||
|
||||
# Collect all events
|
||||
for stream_event in streaming_events:
|
||||
self._event_collector.collect(stream_event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Handle node success by coordinating subsystems.
|
||||
|
||||
This method coordinates between different subsystems to process
|
||||
node completion, handle edges, and trigger downstream execution.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
# Forward to response coordinator and emit streaming events
|
||||
streaming_events = self._response_coordinator.intercept_event(event)
|
||||
for stream_event in streaming_events:
|
||||
self._event_collector.collect(stream_event)
|
||||
|
||||
# Process edges and get ready nodes
|
||||
node = self._graph.nodes[event.node_id]
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
|
||||
# Collect streaming events from edge processing
|
||||
for edge_event in edge_streaming_events:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
# Enqueue ready nodes
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update execution tracking
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
"""Handle pause requests emitted by nodes."""
|
||||
|
||||
pause_reason = event.reason
|
||||
self._graph_execution.pause(pause_reason)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
if event.node_id in self._graph.nodes:
|
||||
self._graph.nodes[event.node_id].state = NodeState.UNKNOWN
|
||||
self._graph_runtime_state.register_paused_node(event.node_id)
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Handle node failure using error handler.
|
||||
|
||||
Args:
|
||||
event: The node failed event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
self._graph_execution.record_node_failure()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
if result:
|
||||
# Process the resulting event (retry, exception, etc.)
|
||||
self.dispatch(result)
|
||||
else:
|
||||
# Abort execution
|
||||
self._graph_execution.fail(RuntimeError(event.error))
|
||||
self._event_collector.collect(event)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunExceptionEvent) -> None:
|
||||
"""
|
||||
Handle node exception event (fail-branch strategy).
|
||||
|
||||
Args:
|
||||
event: The node exception event
|
||||
"""
|
||||
# Node continues via fail-branch/default-value, treat as completion
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Persist outputs produced by the exception strategy (e.g. default values)
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
node = self._graph.nodes[event.node_id]
|
||||
|
||||
if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")
|
||||
|
||||
for edge_event in edge_streaming_events:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update response outputs if applicable
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event.node_run_result.outputs)
|
||||
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Collect the exception event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunRetryEvent) -> None:
|
||||
"""
|
||||
Handle node retry event.
|
||||
|
||||
Args:
|
||||
event: The node retry event
|
||||
"""
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.increment_retry()
|
||||
|
||||
# Finish the previous attempt before re-queuing the node
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
# Emit retry event for observers
|
||||
self._event_collector.collect(event)
|
||||
|
||||
# Re-queue node for execution
|
||||
self._state_manager.enqueue_node(event.node_id)
|
||||
self._state_manager.start_execution(event.node_id)
|
||||
|
||||
def _accumulate_node_usage(self, usage: LLMUsage) -> None:
|
||||
"""Accumulate token usage into the shared runtime state."""
|
||||
if usage.total_tokens <= 0:
|
||||
return
|
||||
|
||||
self._graph_runtime_state.add_tokens(usage.total_tokens)
|
||||
|
||||
current_usage = self._graph_runtime_state.llm_usage
|
||||
if current_usage.total_tokens == 0:
|
||||
self._graph_runtime_state.llm_usage = usage
|
||||
else:
|
||||
self._graph_runtime_state.llm_usage = current_usage.plus(usage)
|
||||
|
||||
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
|
||||
"""
|
||||
Store node outputs in the variable pool.
|
||||
|
||||
Args:
|
||||
event: The node succeeded event containing outputs
|
||||
"""
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
|
||||
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
|
||||
"""Update response outputs for response nodes."""
|
||||
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
|
||||
# in runtime state, rather than allowing nodes to directly access runtime state.
|
||||
for key, value in outputs.items():
|
||||
if key == "answer":
|
||||
existing = self._graph_runtime_state.get_output("answer", "")
|
||||
if existing:
|
||||
self._graph_runtime_state.set_output("answer", f"{existing}{value}")
|
||||
else:
|
||||
self._graph_runtime_state.set_output("answer", value)
|
||||
else:
|
||||
self._graph_runtime_state.set_output(key, value)
|
||||
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Unified event manager for collecting and emitting events.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
from ..layers.base import GraphEngineLayer
|
||||
|
||||
|
||||
@final
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock implementation that allows multiple concurrent readers
|
||||
but only one writer at a time.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._read_ready = threading.Condition(threading.RLock())
|
||||
self._readers = 0
|
||||
|
||||
def acquire_read(self) -> None:
|
||||
"""Acquire a read lock."""
|
||||
_ = self._read_ready.acquire()
|
||||
try:
|
||||
self._readers += 1
|
||||
finally:
|
||||
self._read_ready.release()
|
||||
|
||||
def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
_ = self._read_ready.acquire()
|
||||
try:
|
||||
self._readers -= 1
|
||||
if self._readers == 0:
|
||||
self._read_ready.notify_all()
|
||||
finally:
|
||||
self._read_ready.release()
|
||||
|
||||
def acquire_write(self) -> None:
|
||||
"""Acquire a write lock."""
|
||||
_ = self._read_ready.acquire()
|
||||
while self._readers > 0:
|
||||
_ = self._read_ready.wait()
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release a write lock."""
|
||||
self._read_ready.release()
|
||||
|
||||
@contextmanager
|
||||
def read_lock(self):
|
||||
"""Return a context manager for read locking."""
|
||||
self.acquire_read()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.release_read()
|
||||
|
||||
@contextmanager
|
||||
def write_lock(self):
|
||||
"""Return a context manager for write locking."""
|
||||
self.acquire_write()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.release_write()
|
||||
|
||||
|
||||
@final
|
||||
class EventManager:
|
||||
"""
|
||||
Unified event manager that collects, buffers, and emits events.
|
||||
|
||||
This class combines event collection with event emission, providing
|
||||
thread-safe event management with support for notifying layers and
|
||||
streaming events to external consumers.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the event manager."""
|
||||
self._events: list[GraphEngineEvent] = []
|
||||
self._lock = ReadWriteLock()
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
self._execution_complete = threading.Event()
|
||||
|
||||
def set_layers(self, layers: list[GraphEngineLayer]) -> None:
|
||||
"""
|
||||
Set the layers to notify on event collection.
|
||||
|
||||
Args:
|
||||
layers: List of layers to notify
|
||||
"""
|
||||
self._layers = layers
|
||||
|
||||
def notify_layers(self, event: GraphEngineEvent) -> None:
|
||||
"""Notify registered layers about an event without buffering it."""
|
||||
self._notify_layers(event)
|
||||
|
||||
def collect(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Thread-safe method to collect an event.
|
||||
|
||||
Args:
|
||||
event: The event to collect
|
||||
"""
|
||||
with self._lock.write_lock():
|
||||
self._events.append(event)
|
||||
self._notify_layers(event)
|
||||
|
||||
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
|
||||
"""
|
||||
Get new events starting from a specific index.
|
||||
|
||||
Args:
|
||||
start_index: The index to start from
|
||||
|
||||
Returns:
|
||||
List of new events
|
||||
"""
|
||||
with self._lock.read_lock():
|
||||
return list(self._events[start_index:])
|
||||
|
||||
def _event_count(self) -> int:
|
||||
"""
|
||||
Get the current count of collected events.
|
||||
|
||||
Returns:
|
||||
Number of collected events
|
||||
"""
|
||||
with self._lock.read_lock():
|
||||
return len(self._events)
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete to stop the event emission generator."""
|
||||
self._execution_complete.set()
|
||||
|
||||
def emit_events(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generator that yields events as they're collected.
|
||||
|
||||
Yields:
|
||||
GraphEngineEvent instances as they're processed
|
||||
"""
|
||||
yielded_count = 0
|
||||
|
||||
while not self._execution_complete.is_set() or yielded_count < self._event_count():
|
||||
# Get new events since last yield
|
||||
new_events = self._get_new_events(yielded_count)
|
||||
|
||||
# Yield any new events
|
||||
for event in new_events:
|
||||
yield event
|
||||
yielded_count += 1
|
||||
|
||||
# Small sleep to avoid busy waiting
|
||||
if not self._execution_complete.is_set() and not new_events:
|
||||
time.sleep(0.001)
|
||||
|
||||
def _notify_layers(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Notify all layers of an event.
|
||||
|
||||
Layer exceptions are caught and logged to prevent disrupting collection.
|
||||
|
||||
Args:
|
||||
event: The event to send to layers
|
||||
"""
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_event(event)
|
||||
except Exception:
|
||||
# Silently ignore layer errors during collection
|
||||
pass
|
||||
362
dify/api/core/workflow/graph_engine/graph_engine.py
Normal file
362
dify/api/core/workflow/graph_engine/graph_engine.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution.
|
||||
|
||||
This engine uses a modular architecture with separated packages following
|
||||
Domain-Driven Design principles for improved maintainability and testability.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphNodeEventBase,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
|
||||
from core.workflow.runtime.graph_runtime_state import GraphProtocol
|
||||
|
||||
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
|
||||
from .entities.commands import AbortCommand, PauseCommand
|
||||
from .error_handler import ErrorHandler
|
||||
from .event_management import EventHandler, EventManager
|
||||
from .graph_state_manager import GraphStateManager
|
||||
from .graph_traversal import EdgeProcessor, SkipPropagator
|
||||
from .layers.base import GraphEngineLayer
|
||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||
from .protocols.command_channel import CommandChannel
|
||||
from .ready_queue import ReadyQueue
|
||||
from .worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class GraphEngine:
|
||||
"""
|
||||
Queue-based graph execution engine.
|
||||
|
||||
Uses a modular architecture that delegates responsibilities to specialized
|
||||
subsystems, following Domain-Driven Design and SOLID principles.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workflow_id: str,
|
||||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
command_channel: CommandChannel,
|
||||
min_workers: int | None = None,
|
||||
max_workers: int | None = None,
|
||||
scale_up_threshold: int | None = None,
|
||||
scale_down_idle_time: float | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
|
||||
# Bind runtime state to current workflow context
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
|
||||
self._graph_execution.workflow_id = workflow_id
|
||||
|
||||
# === Worker Management Parameters ===
|
||||
# Parameters for dynamic worker pool scaling
|
||||
self._min_workers = min_workers
|
||||
self._max_workers = max_workers
|
||||
self._scale_up_threshold = scale_up_threshold
|
||||
self._scale_down_idle_time = scale_down_idle_time
|
||||
|
||||
# === Execution Queues ===
|
||||
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
|
||||
|
||||
# Queue for events generated during execution
|
||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
|
||||
# === State Management ===
|
||||
# Unified state manager handles all node state transitions and queue operations
|
||||
self._state_manager = GraphStateManager(self._graph, self._ready_queue)
|
||||
|
||||
# === Response Coordination ===
|
||||
# Coordinates response streaming from response nodes
|
||||
self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator)
|
||||
|
||||
# === Event Management ===
|
||||
# Event manager handles both collection and emission of events
|
||||
self._event_manager = EventManager()
|
||||
|
||||
# === Error Handling ===
|
||||
# Centralized error handler for graph execution errors
|
||||
self._error_handler = ErrorHandler(self._graph, self._graph_execution)
|
||||
|
||||
# === Graph Traversal Components ===
|
||||
# Propagates skip status through the graph when conditions aren't met
|
||||
self._skip_propagator = SkipPropagator(
|
||||
graph=self._graph,
|
||||
state_manager=self._state_manager,
|
||||
)
|
||||
|
||||
# Processes edges to determine next nodes after execution
|
||||
# Also handles conditional branching and route selection
|
||||
self._edge_processor = EdgeProcessor(
|
||||
graph=self._graph,
|
||||
state_manager=self._state_manager,
|
||||
response_coordinator=self._response_coordinator,
|
||||
skip_propagator=self._skip_propagator,
|
||||
)
|
||||
|
||||
# === Command Processing ===
|
||||
# Processes external commands (e.g., abort requests)
|
||||
self._command_processor = CommandProcessor(
|
||||
command_channel=self._command_channel,
|
||||
graph_execution=self._graph_execution,
|
||||
)
|
||||
|
||||
# Register command handlers
|
||||
abort_handler = AbortCommandHandler()
|
||||
self._command_processor.register_handler(AbortCommand, abort_handler)
|
||||
|
||||
pause_handler = PauseCommandHandler()
|
||||
self._command_processor.register_handler(PauseCommand, pause_handler)
|
||||
|
||||
# === Worker Pool Setup ===
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app: Flask | None = None
|
||||
try:
|
||||
app = current_app._get_current_object() # type: ignore
|
||||
if isinstance(app, Flask):
|
||||
flask_app = app
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# Capture context variables for worker threads
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Create worker pool for parallel node execution
|
||||
self._worker_pool = WorkerPool(
|
||||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
flask_app=flask_app,
|
||||
context_vars=context_vars,
|
||||
min_workers=self._min_workers,
|
||||
max_workers=self._max_workers,
|
||||
scale_up_threshold=self._scale_up_threshold,
|
||||
scale_down_idle_time=self._scale_down_idle_time,
|
||||
)
|
||||
|
||||
# === Orchestration ===
|
||||
# Coordinates the overall execution lifecycle
|
||||
self._execution_coordinator = ExecutionCoordinator(
|
||||
graph_execution=self._graph_execution,
|
||||
state_manager=self._state_manager,
|
||||
command_processor=self._command_processor,
|
||||
worker_pool=self._worker_pool,
|
||||
)
|
||||
|
||||
# === Event Handler Registry ===
|
||||
# Central registry for handling all node execution events
|
||||
self._event_handler_registry = EventHandler(
|
||||
graph=self._graph,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
graph_execution=self._graph_execution,
|
||||
response_coordinator=self._response_coordinator,
|
||||
event_collector=self._event_manager,
|
||||
edge_processor=self._edge_processor,
|
||||
state_manager=self._state_manager,
|
||||
error_handler=self._error_handler,
|
||||
)
|
||||
|
||||
# Dispatches events and manages execution flow
|
||||
self._dispatcher = Dispatcher(
|
||||
event_queue=self._event_queue,
|
||||
event_handler=self._event_handler_registry,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
event_emitter=self._event_manager,
|
||||
)
|
||||
|
||||
# === Extensibility ===
|
||||
# Layers allow plugins to extend engine functionality
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
|
||||
# === Validation ===
|
||||
# Ensure all nodes share the same GraphRuntimeState instance
|
||||
self._validate_graph_state_consistency()
|
||||
|
||||
def _validate_graph_state_consistency(self) -> None:
|
||||
"""Validate that all nodes share the same GraphRuntimeState."""
|
||||
expected_state_id = id(self._graph_runtime_state)
|
||||
for node in self._graph.nodes.values():
|
||||
if id(node.graph_runtime_state) != expected_state_id:
|
||||
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
|
||||
|
||||
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
|
||||
"""Add a layer for extending functionality."""
|
||||
self._layers.append(layer)
|
||||
return self
|
||||
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Execute the graph using the modular architecture.
|
||||
|
||||
Returns:
|
||||
Generator yielding GraphEngineEvent instances
|
||||
"""
|
||||
try:
|
||||
# Initialize layers
|
||||
self._initialize_layers()
|
||||
|
||||
is_resume = self._graph_execution.started
|
||||
if not is_resume:
|
||||
self._graph_execution.start()
|
||||
else:
|
||||
self._graph_execution.paused = False
|
||||
self._graph_execution.pause_reason = None
|
||||
|
||||
start_event = GraphRunStartedEvent()
|
||||
self._event_manager.notify_layers(start_event)
|
||||
yield start_event
|
||||
|
||||
# Start subsystems
|
||||
self._start_execution(resume=is_resume)
|
||||
|
||||
# Yield events as they occur
|
||||
yield from self._event_manager.emit_events()
|
||||
|
||||
# Handle completion
|
||||
if self._graph_execution.is_paused:
|
||||
pause_reason = self._graph_execution.pause_reason
|
||||
assert pause_reason is not None, "pause_reason should not be None when execution is paused."
|
||||
# Ensure we have a valid PauseReason for the event
|
||||
paused_event = GraphRunPausedEvent(
|
||||
reason=pause_reason,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(paused_event)
|
||||
yield paused_event
|
||||
elif self._graph_execution.aborted:
|
||||
abort_reason = "Workflow execution aborted by user command"
|
||||
if self._graph_execution.error:
|
||||
abort_reason = str(self._graph_execution.error)
|
||||
aborted_event = GraphRunAbortedEvent(
|
||||
reason=abort_reason,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(aborted_event)
|
||||
yield aborted_event
|
||||
elif self._graph_execution.has_error:
|
||||
if self._graph_execution.error:
|
||||
raise self._graph_execution.error
|
||||
else:
|
||||
outputs = self._graph_runtime_state.outputs
|
||||
exceptions_count = self._graph_execution.exceptions_count
|
||||
if exceptions_count > 0:
|
||||
partial_event = GraphRunPartialSucceededEvent(
|
||||
exceptions_count=exceptions_count,
|
||||
outputs=outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(partial_event)
|
||||
yield partial_event
|
||||
else:
|
||||
succeeded_event = GraphRunSucceededEvent(
|
||||
outputs=outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(succeeded_event)
|
||||
yield succeeded_event
|
||||
|
||||
except Exception as e:
|
||||
failed_event = GraphRunFailedEvent(
|
||||
error=str(e),
|
||||
exceptions_count=self._graph_execution.exceptions_count,
|
||||
)
|
||||
self._event_manager.notify_layers(failed_event)
|
||||
yield failed_event
|
||||
raise
|
||||
|
||||
finally:
|
||||
self._stop_execution()
|
||||
|
||||
def _initialize_layers(self) -> None:
|
||||
"""Initialize layers with context."""
|
||||
self._event_manager.set_layers(self._layers)
|
||||
# Create a read-only wrapper for the runtime state
|
||||
read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.initialize(read_only_state, self._command_channel)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
|
||||
|
||||
try:
|
||||
layer.on_graph_start()
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
|
||||
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
paused_nodes: list[str] = []
|
||||
if resume:
|
||||
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
||||
|
||||
# Start worker pool (it calculates initial workers internally)
|
||||
self._worker_pool.start()
|
||||
|
||||
# Register response nodes
|
||||
for node in self._graph.nodes.values():
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._response_coordinator.register(node.id)
|
||||
|
||||
if not resume:
|
||||
# Enqueue root node
|
||||
root_node = self._graph.root_node
|
||||
self._state_manager.enqueue_node(root_node.id)
|
||||
self._state_manager.start_execution(root_node.id)
|
||||
else:
|
||||
for node_id in paused_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Start dispatcher
|
||||
self._dispatcher.start()
|
||||
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self._dispatcher.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
# Notify layers
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_end(self._graph_execution.error)
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
|
||||
|
||||
# Public property accessors for attributes that need external access
|
||||
@property
|
||||
def graph_runtime_state(self) -> GraphRuntimeState:
|
||||
"""Get the graph runtime state."""
|
||||
return self._graph_runtime_state
|
||||
288
dify/api/core/workflow/graph_engine/graph_state_manager.py
Normal file
288
dify/api/core/workflow/graph_engine/graph_state_manager.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Graph state manager that combines node, edge, and execution tracking.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict, final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
|
||||
class EdgeStateAnalysis(TypedDict):
|
||||
"""Analysis result for edge states."""
|
||||
|
||||
has_unknown: bool
|
||||
has_taken: bool
|
||||
all_skipped: bool
|
||||
|
||||
|
||||
@final
|
||||
class GraphStateManager:
|
||||
def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None:
|
||||
"""
|
||||
Initialize the state manager.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
ready_queue: Queue for nodes ready to execute
|
||||
"""
|
||||
self._graph = graph
|
||||
self._ready_queue = ready_queue
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Execution tracking state
|
||||
self._executing_nodes: set[str] = set()
|
||||
|
||||
# ============= Node State Operations =============
|
||||
|
||||
def enqueue_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as TAKEN and add it to the ready queue.
|
||||
|
||||
This combines the state transition and enqueueing operations
|
||||
that always occur together when preparing a node for execution.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to enqueue
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.nodes[node_id].state = NodeState.TAKEN
|
||||
self._ready_queue.put(node_id)
|
||||
|
||||
def mark_node_skipped(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as SKIPPED.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.nodes[node_id].state = NodeState.SKIPPED
|
||||
|
||||
def is_node_ready(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is ready to be executed.
|
||||
|
||||
A node is ready when all its incoming edges from taken branches
|
||||
have been satisfied.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is ready for execution
|
||||
"""
|
||||
with self._lock:
|
||||
# Get all incoming edges to this node
|
||||
incoming_edges = self._graph.get_incoming_edges(node_id)
|
||||
|
||||
# If no incoming edges, node is always ready
|
||||
if not incoming_edges:
|
||||
return True
|
||||
|
||||
# If any edge is UNKNOWN, node is not ready
|
||||
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
|
||||
return False
|
||||
|
||||
# Node is ready if at least one edge is TAKEN
|
||||
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
|
||||
|
||||
def get_node_state(self, node_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
|
||||
Returns:
|
||||
The current node state
|
||||
"""
|
||||
with self._lock:
|
||||
return self._graph.nodes[node_id].state
|
||||
|
||||
# ============= Edge State Operations =============
|
||||
|
||||
def mark_edge_taken(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as TAKEN.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.edges[edge_id].state = NodeState.TAKEN
|
||||
|
||||
def mark_edge_skipped(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as SKIPPED.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self._graph.edges[edge_id].state = NodeState.SKIPPED
|
||||
|
||||
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
|
||||
"""
|
||||
Analyze the states of edges and return summary flags.
|
||||
|
||||
Args:
|
||||
edges: List of edges to analyze
|
||||
|
||||
Returns:
|
||||
Analysis result with state flags
|
||||
"""
|
||||
with self._lock:
|
||||
states = {edge.state for edge in edges}
|
||||
|
||||
return EdgeStateAnalysis(
|
||||
has_unknown=NodeState.UNKNOWN in states,
|
||||
has_taken=NodeState.TAKEN in states,
|
||||
all_skipped=states == {NodeState.SKIPPED} if states else True,
|
||||
)
|
||||
|
||||
def get_edge_state(self, edge_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of an edge.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge
|
||||
|
||||
Returns:
|
||||
The current edge state
|
||||
"""
|
||||
with self._lock:
|
||||
return self._graph.edges[edge_id].state
|
||||
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
|
||||
"""
|
||||
Categorize branch edges into selected and unselected.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
A tuple of (selected_edges, unselected_edges)
|
||||
"""
|
||||
with self._lock:
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
selected_edges: list[Edge] = []
|
||||
unselected_edges: list[Edge] = []
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.source_handle == selected_handle:
|
||||
selected_edges.append(edge)
|
||||
else:
|
||||
unselected_edges.append(edge)
|
||||
|
||||
return selected_edges, unselected_edges
|
||||
|
||||
# ============= Execution Tracking Operations =============
|
||||
|
||||
def start_execution(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node starting execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.add(node_id)
|
||||
|
||||
def finish_execution(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as no longer executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node finishing execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.discard(node_id)
|
||||
|
||||
def is_executing(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is currently executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is executing
|
||||
"""
|
||||
with self._lock:
|
||||
return node_id in self._executing_nodes
|
||||
|
||||
def get_executing_count(self) -> int:
|
||||
"""
|
||||
Get the count of currently executing nodes.
|
||||
|
||||
Returns:
|
||||
Number of executing nodes
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self._executing_nodes)
|
||||
|
||||
def get_executing_nodes(self) -> set[str]:
|
||||
"""
|
||||
Get a copy of the set of executing node IDs.
|
||||
|
||||
Returns:
|
||||
Set of node IDs currently executing
|
||||
"""
|
||||
with self._lock:
|
||||
return self._executing_nodes.copy()
|
||||
|
||||
def clear_executing(self) -> None:
|
||||
"""Clear all executing nodes."""
|
||||
with self._lock:
|
||||
self._executing_nodes.clear()
|
||||
|
||||
# ============= Composite Operations =============
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
"""
|
||||
Check if graph execution is complete.
|
||||
|
||||
Execution is complete when:
|
||||
- Ready queue is empty
|
||||
- No nodes are executing
|
||||
|
||||
Returns:
|
||||
True if execution is complete
|
||||
"""
|
||||
with self._lock:
|
||||
return self._ready_queue.empty() and len(self._executing_nodes) == 0
|
||||
|
||||
def get_queue_depth(self) -> int:
|
||||
"""
|
||||
Get the current depth of the ready queue.
|
||||
|
||||
Returns:
|
||||
Number of nodes in the ready queue
|
||||
"""
|
||||
return self._ready_queue.qsize()
|
||||
|
||||
def get_execution_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get execution statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with execution statistics
|
||||
"""
|
||||
with self._lock:
|
||||
taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN)
|
||||
skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED)
|
||||
unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN)
|
||||
|
||||
return {
|
||||
"queue_depth": self._ready_queue.qsize(),
|
||||
"executing": len(self._executing_nodes),
|
||||
"taken_nodes": taken_nodes,
|
||||
"skipped_nodes": skipped_nodes,
|
||||
"unknown_nodes": unknown_nodes,
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Graph traversal subsystem for graph engine.
|
||||
|
||||
This package handles graph navigation, edge processing,
|
||||
and skip propagation logic.
|
||||
"""
|
||||
|
||||
from .edge_processor import EdgeProcessor
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
__all__ = [
|
||||
"EdgeProcessor",
|
||||
"SkipPropagator",
|
||||
]
|
||||
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Edge processing logic for graph traversal.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Edge, Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
|
||||
@final
|
||||
class EdgeProcessor:
|
||||
"""
|
||||
Processes edges during graph execution.
|
||||
|
||||
This handles marking edges as taken or skipped, notifying
|
||||
the response coordinator, triggering downstream node execution,
|
||||
and managing branch node logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
state_manager: GraphStateManager,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
skip_propagator: "SkipPropagator",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the edge processor.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
state_manager: Unified state manager
|
||||
response_coordinator: Response stream coordinator
|
||||
skip_propagator: Propagator for skip states
|
||||
"""
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
self._response_coordinator = response_coordinator
|
||||
self._skip_propagator = skip_propagator
|
||||
|
||||
def process_node_success(
|
||||
self, node_id: str, selected_handle: str | None = None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges after a node succeeds.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the succeeded node
|
||||
selected_handle: For branch nodes, the selected edge handle
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
|
||||
"""
|
||||
node = self._graph.nodes[node_id]
|
||||
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
return self._process_branch_node_edges(node_id, selected_handle)
|
||||
else:
|
||||
return self._process_non_branch_node_edges(node_id)
|
||||
|
||||
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for non-branch nodes (mark all as TAKEN).
|
||||
|
||||
Args:
|
||||
node_id: The ID of the succeeded node
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
"""
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
nodes, events = self._process_taken_edge(edge)
|
||||
ready_nodes.extend(nodes)
|
||||
all_streaming_events.extend(events)
|
||||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_branch_node_edges(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for branch nodes.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no edge was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} did not select any edge")
|
||||
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Categorize edges
|
||||
selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Process unselected edges first (mark as skipped)
|
||||
for edge in unselected_edges:
|
||||
self._process_skipped_edge(edge)
|
||||
|
||||
# Process selected edges
|
||||
for edge in selected_edges:
|
||||
nodes, events = self._process_taken_edge(edge)
|
||||
ready_nodes.extend(nodes)
|
||||
all_streaming_events.extend(events)
|
||||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Mark edge as taken and check downstream node.
|
||||
|
||||
Args:
|
||||
edge: The edge to process
|
||||
|
||||
Returns:
|
||||
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
|
||||
"""
|
||||
# Mark edge as taken
|
||||
self._state_manager.mark_edge_taken(edge.id)
|
||||
|
||||
# Notify response coordinator and get streaming events
|
||||
streaming_events = self._response_coordinator.on_edge_taken(edge.id)
|
||||
|
||||
# Check if downstream node is ready
|
||||
ready_nodes: list[str] = []
|
||||
if self._state_manager.is_node_ready(edge.head):
|
||||
ready_nodes.append(edge.head)
|
||||
|
||||
return ready_nodes, streaming_events
|
||||
|
||||
def _process_skipped_edge(self, edge: Edge) -> None:
|
||||
"""
|
||||
Mark edge as skipped.
|
||||
|
||||
Args:
|
||||
edge: The edge to skip
|
||||
"""
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
|
||||
def handle_branch_completion(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Handle completion of a branch node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected branch
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no branch was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
|
||||
|
||||
# Categorize edges into selected and unselected
|
||||
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Skip all unselected paths
|
||||
self._skip_propagator.skip_branch_paths(unselected_edges)
|
||||
|
||||
# Process selected edges and get ready nodes and streaming events
|
||||
return self.process_node_success(node_id, selected_handle)
|
||||
|
||||
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
|
||||
"""
|
||||
Validate that a branch selection is valid.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle to validate
|
||||
|
||||
Returns:
|
||||
True if the selection is valid
|
||||
"""
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
valid_handles = {edge.source_handle for edge in outgoing_edges}
|
||||
return selected_handle in valid_handles
|
||||
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Skip state propagation through the graph.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
|
||||
|
||||
@final
|
||||
class SkipPropagator:
|
||||
"""
|
||||
Propagates skip states through the graph.
|
||||
|
||||
When a node is skipped, this ensures all downstream nodes
|
||||
that depend solely on it are also skipped.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
state_manager: GraphStateManager,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the skip propagator.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
state_manager: Unified state manager
|
||||
"""
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
|
||||
def propagate_skip_from_edge(self, edge_id: str) -> None:
|
||||
"""
|
||||
Recursively propagate skip state from a skipped edge.
|
||||
|
||||
Rules:
|
||||
- If a node has any UNKNOWN incoming edges, stop processing
|
||||
- If all incoming edges are SKIPPED, skip the node and its edges
|
||||
- If any incoming edge is TAKEN, the node may still execute
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the skipped edge to start from
|
||||
"""
|
||||
downstream_node_id = self._graph.edges[edge_id].head
|
||||
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
|
||||
|
||||
# Analyze edge states
|
||||
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
|
||||
|
||||
# Stop if there are unknown edges (not yet processed)
|
||||
if edge_states["has_unknown"]:
|
||||
return
|
||||
|
||||
# If any edge is taken, node may still execute
|
||||
if edge_states["has_taken"]:
|
||||
# Enqueue node
|
||||
self._state_manager.enqueue_node(downstream_node_id)
|
||||
return
|
||||
|
||||
# All edges are skipped, propagate skip to this node
|
||||
if edge_states["all_skipped"]:
|
||||
self._propagate_skip_to_node(downstream_node_id)
|
||||
|
||||
def _propagate_skip_to_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node and all its outgoing edges as skipped.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
# Mark node as skipped
|
||||
self._state_manager.mark_node_skipped(node_id)
|
||||
|
||||
# Mark all outgoing edges as skipped and propagate
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
for edge in outgoing_edges:
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
# Recursively propagate skip
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
|
||||
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
|
||||
"""
|
||||
Skip all paths from unselected branch edges.
|
||||
|
||||
Args:
|
||||
unselected_edges: List of edges not taken by the branch
|
||||
"""
|
||||
for edge in unselected_edges:
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
52
dify/api/core/workflow/graph_engine/layers/README.md
Normal file
52
dify/api/core/workflow/graph_engine/layers/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# Layers
|
||||
|
||||
Pluggable middleware for engine extensions.
|
||||
|
||||
## Components
|
||||
|
||||
### Layer (base)
|
||||
|
||||
Abstract base class for layers.
|
||||
|
||||
- `initialize()` - Receive runtime context
|
||||
- `on_graph_start()` - Execution start hook
|
||||
- `on_event()` - Process all events
|
||||
- `on_graph_end()` - Execution end hook
|
||||
|
||||
### DebugLoggingLayer
|
||||
|
||||
Comprehensive execution logging.
|
||||
|
||||
- Configurable detail levels
|
||||
- Tracks execution statistics
|
||||
- Truncates long values
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
debug_layer = DebugLoggingLayer(
|
||||
level="INFO",
|
||||
include_outputs=True
|
||||
)
|
||||
|
||||
engine = GraphEngine(graph)
|
||||
engine.layer(debug_layer)
|
||||
engine.run()
|
||||
```
|
||||
|
||||
## Custom Layers
|
||||
|
||||
```python
|
||||
class MetricsLayer(Layer):
|
||||
def on_event(self, event):
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
self.metrics[event.node_id] = event.elapsed_time
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
**DebugLoggingLayer Options:**
|
||||
|
||||
- `level` - Log level (INFO, DEBUG, ERROR)
|
||||
- `include_inputs/outputs` - Log data values
|
||||
- `max_value_length` - Truncate long values
|
||||
16
dify/api/core/workflow/graph_engine/layers/__init__.py
Normal file
16
dify/api/core/workflow/graph_engine/layers/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Layer system for GraphEngine extensibility.
|
||||
|
||||
This module provides the layer infrastructure for extending GraphEngine functionality
|
||||
with middleware-like components that can observe events and interact with execution.
|
||||
"""
|
||||
|
||||
from .base import GraphEngineLayer
|
||||
from .debug_logging import DebugLoggingLayer
|
||||
from .execution_limits import ExecutionLimitsLayer
|
||||
|
||||
__all__ = [
|
||||
"DebugLoggingLayer",
|
||||
"ExecutionLimitsLayer",
|
||||
"GraphEngineLayer",
|
||||
]
|
||||
85
dify/api/core/workflow/graph_engine/layers/base.py
Normal file
85
dify/api/core/workflow/graph_engine/layers/base.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Base layer class for GraphEngine extensions.
|
||||
|
||||
This module provides the abstract base class for implementing layers that can
|
||||
intercept and respond to GraphEngine events.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
from core.workflow.runtime import ReadOnlyGraphRuntimeState
|
||||
|
||||
|
||||
class GraphEngineLayer(ABC):
|
||||
"""
|
||||
Abstract base class for GraphEngine layers.
|
||||
|
||||
Layers are middleware-like components that can:
|
||||
- Observe all events emitted by the GraphEngine
|
||||
- Access the graph runtime state
|
||||
- Send commands to control execution
|
||||
|
||||
Subclasses should override the constructor to accept configuration parameters,
|
||||
then implement the three lifecycle methods.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the layer. Subclasses can override with custom parameters."""
|
||||
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
|
||||
self.command_channel: CommandChannel | None = None
|
||||
|
||||
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
|
||||
"""
|
||||
Initialize the layer with engine dependencies.
|
||||
|
||||
Called by GraphEngine before execution starts to inject the read-only runtime state
|
||||
and command channel. This allows layers to observe engine context and send
|
||||
commands, but prevents direct state modification.
|
||||
|
||||
Args:
|
||||
graph_runtime_state: Read-only view of the runtime state
|
||||
command_channel: Channel for sending commands to the engine
|
||||
"""
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.command_channel = command_channel
|
||||
|
||||
@abstractmethod
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Called when graph execution starts.
|
||||
|
||||
This is called after the engine has been initialized but before any nodes
|
||||
are executed. Layers can use this to set up resources or log start information.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
This method receives all events generated during graph execution, including:
|
||||
- Graph lifecycle events (start, success, failure)
|
||||
- Node execution events (start, success, failure, retry)
|
||||
- Stream events for response nodes
|
||||
- Container events (iteration, loop)
|
||||
|
||||
Args:
|
||||
event: The event emitted by the engine
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
This is called after all nodes have been executed or when execution is
|
||||
aborted. Layers can use this to clean up resources or log final state.
|
||||
|
||||
Args:
|
||||
error: The exception that caused execution to fail, or None if successful
|
||||
"""
|
||||
pass
|
||||
250
dify/api/core/workflow/graph_engine/layers/debug_logging.py
Normal file
250
dify/api/core/workflow/graph_engine/layers/debug_logging.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Debug logging layer for GraphEngine.
|
||||
|
||||
This module provides a layer that logs all events and state changes during
|
||||
graph execution for debugging purposes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .base import GraphEngineLayer
|
||||
|
||||
|
||||
@final
|
||||
class DebugLoggingLayer(GraphEngineLayer):
|
||||
"""
|
||||
A layer that provides comprehensive logging of GraphEngine execution.
|
||||
|
||||
This layer logs all events with configurable detail levels, helping developers
|
||||
debug workflow execution and understand the flow of events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
level: str = "INFO",
|
||||
include_inputs: bool = False,
|
||||
include_outputs: bool = True,
|
||||
include_process_data: bool = False,
|
||||
logger_name: str = "GraphEngine.Debug",
|
||||
max_value_length: int = 500,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the debug logging layer.
|
||||
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
include_inputs: Whether to log node input values
|
||||
include_outputs: Whether to log node output values
|
||||
include_process_data: Whether to log node process data
|
||||
logger_name: Name of the logger to use
|
||||
max_value_length: Maximum length of logged values (truncated if longer)
|
||||
"""
|
||||
super().__init__()
|
||||
self.level = level
|
||||
self.include_inputs = include_inputs
|
||||
self.include_outputs = include_outputs
|
||||
self.include_process_data = include_process_data
|
||||
self.max_value_length = max_value_length
|
||||
|
||||
# Set up logger
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
log_level = getattr(logging, level.upper(), logging.INFO)
|
||||
self.logger.setLevel(log_level)
|
||||
|
||||
# Track execution stats
|
||||
self.node_count = 0
|
||||
self.success_count = 0
|
||||
self.failure_count = 0
|
||||
self.retry_count = 0
|
||||
|
||||
def _truncate_value(self, value: Any) -> str:
|
||||
"""Truncate long values for logging."""
|
||||
str_value = str(value)
|
||||
if len(str_value) > self.max_value_length:
|
||||
return str_value[: self.max_value_length] + "... (truncated)"
|
||||
return str_value
|
||||
|
||||
def _format_dict(self, data: dict[str, Any] | Mapping[str, Any]) -> str:
|
||||
"""Format a dictionary or mapping for logging with truncation."""
|
||||
if not data:
|
||||
return "{}"
|
||||
|
||||
formatted_items: list[str] = []
|
||||
for key, value in data.items():
|
||||
formatted_value = self._truncate_value(value)
|
||||
formatted_items.append(f" {key}: {formatted_value}")
|
||||
|
||||
return "{\n" + ",\n".join(formatted_items) + "\n}"
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""Log graph execution start."""
|
||||
self.logger.info("=" * 80)
|
||||
self.logger.info("🚀 GRAPH EXECUTION STARTED")
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
if self.graph_runtime_state:
|
||||
# Log initial state
|
||||
self.logger.info("Initial State:")
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""Log individual events based on their type."""
|
||||
event_class = event.__class__.__name__
|
||||
|
||||
# Graph-level events
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.logger.debug("Graph run started event")
|
||||
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.logger.info("✅ Graph run succeeded")
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self.logger.warning("⚠️ Graph run partially succeeded")
|
||||
if event.exceptions_count > 0:
|
||||
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.logger.error("❌ Graph run failed: %s", event.error)
|
||||
if event.exceptions_count > 0:
|
||||
self.logger.error(" Total exceptions: %s", event.exceptions_count)
|
||||
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self.logger.warning("⚠️ Graph run aborted: %s", event.reason)
|
||||
if event.outputs:
|
||||
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
# Node-level events
|
||||
# Retry before Started because Retry subclasses Started;
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self.retry_count += 1
|
||||
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
|
||||
self.logger.warning(" Previous error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.node_count += 1
|
||||
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
|
||||
|
||||
if self.include_inputs and event.node_run_result.inputs:
|
||||
self.logger.debug(" Inputs: %s", self._format_dict(event.node_run_result.inputs))
|
||||
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.success_count += 1
|
||||
self.logger.info("✅ Node succeeded: %s", event.node_id)
|
||||
|
||||
if self.include_outputs and event.node_run_result.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.node_run_result.outputs))
|
||||
|
||||
if self.include_process_data and event.node_run_result.process_data:
|
||||
self.logger.debug(" Process data: %s", self._format_dict(event.node_run_result.process_data))
|
||||
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.failure_count += 1
|
||||
self.logger.error("❌ Node failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
if event.node_run_result.error:
|
||||
self.logger.error(" Details: %s", event.node_run_result.error)
|
||||
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
|
||||
self.logger.warning(" Error: %s", event.error)
|
||||
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
# Log stream chunks at debug level to avoid spam
|
||||
final_indicator = " (FINAL)" if event.is_final else ""
|
||||
self.logger.debug(
|
||||
"📝 Stream chunk from %s%s: %s", event.node_id, final_indicator, self._truncate_value(event.chunk)
|
||||
)
|
||||
|
||||
# Iteration events
|
||||
elif isinstance(event, NodeRunIterationStartedEvent):
|
||||
self.logger.info("🔁 Iteration started: %s", event.node_id)
|
||||
|
||||
elif isinstance(event, NodeRunIterationNextEvent):
|
||||
self.logger.debug(" Iteration next: %s (index: %s)", event.node_id, event.index)
|
||||
|
||||
elif isinstance(event, NodeRunIterationSucceededEvent):
|
||||
self.logger.info("✅ Iteration succeeded: %s", event.node_id)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, NodeRunIterationFailedEvent):
|
||||
self.logger.error("❌ Iteration failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
# Loop events
|
||||
elif isinstance(event, NodeRunLoopStartedEvent):
|
||||
self.logger.info("🔄 Loop started: %s", event.node_id)
|
||||
|
||||
elif isinstance(event, NodeRunLoopNextEvent):
|
||||
self.logger.debug(" Loop iteration: %s (index: %s)", event.node_id, event.index)
|
||||
|
||||
elif isinstance(event, NodeRunLoopSucceededEvent):
|
||||
self.logger.info("✅ Loop succeeded: %s", event.node_id)
|
||||
if self.include_outputs and event.outputs:
|
||||
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
|
||||
|
||||
elif isinstance(event, NodeRunLoopFailedEvent):
|
||||
self.logger.error("❌ Loop failed: %s", event.node_id)
|
||||
self.logger.error(" Error: %s", event.error)
|
||||
|
||||
else:
|
||||
# Log unknown events at debug level
|
||||
self.logger.debug("Event: %s", event_class)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Log graph execution end with summary statistics."""
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
if error:
|
||||
self.logger.error("🔴 GRAPH EXECUTION FAILED")
|
||||
self.logger.error(" Error: %s", error)
|
||||
else:
|
||||
self.logger.info("🎉 GRAPH EXECUTION COMPLETED SUCCESSFULLY")
|
||||
|
||||
# Log execution statistics
|
||||
self.logger.info("Execution Statistics:")
|
||||
self.logger.info(" Total nodes executed: %s", self.node_count)
|
||||
self.logger.info(" Successful nodes: %s", self.success_count)
|
||||
self.logger.info(" Failed nodes: %s", self.failure_count)
|
||||
self.logger.info(" Node retries: %s", self.retry_count)
|
||||
|
||||
# Log final state if available
|
||||
if self.graph_runtime_state and self.include_outputs:
|
||||
if self.graph_runtime_state.outputs:
|
||||
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
|
||||
|
||||
self.logger.info("=" * 80)
|
||||
150
dify/api/core/workflow/graph_engine/layers/execution_limits.py
Normal file
150
dify/api/core/workflow/graph_engine/layers/execution_limits.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Execution limits layer for GraphEngine.
|
||||
|
||||
This layer monitors workflow execution to enforce limits on:
|
||||
- Maximum execution steps
|
||||
- Maximum execution time
|
||||
|
||||
When limits are exceeded, the layer automatically aborts execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from enum import StrEnum
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.layers import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
|
||||
|
||||
class LimitType(StrEnum):
|
||||
"""Types of execution limits that can be exceeded."""
|
||||
|
||||
STEP_LIMIT = "step_limit"
|
||||
TIME_LIMIT = "time_limit"
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionLimitsLayer(GraphEngineLayer):
|
||||
"""
|
||||
Layer that enforces execution limits for workflows.
|
||||
|
||||
Monitors:
|
||||
- Step count: Tracks number of node executions
|
||||
- Time limit: Monitors total execution time
|
||||
|
||||
Automatically aborts execution when limits are exceeded.
|
||||
"""
|
||||
|
||||
def __init__(self, max_steps: int, max_time: int) -> None:
|
||||
"""
|
||||
Initialize the execution limits layer.
|
||||
|
||||
Args:
|
||||
max_steps: Maximum number of execution steps allowed
|
||||
max_time: Maximum execution time in seconds allowed
|
||||
"""
|
||||
super().__init__()
|
||||
self.max_steps = max_steps
|
||||
self.max_time = max_time
|
||||
|
||||
# Runtime tracking
|
||||
self.start_time: float | None = None
|
||||
self.step_count = 0
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# State tracking
|
||||
self._execution_started = False
|
||||
self._execution_ended = False
|
||||
self._abort_sent = False # Track if abort command has been sent
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""Called when graph execution starts."""
|
||||
self.start_time = time.time()
|
||||
self.step_count = 0
|
||||
self._execution_started = True
|
||||
self._execution_ended = False
|
||||
self._abort_sent = False
|
||||
|
||||
self.logger.debug("Execution limits monitoring started")
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
Monitors execution progress and enforces limits.
|
||||
"""
|
||||
if not self._execution_started or self._execution_ended or self._abort_sent:
|
||||
return
|
||||
|
||||
# Track step count for node execution events
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self.step_count += 1
|
||||
self.logger.debug("Step %d started: %s", self.step_count, event.node_id)
|
||||
|
||||
# Check step limit when node execution completes
|
||||
if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent):
|
||||
if self._reached_step_limitation():
|
||||
self._send_abort_command(LimitType.STEP_LIMIT)
|
||||
|
||||
if self._reached_time_limitation():
|
||||
self._send_abort_command(LimitType.TIME_LIMIT)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Called when graph execution ends."""
|
||||
if self._execution_started and not self._execution_ended:
|
||||
self._execution_ended = True
|
||||
|
||||
if self.start_time:
|
||||
total_time = time.time() - self.start_time
|
||||
self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time)
|
||||
|
||||
def _reached_step_limitation(self) -> bool:
|
||||
"""Check if step count limit has been exceeded."""
|
||||
return self.step_count > self.max_steps
|
||||
|
||||
def _reached_time_limitation(self) -> bool:
|
||||
"""Check if time limit has been exceeded."""
|
||||
return self.start_time is not None and (time.time() - self.start_time) > self.max_time
|
||||
|
||||
def _send_abort_command(self, limit_type: LimitType) -> None:
|
||||
"""
|
||||
Send abort command due to limit violation.
|
||||
|
||||
Args:
|
||||
limit_type: Type of limit exceeded
|
||||
"""
|
||||
if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent:
|
||||
return
|
||||
|
||||
# Format detailed reason message
|
||||
if limit_type == LimitType.STEP_LIMIT:
|
||||
reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}"
|
||||
elif limit_type == LimitType.TIME_LIMIT:
|
||||
elapsed_time = time.time() - self.start_time if self.start_time else 0
|
||||
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
|
||||
|
||||
self.logger.warning("Execution limit exceeded: %s", reason)
|
||||
|
||||
try:
|
||||
# Send abort command to the engine
|
||||
abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason)
|
||||
self.command_channel.send_command(abort_command)
|
||||
|
||||
# Mark that abort has been sent to prevent duplicate commands
|
||||
self._abort_sent = True
|
||||
|
||||
self.logger.debug("Abort command sent to engine")
|
||||
|
||||
except Exception:
|
||||
self.logger.exception("Failed to send abort command")
|
||||
409
dify/api/core/workflow/graph_engine/layers/persistence.py
Normal file
409
dify/api/core/workflow/graph_engine/layers/persistence.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""Workflow persistence layer for GraphEngine.
|
||||
|
||||
This layer mirrors the former ``WorkflowCycleManager`` responsibilities by
|
||||
listening to ``GraphEngineEvent`` instances directly and persisting workflow
|
||||
and node execution state via the injected repositories.
|
||||
|
||||
The design keeps domain persistence concerns inside the engine thread, while
|
||||
allowing presentation layers to remain read-only observers of repository
|
||||
state.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import (
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PersistenceWorkflowInfo:
|
||||
"""Static workflow metadata required for persistence."""
|
||||
|
||||
workflow_id: str
|
||||
workflow_type: WorkflowType
|
||||
version: str
|
||||
graph_data: Mapping[str, Any]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _NodeRuntimeSnapshot:
|
||||
"""Lightweight cache to keep node metadata across event phases."""
|
||||
|
||||
node_id: str
|
||||
title: str
|
||||
predecessor_node_id: str | None
|
||||
iteration_id: str | None
|
||||
loop_id: str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
"""GraphEngine layer that persists workflow and node execution state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
workflow_info: PersistenceWorkflowInfo,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_info = workflow_info
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._trace_manager = trace_manager
|
||||
|
||||
self._workflow_execution: WorkflowExecution | None = None
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||
self._node_snapshots: dict[str, _NodeRuntimeSnapshot] = {}
|
||||
self._node_sequence: int = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GraphEngineLayer lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
def on_graph_start(self) -> None:
|
||||
self._workflow_execution = None
|
||||
self._node_execution_cache.clear()
|
||||
self._node_snapshots.clear()
|
||||
self._node_sequence = 0
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._handle_graph_run_started()
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunSucceededEvent):
|
||||
self._handle_graph_run_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self._handle_graph_run_partial_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
self._handle_graph_run_failed(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunAbortedEvent):
|
||||
self._handle_graph_run_aborted(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
self._handle_graph_run_paused(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self._handle_node_started(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunRetryEvent):
|
||||
self._handle_node_retry(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
self._handle_node_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunFailedEvent):
|
||||
self._handle_node_failed(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunExceptionEvent):
|
||||
self._handle_node_exception(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunPauseRequestedEvent):
|
||||
self._handle_node_pause_requested(event)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
return
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Graph-level handlers
|
||||
# ------------------------------------------------------------------
|
||||
def _handle_graph_run_started(self) -> None:
|
||||
execution_id = self._get_execution_id()
|
||||
workflow_execution = WorkflowExecution.new(
|
||||
id_=execution_id,
|
||||
workflow_id=self._workflow_info.workflow_id,
|
||||
workflow_type=self._workflow_info.workflow_type,
|
||||
workflow_version=self._workflow_info.version,
|
||||
graph=self._workflow_info.graph_data,
|
||||
inputs=self._prepare_workflow_inputs(),
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
self._workflow_execution_repository.save(workflow_execution)
|
||||
self._workflow_execution = workflow_execution
|
||||
|
||||
def _handle_graph_run_succeeded(self, event: GraphRunSucceededEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.outputs = event.outputs
|
||||
execution.status = WorkflowExecutionStatus.SUCCEEDED
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_partial_succeeded(self, event: GraphRunPartialSucceededEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.outputs = event.outputs
|
||||
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
|
||||
execution.exceptions_count = event.exceptions_count
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_failed(self, event: GraphRunFailedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.FAILED
|
||||
execution.error_message = event.error
|
||||
execution.exceptions_count = event.exceptions_count
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._fail_running_node_executions(error_message=event.error)
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_aborted(self, event: GraphRunAbortedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.STOPPED
|
||||
execution.error_message = event.reason or "Workflow execution aborted"
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._fail_running_node_executions(error_message=execution.error_message or "")
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.PAUSED
|
||||
execution.outputs = event.outputs
|
||||
self._populate_completion_statistics(execution, update_finished=False)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Node-level handlers
|
||||
# ------------------------------------------------------------------
|
||||
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
|
||||
domain_execution = WorkflowNodeExecution(
|
||||
id=event.id,
|
||||
node_execution_id=event.id,
|
||||
workflow_id=execution.workflow_id,
|
||||
workflow_execution_id=execution.id_,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
index=self._next_node_sequence(),
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
title=event.node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
metadata=metadata,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
|
||||
self._node_execution_cache[event.id] = domain_execution
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
|
||||
snapshot = _NodeRuntimeSnapshot(
|
||||
node_id=event.node_id,
|
||||
title=event.node_title,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
self._node_snapshots[event.id] = snapshot
|
||||
|
||||
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
domain_execution.status = WorkflowNodeExecutionStatus.RETRY
|
||||
domain_execution.error = event.error
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.PAUSED,
|
||||
error="",
|
||||
update_outputs=False,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _get_execution_id(self) -> str:
|
||||
workflow_execution_id = self._system_variables().get(SystemVariableKey.WORKFLOW_EXECUTION_ID)
|
||||
if not workflow_execution_id:
|
||||
raise ValueError("workflow_execution_id must be provided in system variables for pause/resume flows")
|
||||
return str(workflow_execution_id)
|
||||
|
||||
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for field_name, value in self._system_variables().items():
|
||||
if field_name == SystemVariableKey.CONVERSATION_ID.value:
|
||||
# Conversation IDs are tied to the current session; omit them so persisted
|
||||
# workflow inputs stay reusable without binding future runs to this conversation.
|
||||
continue
|
||||
inputs[f"sys.{field_name}"] = value
|
||||
handled = WorkflowEntry.handle_special_values(inputs)
|
||||
return handled or {}
|
||||
|
||||
def _get_workflow_execution(self) -> WorkflowExecution:
|
||||
if self._workflow_execution is None:
|
||||
raise ValueError("workflow execution not initialized")
|
||||
return self._workflow_execution
|
||||
|
||||
def _get_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
if node_execution_id not in self._node_execution_cache:
|
||||
raise ValueError(f"Node execution not found for id={node_execution_id}")
|
||||
return self._node_execution_cache[node_execution_id]
|
||||
|
||||
def _next_node_sequence(self) -> int:
|
||||
self._node_sequence += 1
|
||||
return self._node_sequence
|
||||
|
||||
def _populate_completion_statistics(self, execution: WorkflowExecution, *, update_finished: bool = True) -> None:
|
||||
if update_finished:
|
||||
execution.finished_at = naive_utc_now()
|
||||
runtime_state = self.graph_runtime_state
|
||||
if runtime_state is None:
|
||||
return
|
||||
execution.total_tokens = runtime_state.total_tokens
|
||||
execution.total_steps = runtime_state.node_run_steps
|
||||
execution.outputs = execution.outputs or runtime_state.outputs
|
||||
execution.exceptions_count = runtime_state.exceptions_count
|
||||
|
||||
def _update_node_execution(
|
||||
self,
|
||||
domain_execution: WorkflowNodeExecution,
|
||||
node_result: NodeRunResult,
|
||||
status: WorkflowNodeExecutionStatus,
|
||||
*,
|
||||
error: str | None = None,
|
||||
update_outputs: bool = True,
|
||||
) -> None:
|
||||
finished_at = naive_utc_now()
|
||||
snapshot = self._node_snapshots.get(domain_execution.id)
|
||||
start_at = snapshot.created_at if snapshot else domain_execution.created_at
|
||||
domain_execution.status = status
|
||||
domain_execution.finished_at = finished_at
|
||||
domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0)
|
||||
|
||||
if error:
|
||||
domain_execution.error = error
|
||||
|
||||
if update_outputs:
|
||||
domain_execution.update_from_mapping(
|
||||
inputs=node_result.inputs,
|
||||
process_data=node_result.process_data,
|
||||
outputs=node_result.outputs,
|
||||
metadata=node_result.metadata,
|
||||
)
|
||||
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
|
||||
def _fail_running_node_executions(self, *, error_message: str) -> None:
|
||||
now = naive_utc_now()
|
||||
for execution in self._node_execution_cache.values():
|
||||
if execution.status == WorkflowNodeExecutionStatus.RUNNING:
|
||||
execution.status = WorkflowNodeExecutionStatus.FAILED
|
||||
execution.error = error_message
|
||||
execution.finished_at = now
|
||||
execution.elapsed_time = max((now - execution.created_at).total_seconds(), 0.0)
|
||||
self._workflow_node_execution_repository.save(execution)
|
||||
|
||||
def _enqueue_trace_task(self, execution: WorkflowExecution) -> None:
|
||||
if not self._trace_manager:
|
||||
return
|
||||
|
||||
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
|
||||
external_trace_id = None
|
||||
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
|
||||
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
|
||||
|
||||
trace_task = TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_execution=execution,
|
||||
conversation_id=conversation_id,
|
||||
user_id=self._trace_manager.user_id,
|
||||
external_trace_id=external_trace_id,
|
||||
)
|
||||
self._trace_manager.add_trace_task(trace_task)
|
||||
|
||||
def _system_variables(self) -> Mapping[str, Any]:
|
||||
runtime_state = self.graph_runtime_state
|
||||
if runtime_state is None:
|
||||
return {}
|
||||
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
|
||||
60
dify/api/core/workflow/graph_engine/manager.py
Normal file
60
dify/api/core/workflow/graph_engine/manager.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
GraphEngine Manager for sending control commands via Redis channel.
|
||||
|
||||
This module provides a simplified interface for controlling workflow executions
|
||||
using the new Redis command channel, without requiring user permission checks.
|
||||
Supports stop, pause, and resume operations.
|
||||
"""
|
||||
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
@final
|
||||
class GraphEngineManager:
|
||||
"""
|
||||
Manager for sending control commands to GraphEngine instances.
|
||||
|
||||
This class provides a simple interface for controlling workflow executions
|
||||
by sending commands through Redis channels, without user validation.
|
||||
Supports stop and pause operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def send_stop_command(task_id: str, reason: str | None = None) -> None:
|
||||
"""
|
||||
Send a stop command to a running workflow.
|
||||
|
||||
Args:
|
||||
task_id: The task ID of the workflow to stop
|
||||
reason: Optional reason for stopping (defaults to "User requested stop")
|
||||
"""
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
GraphEngineManager._send_command(task_id, abort_command)
|
||||
|
||||
@staticmethod
|
||||
def send_pause_command(task_id: str, reason: str | None = None) -> None:
|
||||
"""Send a pause command to a running workflow."""
|
||||
|
||||
pause_command = PauseCommand(reason=reason or "User requested pause")
|
||||
GraphEngineManager._send_command(task_id, pause_command)
|
||||
|
||||
@staticmethod
|
||||
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to the workflow-specific Redis channel."""
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
try:
|
||||
channel.send_command(command)
|
||||
except Exception:
|
||||
# Silently fail if Redis is unavailable
|
||||
# The legacy control mechanisms will still work
|
||||
pass
|
||||
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Orchestration subsystem for graph engine.
|
||||
|
||||
This package coordinates the overall execution flow between
|
||||
different subsystems.
|
||||
"""
|
||||
|
||||
from .dispatcher import Dispatcher
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
|
||||
__all__ = [
|
||||
"Dispatcher",
|
||||
"ExecutionCoordinator",
|
||||
]
|
||||
125
dify/api/core/workflow/graph_engine/orchestration/dispatcher.py
Normal file
125
dify/api/core/workflow/graph_engine/orchestration/dispatcher.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Main dispatcher for processing events from workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from ..event_management import EventManager
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..event_management import EventHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class Dispatcher:
|
||||
"""
|
||||
Main dispatcher that processes events from the event queue.
|
||||
|
||||
This runs in a separate thread and coordinates event processing
|
||||
with timeout and completion detection.
|
||||
"""
|
||||
|
||||
_COMMAND_TRIGGER_EVENTS = (
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunExceptionEvent,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandler",
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the dispatcher.
|
||||
|
||||
Args:
|
||||
event_queue: Queue of events from workers
|
||||
event_handler: Event handler registry for processing events
|
||||
execution_coordinator: Coordinator for execution flow
|
||||
event_emitter: Optional event manager to signal completion
|
||||
"""
|
||||
self._event_queue = event_queue
|
||||
self._event_handler = event_handler
|
||||
self._execution_coordinator = execution_coordinator
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the dispatcher thread."""
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_time = time.time()
|
||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher thread."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=10.0)
|
||||
|
||||
def _dispatcher_loop(self) -> None:
|
||||
"""Main dispatcher loop."""
|
||||
try:
|
||||
self._process_commands()
|
||||
while not self._stop_event.is_set():
|
||||
if (
|
||||
self._execution_coordinator.aborted
|
||||
or self._execution_coordinator.paused
|
||||
or self._execution_coordinator.execution_complete
|
||||
):
|
||||
break
|
||||
|
||||
self._execution_coordinator.check_scaling()
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
self._process_commands(event)
|
||||
except queue.Empty:
|
||||
time.sleep(0.1)
|
||||
|
||||
self._process_commands()
|
||||
while True:
|
||||
try:
|
||||
event = self._event_queue.get(block=False)
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Dispatcher error")
|
||||
self._execution_coordinator.mark_failed(e)
|
||||
|
||||
finally:
|
||||
self._execution_coordinator.mark_complete()
|
||||
# Signal the event emitter that execution is complete
|
||||
if self._event_emitter:
|
||||
self._event_emitter.mark_complete()
|
||||
|
||||
def _process_commands(self, event: GraphNodeEventBase | None = None):
|
||||
if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS):
|
||||
self._execution_coordinator.process_commands()
|
||||
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
Execution coordinator for managing overall workflow execution.
|
||||
"""
|
||||
|
||||
from typing import final
|
||||
|
||||
from ..command_processing import CommandProcessor
|
||||
from ..domain import GraphExecution
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..worker_management import WorkerPool
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionCoordinator:
|
||||
"""
|
||||
Coordinates overall execution flow between subsystems.
|
||||
|
||||
This provides high-level coordination methods used by the
|
||||
dispatcher to manage execution state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_execution: GraphExecution,
|
||||
state_manager: GraphStateManager,
|
||||
command_processor: CommandProcessor,
|
||||
worker_pool: WorkerPool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the execution coordinator.
|
||||
|
||||
Args:
|
||||
graph_execution: Graph execution aggregate
|
||||
state_manager: Unified state manager
|
||||
command_processor: Processor for commands
|
||||
worker_pool: Pool of workers
|
||||
"""
|
||||
self._graph_execution = graph_execution
|
||||
self._state_manager = state_manager
|
||||
self._command_processor = command_processor
|
||||
self._worker_pool = worker_pool
|
||||
|
||||
def process_commands(self) -> None:
|
||||
"""Process any pending commands."""
|
||||
self._command_processor.process_commands()
|
||||
|
||||
def check_scaling(self) -> None:
|
||||
"""Check and perform worker scaling if needed."""
|
||||
self._worker_pool.check_and_scale()
|
||||
|
||||
@property
|
||||
def execution_complete(self):
|
||||
return self._state_manager.is_execution_complete()
|
||||
|
||||
@property
|
||||
def aborted(self):
|
||||
return self._graph_execution.aborted or self._graph_execution.has_error
|
||||
|
||||
@property
|
||||
def paused(self) -> bool:
|
||||
"""Expose whether the underlying graph execution is paused."""
|
||||
return self._graph_execution.is_paused
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete."""
|
||||
if self._graph_execution.is_paused:
|
||||
return
|
||||
if not self._graph_execution.completed:
|
||||
self._graph_execution.complete()
|
||||
|
||||
def mark_failed(self, error: Exception) -> None:
|
||||
"""
|
||||
Mark execution as failed.
|
||||
|
||||
Args:
|
||||
error: The error that caused failure
|
||||
"""
|
||||
self._graph_execution.fail(error)
|
||||
|
||||
def handle_pause_if_needed(self) -> None:
|
||||
"""If the execution has been paused, stop workers immediately."""
|
||||
|
||||
if not self._graph_execution.is_paused:
|
||||
return
|
||||
|
||||
self._worker_pool.stop()
|
||||
self._state_manager.clear_executing()
|
||||
|
||||
def handle_abort_if_needed(self) -> None:
|
||||
"""If the execution has been aborted, stop workers immediately."""
|
||||
|
||||
if not self._graph_execution.aborted:
|
||||
return
|
||||
|
||||
self._worker_pool.stop()
|
||||
self._state_manager.clear_executing()
|
||||
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
CommandChannel protocol for GraphEngine command communication.
|
||||
|
||||
This protocol defines the interface for sending and receiving commands
|
||||
to/from a GraphEngine instance, supporting both local and distributed scenarios.
|
||||
"""
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
|
||||
|
||||
class CommandChannel(Protocol):
|
||||
"""
|
||||
Protocol for bidirectional command communication with GraphEngine.
|
||||
|
||||
Since each GraphEngine instance processes only one workflow execution,
|
||||
this channel is dedicated to that single execution.
|
||||
"""
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""
|
||||
Fetch pending commands for this GraphEngine instance.
|
||||
|
||||
Called by GraphEngine to poll for commands that need to be processed.
|
||||
|
||||
Returns:
|
||||
List of pending commands (may be empty)
|
||||
"""
|
||||
...
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""
|
||||
Send a command to be processed by this GraphEngine instance.
|
||||
|
||||
Called by external systems to send control commands to the running workflow.
|
||||
|
||||
Args:
|
||||
command: The command to send
|
||||
"""
|
||||
...
|
||||
12
dify/api/core/workflow/graph_engine/ready_queue/__init__.py
Normal file
12
dify/api/core/workflow/graph_engine/ready_queue/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Ready queue implementations for GraphEngine.
|
||||
|
||||
This package contains the protocol and implementations for managing
|
||||
the queue of nodes ready for execution.
|
||||
"""
|
||||
|
||||
from .factory import create_ready_queue_from_state
|
||||
from .in_memory import InMemoryReadyQueue
|
||||
from .protocol import ReadyQueue, ReadyQueueState
|
||||
|
||||
__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"]
|
||||
35
dify/api/core/workflow/graph_engine/ready_queue/factory.py
Normal file
35
dify/api/core/workflow/graph_engine/ready_queue/factory.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Factory for creating ReadyQueue instances from serialized state.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .in_memory import InMemoryReadyQueue
|
||||
from .protocol import ReadyQueueState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protocol import ReadyQueue
|
||||
|
||||
|
||||
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
|
||||
"""
|
||||
Create a ReadyQueue instance from a serialized state.
|
||||
|
||||
Args:
|
||||
state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue
|
||||
|
||||
Returns:
|
||||
A ReadyQueue instance initialized with the given state
|
||||
|
||||
Raises:
|
||||
ValueError: If the queue type is unknown or version is unsupported
|
||||
"""
|
||||
if state.type == "InMemoryReadyQueue":
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}")
|
||||
queue = InMemoryReadyQueue()
|
||||
# Always pass as JSON string to loads()
|
||||
queue.loads(state.model_dump_json())
|
||||
return queue
|
||||
else:
|
||||
raise ValueError(f"Unknown ready queue type: {state.type}")
|
||||
140
dify/api/core/workflow/graph_engine/ready_queue/in_memory.py
Normal file
140
dify/api/core/workflow/graph_engine/ready_queue/in_memory.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
In-memory implementation of the ReadyQueue protocol.
|
||||
|
||||
This implementation wraps Python's standard queue.Queue and adds
|
||||
serialization capabilities for state storage.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import final
|
||||
|
||||
from .protocol import ReadyQueue, ReadyQueueState
|
||||
|
||||
|
||||
@final
|
||||
class InMemoryReadyQueue(ReadyQueue):
|
||||
"""
|
||||
In-memory ready queue implementation with serialization support.
|
||||
|
||||
This implementation uses Python's queue.Queue internally and provides
|
||||
methods to serialize and restore the queue state.
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize: int = 0) -> None:
|
||||
"""
|
||||
Initialize the in-memory ready queue.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum size of the queue (0 for unlimited)
|
||||
"""
|
||||
self._queue: queue.Queue[str] = queue.Queue(maxsize=maxsize)
|
||||
|
||||
def put(self, item: str) -> None:
|
||||
"""
|
||||
Add a node ID to the ready queue.
|
||||
|
||||
Args:
|
||||
item: The node ID to add to the queue
|
||||
"""
|
||||
self._queue.put(item)
|
||||
|
||||
def get(self, timeout: float | None = None) -> str:
|
||||
"""
|
||||
Retrieve and remove a node ID from the queue.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for an item (None for blocking)
|
||||
|
||||
Returns:
|
||||
The node ID retrieved from the queue
|
||||
|
||||
Raises:
|
||||
queue.Empty: If timeout expires and no item is available
|
||||
"""
|
||||
if timeout is None:
|
||||
return self._queue.get(block=True)
|
||||
return self._queue.get(timeout=timeout)
|
||||
|
||||
def task_done(self) -> None:
|
||||
"""
|
||||
Indicate that a previously retrieved task is complete.
|
||||
|
||||
Used by worker threads to signal task completion for
|
||||
join() synchronization.
|
||||
"""
|
||||
self._queue.task_done()
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""
|
||||
Check if the queue is empty.
|
||||
|
||||
Returns:
|
||||
True if the queue has no items, False otherwise
|
||||
"""
|
||||
return self._queue.empty()
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""
|
||||
Get the approximate size of the queue.
|
||||
|
||||
Returns:
|
||||
The approximate number of items in the queue
|
||||
"""
|
||||
return self._queue.qsize()
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""
|
||||
Serialize the queue state to a JSON string for storage.
|
||||
|
||||
Returns:
|
||||
A JSON string containing the serialized queue state
|
||||
"""
|
||||
# Extract all items from the queue without removing them
|
||||
items: list[str] = []
|
||||
temp_items: list[str] = []
|
||||
|
||||
# Drain the queue temporarily to get all items
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
item = self._queue.get_nowait()
|
||||
temp_items.append(item)
|
||||
items.append(item)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Put items back in the same order
|
||||
for item in temp_items:
|
||||
self._queue.put(item)
|
||||
|
||||
state = ReadyQueueState(
|
||||
type="InMemoryReadyQueue",
|
||||
version="1.0",
|
||||
items=items,
|
||||
)
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""
|
||||
Restore the queue state from a JSON string.
|
||||
|
||||
Args:
|
||||
data: The JSON string containing the serialized queue state to restore
|
||||
"""
|
||||
state = ReadyQueueState.model_validate_json(data)
|
||||
|
||||
if state.type != "InMemoryReadyQueue":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported version: {state.version}")
|
||||
|
||||
# Clear the current queue
|
||||
while not self._queue.empty():
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Restore items
|
||||
for item in state.items:
|
||||
self._queue.put(item)
|
||||
104
dify/api/core/workflow/graph_engine/ready_queue/protocol.py
Normal file
104
dify/api/core/workflow/graph_engine/ready_queue/protocol.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
ReadyQueue protocol for GraphEngine node execution queue.
|
||||
|
||||
This protocol defines the interface for managing the queue of nodes ready
|
||||
for execution, supporting both in-memory and persistent storage scenarios.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ReadyQueueState(BaseModel):
|
||||
"""
|
||||
Pydantic model for serialized ready queue state.
|
||||
|
||||
This defines the structure of the data returned by dumps()
|
||||
and expected by loads() for ready queue serialization.
|
||||
"""
|
||||
|
||||
type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')")
|
||||
version: str = Field(description="Serialization format version")
|
||||
items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue")
|
||||
|
||||
|
||||
class ReadyQueue(Protocol):
|
||||
"""
|
||||
Protocol for managing nodes ready for execution in GraphEngine.
|
||||
|
||||
This protocol defines the interface that any ready queue implementation
|
||||
must provide, enabling both in-memory queues and persistent queues
|
||||
that can be serialized for state storage.
|
||||
"""
|
||||
|
||||
def put(self, item: str) -> None:
|
||||
"""
|
||||
Add a node ID to the ready queue.
|
||||
|
||||
Args:
|
||||
item: The node ID to add to the queue
|
||||
"""
|
||||
...
|
||||
|
||||
def get(self, timeout: float | None = None) -> str:
|
||||
"""
|
||||
Retrieve and remove a node ID from the queue.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for an item (None for blocking)
|
||||
|
||||
Returns:
|
||||
The node ID retrieved from the queue
|
||||
|
||||
Raises:
|
||||
queue.Empty: If timeout expires and no item is available
|
||||
"""
|
||||
...
|
||||
|
||||
def task_done(self) -> None:
|
||||
"""
|
||||
Indicate that a previously retrieved task is complete.
|
||||
|
||||
Used by worker threads to signal task completion for
|
||||
join() synchronization.
|
||||
"""
|
||||
...
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""
|
||||
Check if the queue is empty.
|
||||
|
||||
Returns:
|
||||
True if the queue has no items, False otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""
|
||||
Get the approximate size of the queue.
|
||||
|
||||
Returns:
|
||||
The approximate number of items in the queue
|
||||
"""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""
|
||||
Serialize the queue state to a JSON string for storage.
|
||||
|
||||
Returns:
|
||||
A JSON string containing the serialized queue state
|
||||
that can be persisted and later restored
|
||||
"""
|
||||
...
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""
|
||||
Restore the queue state from a JSON string.
|
||||
|
||||
Args:
|
||||
data: The JSON string containing the serialized queue state to restore
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
ResponseStreamCoordinator - Coordinates streaming output from response nodes
|
||||
|
||||
This component manages response streaming sessions and ensures ordered streaming
|
||||
of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
from .coordinator import ResponseStreamCoordinator
|
||||
|
||||
__all__ = ["ResponseStreamCoordinator"]
|
||||
@@ -0,0 +1,697 @@
|
||||
"""
|
||||
Main ResponseStreamCoordinator implementation.
|
||||
|
||||
This module contains the public ResponseStreamCoordinator class that manages
|
||||
response streaming sessions and ensures ordered streaming of responses.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from threading import RLock
|
||||
from typing import Literal, TypeAlias, final
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from .path import Path
|
||||
from .session import ResponseSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type definitions
|
||||
NodeID: TypeAlias = str
|
||||
EdgeID: TypeAlias = str
|
||||
|
||||
|
||||
class ResponseSessionState(BaseModel):
|
||||
"""Serializable representation of a response session."""
|
||||
|
||||
node_id: str
|
||||
index: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class StreamBufferState(BaseModel):
|
||||
"""Serializable representation of buffered stream chunks."""
|
||||
|
||||
selector: tuple[str, ...]
|
||||
events: list[NodeRunStreamChunkEvent] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StreamPositionState(BaseModel):
|
||||
"""Serializable representation for stream read positions."""
|
||||
|
||||
selector: tuple[str, ...]
|
||||
position: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class ResponseStreamCoordinatorState(BaseModel):
|
||||
"""Serialized snapshot of ResponseStreamCoordinator."""
|
||||
|
||||
type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator")
|
||||
version: str = Field(default="1.0")
|
||||
response_nodes: Sequence[str] = Field(default_factory=list)
|
||||
active_session: ResponseSessionState | None = None
|
||||
waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
|
||||
pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
|
||||
node_execution_ids: dict[str, str] = Field(default_factory=dict)
|
||||
paths_map: dict[str, list[list[str]]] = Field(default_factory=dict)
|
||||
stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list)
|
||||
stream_positions: Sequence[StreamPositionState] = Field(default_factory=list)
|
||||
closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list)
|
||||
|
||||
|
||||
@final
|
||||
class ResponseStreamCoordinator:
|
||||
"""
|
||||
Manages response streaming sessions without relying on global state.
|
||||
|
||||
Ensures ordered streaming of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
|
||||
"""
|
||||
Initialize coordinator with variable pool.
|
||||
|
||||
Args:
|
||||
variable_pool: VariablePool instance for accessing node variables
|
||||
graph: Graph instance for looking up node information
|
||||
"""
|
||||
self._variable_pool = variable_pool
|
||||
self._graph = graph
|
||||
self._active_session: ResponseSession | None = None
|
||||
self._waiting_sessions: deque[ResponseSession] = deque()
|
||||
self._lock = RLock()
|
||||
|
||||
# Internal stream management (replacing OutputRegistry)
|
||||
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
|
||||
self._stream_positions: dict[tuple[str, ...], int] = {}
|
||||
self._closed_streams: set[tuple[str, ...]] = set()
|
||||
|
||||
# Track response nodes
|
||||
self._response_nodes: set[NodeID] = set()
|
||||
|
||||
# Store paths for each response node
|
||||
self._paths_maps: dict[NodeID, list[Path]] = {}
|
||||
|
||||
# Track node execution IDs and types for proper event forwarding
|
||||
self._node_execution_ids: dict[NodeID, str] = {} # node_id -> execution_id
|
||||
|
||||
# Track response sessions to ensure only one per node
|
||||
self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session
|
||||
|
||||
def register(self, response_node_id: NodeID) -> None:
|
||||
with self._lock:
|
||||
if response_node_id in self._response_nodes:
|
||||
return
|
||||
self._response_nodes.add(response_node_id)
|
||||
|
||||
# Build and save paths map for this response node
|
||||
paths_map = self._build_paths_map(response_node_id)
|
||||
self._paths_maps[response_node_id] = paths_map
|
||||
|
||||
# Create and store response session for this node
|
||||
response_node = self._graph.nodes[response_node_id]
|
||||
session = ResponseSession.from_node(response_node)
|
||||
self._response_sessions[response_node_id] = session
|
||||
|
||||
def track_node_execution(self, node_id: NodeID, execution_id: str) -> None:
|
||||
"""Track the execution ID for a node when it starts executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
execution_id: The execution ID from NodeRunStartedEvent
|
||||
"""
|
||||
with self._lock:
|
||||
self._node_execution_ids[node_id] = execution_id
|
||||
|
||||
def _get_or_create_execution_id(self, node_id: NodeID) -> str:
|
||||
"""Get the execution ID for a node, creating one if it doesn't exist.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
|
||||
Returns:
|
||||
The execution ID for the node
|
||||
"""
|
||||
with self._lock:
|
||||
if node_id not in self._node_execution_ids:
|
||||
self._node_execution_ids[node_id] = str(uuid4())
|
||||
return self._node_execution_ids[node_id]
|
||||
|
||||
def _build_paths_map(self, response_node_id: NodeID) -> list[Path]:
|
||||
"""
|
||||
Build a paths map for a response node by finding all paths from root node
|
||||
to the response node, recording branch edges along each path.
|
||||
|
||||
Args:
|
||||
response_node_id: ID of the response node to analyze
|
||||
|
||||
Returns:
|
||||
List of Path objects, where each path contains branch edge IDs
|
||||
"""
|
||||
# Get root node ID
|
||||
root_node_id = self._graph.root_node.id
|
||||
|
||||
# If root is the response node, return empty path
|
||||
if root_node_id == response_node_id:
|
||||
return [Path()]
|
||||
|
||||
# Extract variable selectors from the response node's template
|
||||
response_node = self._graph.nodes[response_node_id]
|
||||
response_session = ResponseSession.from_node(response_node)
|
||||
template = response_session.template
|
||||
|
||||
# Collect all variable selectors from the template
|
||||
variable_selectors: set[tuple[str, ...]] = set()
|
||||
for segment in template.segments:
|
||||
if isinstance(segment, VariableSegment):
|
||||
variable_selectors.add(tuple(segment.selector[:2]))
|
||||
|
||||
# Step 1: Find all complete paths from root to response node
|
||||
all_complete_paths: list[list[EdgeID]] = []
|
||||
|
||||
def find_paths(
|
||||
current_node_id: NodeID, target_node_id: NodeID, current_path: list[EdgeID], visited: set[NodeID]
|
||||
) -> None:
|
||||
"""Recursively find all paths from current node to target node."""
|
||||
if current_node_id == target_node_id:
|
||||
# Found a complete path, store it
|
||||
all_complete_paths.append(current_path.copy())
|
||||
return
|
||||
|
||||
# Mark as visited to avoid cycles
|
||||
visited.add(current_node_id)
|
||||
|
||||
# Explore outgoing edges
|
||||
outgoing_edges = self._graph.get_outgoing_edges(current_node_id)
|
||||
for edge in outgoing_edges:
|
||||
edge_id = edge.id
|
||||
next_node_id = edge.head
|
||||
|
||||
# Skip if already visited in this path
|
||||
if next_node_id not in visited:
|
||||
# Add edge to path and recurse
|
||||
new_path = current_path + [edge_id]
|
||||
find_paths(next_node_id, target_node_id, new_path, visited.copy())
|
||||
|
||||
# Start searching from root node
|
||||
find_paths(root_node_id, response_node_id, [], set())
|
||||
|
||||
# Step 2: For each complete path, filter edges based on node blocking behavior
|
||||
filtered_paths: list[Path] = []
|
||||
for path in all_complete_paths:
|
||||
blocking_edges: list[str] = []
|
||||
for edge_id in path:
|
||||
edge = self._graph.edges[edge_id]
|
||||
source_node = self._graph.nodes[edge.tail]
|
||||
|
||||
# Check if node is a branch, container, or response node
|
||||
if source_node.execution_type in {
|
||||
NodeExecutionType.BRANCH,
|
||||
NodeExecutionType.CONTAINER,
|
||||
NodeExecutionType.RESPONSE,
|
||||
} or source_node.blocks_variable_output(variable_selectors):
|
||||
blocking_edges.append(edge_id)
|
||||
|
||||
# Keep the path even if it's empty
|
||||
filtered_paths.append(Path(edges=blocking_edges))
|
||||
|
||||
return filtered_paths
|
||||
|
||||
def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
Handle when an edge is taken (selected by a branch node).
|
||||
|
||||
This method updates the paths for all response nodes by removing
|
||||
the taken edge. If any response node has an empty path after removal,
|
||||
it means the node is now deterministically reachable and should start.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge that was taken
|
||||
|
||||
Returns:
|
||||
List of events to emit from starting new sessions
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
with self._lock:
|
||||
# Check each response node in order
|
||||
for response_node_id in self._response_nodes:
|
||||
if response_node_id not in self._paths_maps:
|
||||
continue
|
||||
|
||||
paths = self._paths_maps[response_node_id]
|
||||
has_reachable_path = False
|
||||
|
||||
# Update each path by removing the taken edge
|
||||
for path in paths:
|
||||
# Remove the taken edge from this path
|
||||
path.remove_edge(edge_id)
|
||||
|
||||
# Check if this path is now empty (node is reachable)
|
||||
if path.is_empty():
|
||||
has_reachable_path = True
|
||||
|
||||
# If node is now reachable (has empty path), start/queue session
|
||||
if has_reachable_path:
|
||||
# Pass the node_id to the activation method
|
||||
# The method will handle checking and removing from map
|
||||
events.extend(self._active_or_queue_session(response_node_id))
|
||||
return events
|
||||
|
||||
def _active_or_queue_session(self, node_id: str) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
Start a session immediately if no active session, otherwise queue it.
|
||||
Only activates sessions that exist in the _response_sessions map.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the response node to activate
|
||||
|
||||
Returns:
|
||||
List of events from flush attempt if session started immediately
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Get the session from our map (only activate if it exists)
|
||||
session = self._response_sessions.get(node_id)
|
||||
if not session:
|
||||
return events
|
||||
|
||||
# Remove from map to ensure it won't be activated again
|
||||
del self._response_sessions[node_id]
|
||||
|
||||
if self._active_session is None:
|
||||
self._active_session = session
|
||||
|
||||
# Try to flush immediately
|
||||
events.extend(self.try_flush())
|
||||
else:
|
||||
# Queue the session if another is active
|
||||
self._waiting_sessions.append(session)
|
||||
|
||||
return events
|
||||
|
||||
def intercept_event(
|
||||
self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent
|
||||
) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
with self._lock:
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._append_stream_chunk(event.selector, event)
|
||||
if event.is_final:
|
||||
self._close_stream(event.selector)
|
||||
return self.try_flush()
|
||||
else:
|
||||
# Skip cause we share the same variable pool.
|
||||
#
|
||||
# for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
# self._variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
return self.try_flush()
|
||||
|
||||
def _create_stream_chunk_event(
|
||||
self,
|
||||
node_id: str,
|
||||
execution_id: str,
|
||||
selector: Sequence[str],
|
||||
chunk: str,
|
||||
is_final: bool = False,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Create a stream chunk event with consistent structure.
|
||||
|
||||
For selectors with special prefixes (sys, env, conversation), we use the
|
||||
active response node's information since these are not actual node IDs.
|
||||
"""
|
||||
# Check if this is a special selector that doesn't correspond to a node
|
||||
if selector and selector[0] not in self._graph.nodes and self._active_session:
|
||||
# Use the active response node for special selectors
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
# Standard case: selector refers to an actual node
|
||||
node = self._graph.nodes[node_id]
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
|
||||
"""Process a variable segment. Returns (events, is_complete).
|
||||
|
||||
Handles both regular node selectors and special system selectors (sys, env, conversation).
|
||||
For special selectors, we attribute the output to the active response node.
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
is_complete = False
|
||||
|
||||
# Determine which node to attribute the output to
|
||||
# For special selectors (sys, env, conversation), use the active response node
|
||||
# For regular selectors, use the source node
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
# Special selector - use active response node
|
||||
output_node_id = self._active_session.node_id
|
||||
else:
|
||||
# Regular node selector
|
||||
output_node_id = source_selector_prefix
|
||||
execution_id = self._get_or_create_execution_id(output_node_id)
|
||||
|
||||
# Stream all available chunks
|
||||
while self._has_unread_stream(segment.selector):
|
||||
if event := self._pop_stream_chunk(segment.selector):
|
||||
# For special selectors, we need to update the event to use
|
||||
# the active response node's information
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
# Create a new event with the response node's information
|
||||
# but keep the original selector
|
||||
updated_event = NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=event.selector, # Keep original selector
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
)
|
||||
events.append(updated_event)
|
||||
else:
|
||||
# Regular node selector - use event as is
|
||||
events.append(event)
|
||||
|
||||
# Check if this is the last chunk by looking ahead
|
||||
stream_closed = self._is_stream_closed(segment.selector)
|
||||
# Check if stream is closed to determine if segment is complete
|
||||
if stream_closed:
|
||||
is_complete = True
|
||||
|
||||
elif value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
)
|
||||
is_complete = True
|
||||
|
||||
return events, is_complete
|
||||
|
||||
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""Process a text segment. Returns (events, is_complete)."""
|
||||
assert self._active_session is not None
|
||||
current_response_node = self._graph.nodes[self._active_session.node_id]
|
||||
|
||||
# Use get_or_create_execution_id to ensure we have a consistent ID
|
||||
execution_id = self._get_or_create_execution_id(current_response_node.id)
|
||||
|
||||
is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
event = self._create_stream_chunk_event(
|
||||
node_id=current_response_node.id,
|
||||
execution_id=execution_id,
|
||||
selector=[current_response_node.id, "answer"], # FIXME(-LAN-)
|
||||
chunk=segment.text,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
return [event]
|
||||
|
||||
def try_flush(self) -> list[NodeRunStreamChunkEvent]:
|
||||
with self._lock:
|
||||
if not self._active_session:
|
||||
return []
|
||||
|
||||
template = self._active_session.template
|
||||
response_node_id = self._active_session.node_id
|
||||
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Process segments sequentially from current index
|
||||
while self._active_session.index < len(template.segments):
|
||||
segment = template.segments[self._active_session.index]
|
||||
|
||||
if isinstance(segment, VariableSegment):
|
||||
# Check if the source node for this variable is skipped
|
||||
# Only check for actual nodes, not special selectors (sys, env, conversation)
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
if source_selector_prefix in self._graph.nodes:
|
||||
source_node = self._graph.nodes[source_selector_prefix]
|
||||
|
||||
if source_node.state == NodeState.SKIPPED:
|
||||
# Skip this variable segment if the source node is skipped
|
||||
self._active_session.index += 1
|
||||
continue
|
||||
|
||||
segment_events, is_complete = self._process_variable_segment(segment)
|
||||
events.extend(segment_events)
|
||||
|
||||
# Only advance index if this variable segment is complete
|
||||
if is_complete:
|
||||
self._active_session.index += 1
|
||||
else:
|
||||
# Wait for more data
|
||||
break
|
||||
|
||||
else:
|
||||
segment_events = self._process_text_segment(segment)
|
||||
events.extend(segment_events)
|
||||
self._active_session.index += 1
|
||||
|
||||
if self._active_session.is_complete():
|
||||
# End current session and get events from starting next session
|
||||
next_session_events = self.end_session(response_node_id)
|
||||
events.extend(next_session_events)
|
||||
|
||||
return events
|
||||
|
||||
def end_session(self, node_id: str) -> list[NodeRunStreamChunkEvent]:
|
||||
"""
|
||||
End the active session for a response node.
|
||||
Automatically starts the next waiting session if available.
|
||||
|
||||
Args:
|
||||
node_id: ID of the response node ending its session
|
||||
|
||||
Returns:
|
||||
List of events from starting the next session
|
||||
"""
|
||||
with self._lock:
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
if self._active_session and self._active_session.node_id == node_id:
|
||||
self._active_session = None
|
||||
|
||||
# Try to start next waiting session
|
||||
if self._waiting_sessions:
|
||||
next_session = self._waiting_sessions.popleft()
|
||||
self._active_session = next_session
|
||||
|
||||
# Immediately try to flush any available segments
|
||||
events = self.try_flush()
|
||||
|
||||
return events
|
||||
|
||||
# ============= Internal Stream Management Methods =============
|
||||
|
||||
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Append a stream chunk to the internal buffer.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
event: The NodeRunStreamChunkEvent to append
|
||||
|
||||
Raises:
|
||||
ValueError: If the stream is already closed
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key in self._closed_streams:
|
||||
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
self._stream_buffers[key] = []
|
||||
self._stream_positions[key] = 0
|
||||
|
||||
self._stream_buffers[key].append(event)
|
||||
|
||||
def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None:
|
||||
"""
|
||||
Pop the next unread stream chunk from the buffer.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
The next event, or None if no unread events available
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
return None
|
||||
|
||||
position = self._stream_positions.get(key, 0)
|
||||
buffer = self._stream_buffers[key]
|
||||
|
||||
if position >= len(buffer):
|
||||
return None
|
||||
|
||||
event = buffer[position]
|
||||
self._stream_positions[key] = position + 1
|
||||
return event
|
||||
|
||||
def _has_unread_stream(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if the stream has unread events.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if there are unread events, False otherwise
|
||||
"""
|
||||
key = tuple(selector)
|
||||
|
||||
if key not in self._stream_buffers:
|
||||
return False
|
||||
|
||||
position = self._stream_positions.get(key, 0)
|
||||
return position < len(self._stream_buffers[key])
|
||||
|
||||
def _close_stream(self, selector: Sequence[str]) -> None:
|
||||
"""
|
||||
Mark a stream as closed (no more chunks can be appended).
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
"""
|
||||
key = tuple(selector)
|
||||
self._closed_streams.add(key)
|
||||
|
||||
def _is_stream_closed(self, selector: Sequence[str]) -> bool:
|
||||
"""
|
||||
Check if a stream is closed.
|
||||
|
||||
Args:
|
||||
selector: List of strings identifying the stream location
|
||||
|
||||
Returns:
|
||||
True if the stream is closed, False otherwise
|
||||
"""
|
||||
key = tuple(selector)
|
||||
return key in self._closed_streams
|
||||
|
||||
def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None:
|
||||
"""Convert an in-memory session into its serializable form."""
|
||||
|
||||
if session is None:
|
||||
return None
|
||||
return ResponseSessionState(node_id=session.node_id, index=session.index)
|
||||
|
||||
def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession:
|
||||
"""Rebuild a response session from serialized data."""
|
||||
|
||||
node = self._graph.nodes.get(session_state.node_id)
|
||||
if node is None:
|
||||
raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state")
|
||||
|
||||
session = ResponseSession.from_node(node)
|
||||
session.index = session_state.index
|
||||
return session
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize coordinator state to JSON."""
|
||||
|
||||
with self._lock:
|
||||
state = ResponseStreamCoordinatorState(
|
||||
response_nodes=sorted(self._response_nodes),
|
||||
active_session=self._serialize_session(self._active_session),
|
||||
waiting_sessions=[
|
||||
session_state
|
||||
for session in list(self._waiting_sessions)
|
||||
if (session_state := self._serialize_session(session)) is not None
|
||||
],
|
||||
pending_sessions=[
|
||||
session_state
|
||||
for _, session in sorted(self._response_sessions.items())
|
||||
if (session_state := self._serialize_session(session)) is not None
|
||||
],
|
||||
node_execution_ids=dict(sorted(self._node_execution_ids.items())),
|
||||
paths_map={
|
||||
node_id: [path.edges.copy() for path in paths]
|
||||
for node_id, paths in sorted(self._paths_maps.items())
|
||||
},
|
||||
stream_buffers=[
|
||||
StreamBufferState(
|
||||
selector=selector,
|
||||
events=[event.model_copy(deep=True) for event in events],
|
||||
)
|
||||
for selector, events in sorted(self._stream_buffers.items())
|
||||
],
|
||||
stream_positions=[
|
||||
StreamPositionState(selector=selector, position=position)
|
||||
for selector, position in sorted(self._stream_positions.items())
|
||||
],
|
||||
closed_streams=sorted(self._closed_streams),
|
||||
)
|
||||
return state.model_dump_json()
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore coordinator state from JSON."""
|
||||
|
||||
state = ResponseStreamCoordinatorState.model_validate_json(data)
|
||||
|
||||
if state.type != "ResponseStreamCoordinator":
|
||||
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||
|
||||
if state.version != "1.0":
|
||||
raise ValueError(f"Unsupported serialized version: {state.version}")
|
||||
|
||||
with self._lock:
|
||||
self._response_nodes = set(state.response_nodes)
|
||||
self._paths_maps = {
|
||||
node_id: [Path(edges=list(path_edges)) for path_edges in paths]
|
||||
for node_id, paths in state.paths_map.items()
|
||||
}
|
||||
self._node_execution_ids = dict(state.node_execution_ids)
|
||||
|
||||
self._stream_buffers = {
|
||||
tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events]
|
||||
for buffer in state.stream_buffers
|
||||
}
|
||||
self._stream_positions = {
|
||||
tuple(position.selector): position.position for position in state.stream_positions
|
||||
}
|
||||
for selector in self._stream_buffers:
|
||||
self._stream_positions.setdefault(selector, 0)
|
||||
|
||||
self._closed_streams = {tuple(selector) for selector in state.closed_streams}
|
||||
|
||||
self._waiting_sessions = deque(
|
||||
self._session_from_state(session_state) for session_state in state.waiting_sessions
|
||||
)
|
||||
self._response_sessions = {
|
||||
session_state.node_id: self._session_from_state(session_state)
|
||||
for session_state in state.pending_sessions
|
||||
}
|
||||
self._active_session = self._session_from_state(state.active_session) if state.active_session else None
|
||||
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Internal path representation for response coordinator.
|
||||
|
||||
This module contains the private Path class used internally by ResponseStreamCoordinator
|
||||
to track execution paths to response nodes.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeAlias
|
||||
|
||||
EdgeID: TypeAlias = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Path:
|
||||
"""
|
||||
Represents a path of branch edges that must be taken to reach a response node.
|
||||
|
||||
Note: This is an internal class not exposed in the public API.
|
||||
"""
|
||||
|
||||
edges: list[EdgeID] = field(default_factory=list[EdgeID])
|
||||
|
||||
def contains_edge(self, edge_id: EdgeID) -> bool:
|
||||
"""Check if this path contains the given edge."""
|
||||
return edge_id in self.edges
|
||||
|
||||
def remove_edge(self, edge_id: EdgeID) -> None:
|
||||
"""Remove the given edge from this path in place."""
|
||||
if self.contains_edge(edge_id):
|
||||
self.edges.remove(edge_id)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if the path has no edges (node is reachable)."""
|
||||
return len(self.edges) == 0
|
||||
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Internal response session management for response coordinator.
|
||||
|
||||
This module contains the private ResponseSession class used internally
|
||||
by ResponseStreamCoordinator to manage streaming sessions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseSession:
|
||||
"""
|
||||
Represents an active response streaming session.
|
||||
|
||||
Note: This is an internal class not exposed in the public API.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
template: Template # Template object from the response node
|
||||
index: int = 0 # Current position in the template segments
|
||||
|
||||
@classmethod
|
||||
def from_node(cls, node: Node) -> "ResponseSession":
|
||||
"""
|
||||
Create a ResponseSession from an AnswerNode or EndNode.
|
||||
|
||||
Args:
|
||||
node: Must be either an AnswerNode or EndNode instance
|
||||
|
||||
Returns:
|
||||
ResponseSession configured with the node's streaming template
|
||||
|
||||
Raises:
|
||||
TypeError: If node is not an AnswerNode or EndNode
|
||||
"""
|
||||
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
|
||||
raise TypeError
|
||||
return cls(
|
||||
node_id=node.id,
|
||||
template=node.get_streaming_template(),
|
||||
)
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if all segments in the template have been processed."""
|
||||
return self.index >= len(self.template.segments)
|
||||
141
dify/api/core/workflow/graph_engine/worker.py
Normal file
141
dify/api/core/workflow/graph_engine/worker.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Worker - Thread implementation for queue-based node execution
|
||||
|
||||
Workers pull node IDs from the ready_queue, execute nodes, and push events
|
||||
to the event_queue for the dispatcher to process.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import final
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import Flask
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
|
||||
@final
|
||||
class Worker(threading.Thread):
|
||||
"""
|
||||
Worker thread that executes nodes from the ready queue.
|
||||
|
||||
Workers continuously pull node IDs from the ready_queue, execute the
|
||||
corresponding nodes, and push the resulting events to the event_queue
|
||||
for the dispatcher to process.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ready_queue: ReadyQueue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
worker_id: int = 0,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize worker thread.
|
||||
|
||||
Args:
|
||||
ready_queue: Ready queue containing node IDs ready for execution
|
||||
event_queue: Queue for pushing execution events
|
||||
graph: Graph containing nodes to execute
|
||||
worker_id: Unique identifier for this worker
|
||||
flask_app: Optional Flask application for context preservation
|
||||
context_vars: Optional context variables to preserve in worker thread
|
||||
"""
|
||||
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._worker_id = worker_id
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._stop_event = threading.Event()
|
||||
self._last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
"""Check if the worker is currently idle."""
|
||||
# Worker is idle if it hasn't processed a task recently (within 0.2 seconds)
|
||||
return (time.time() - self._last_task_time) > 0.2
|
||||
|
||||
@property
|
||||
def idle_duration(self) -> float:
|
||||
"""Get the duration in seconds since the worker last processed a task."""
|
||||
return time.time() - self._last_task_time
|
||||
|
||||
@property
|
||||
def worker_id(self) -> int:
|
||||
"""Get the worker's ID."""
|
||||
return self._worker_id
|
||||
|
||||
@override
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Main worker loop.
|
||||
|
||||
Continuously pulls node IDs from ready_queue, executes them,
|
||||
and pushes events to event_queue until stopped.
|
||||
"""
|
||||
while not self._stop_event.is_set():
|
||||
# Try to get a node ID from the ready queue (with timeout)
|
||||
try:
|
||||
node_id = self._ready_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
self._last_task_time = time.time()
|
||||
node = self._graph.nodes[node_id]
|
||||
try:
|
||||
self._execute_node(node)
|
||||
self._ready_queue.task_done()
|
||||
except Exception as e:
|
||||
error_event = NodeRunFailedEvent(
|
||||
id=str(uuid4()),
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
in_iteration_id=None,
|
||||
error=str(e),
|
||||
start_at=datetime.now(),
|
||||
)
|
||||
self._event_queue.put(error_event)
|
||||
|
||||
def _execute_node(self, node: Node) -> None:
|
||||
"""
|
||||
Execute a single node and handle its events.
|
||||
|
||||
Args:
|
||||
node: The node instance to execute
|
||||
"""
|
||||
# Execute the node with preserved context if Flask app is provided
|
||||
if self._flask_app and self._context_vars:
|
||||
with preserve_flask_contexts(
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
):
|
||||
# Execute the node
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
# Forward event to dispatcher immediately for streaming
|
||||
self._event_queue.put(event)
|
||||
else:
|
||||
# Execute without context preservation
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
# Forward event to dispatcher immediately for streaming
|
||||
self._event_queue.put(event)
|
||||
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Worker management subsystem for graph engine.
|
||||
|
||||
This package manages the worker pool, including creation,
|
||||
scaling, and activity tracking.
|
||||
"""
|
||||
|
||||
from .worker_pool import WorkerPool
|
||||
|
||||
__all__ = [
|
||||
"WorkerPool",
|
||||
]
|
||||
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
Simple worker pool that consolidates functionality.
|
||||
|
||||
This is a simpler implementation that merges WorkerPool, ActivityTracker,
|
||||
DynamicScaler, and WorkerFactory into a single class.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
|
||||
from ..ready_queue import ReadyQueue
|
||||
from ..worker import Worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from contextvars import Context
|
||||
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@final
|
||||
class WorkerPool:
|
||||
"""
|
||||
Simple worker pool with integrated management.
|
||||
|
||||
This class consolidates all worker management functionality into
|
||||
a single, simpler implementation without excessive abstraction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ready_queue: ReadyQueue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
min_workers: int | None = None,
|
||||
max_workers: int | None = None,
|
||||
scale_up_threshold: int | None = None,
|
||||
scale_down_idle_time: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the simple worker pool.
|
||||
|
||||
Args:
|
||||
ready_queue: Ready queue for nodes ready for execution
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
flask_app: Optional Flask app for context preservation
|
||||
context_vars: Optional context variables
|
||||
min_workers: Minimum number of workers
|
||||
max_workers: Maximum number of workers
|
||||
scale_up_threshold: Queue depth to trigger scale up
|
||||
scale_down_idle_time: Seconds before scaling down idle workers
|
||||
"""
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
|
||||
# Scaling parameters with defaults
|
||||
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
|
||||
self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
|
||||
self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
|
||||
self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
|
||||
|
||||
# Worker management
|
||||
self._workers: list[Worker] = []
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
|
||||
# No longer tracking worker states with callbacks to avoid lock contention
|
||||
|
||||
def start(self, initial_count: int | None = None) -> None:
|
||||
"""
|
||||
Start the worker pool.
|
||||
|
||||
Args:
|
||||
initial_count: Number of workers to start with (auto-calculated if None)
|
||||
"""
|
||||
with self._lock:
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Calculate initial worker count
|
||||
if initial_count is None:
|
||||
node_count = len(self._graph.nodes)
|
||||
if node_count < 10:
|
||||
initial_count = self._min_workers
|
||||
elif node_count < 50:
|
||||
initial_count = min(self._min_workers + 1, self._max_workers)
|
||||
else:
|
||||
initial_count = min(self._min_workers + 2, self._max_workers)
|
||||
|
||||
logger.debug(
|
||||
"Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)",
|
||||
initial_count,
|
||||
node_count,
|
||||
self._min_workers,
|
||||
self._max_workers,
|
||||
)
|
||||
|
||||
# Create initial workers
|
||||
for _ in range(initial_count):
|
||||
self._create_worker()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop all workers in the pool."""
|
||||
with self._lock:
|
||||
self._running = False
|
||||
worker_count = len(self._workers)
|
||||
|
||||
if worker_count > 0:
|
||||
logger.debug("Stopping worker pool: %d workers", worker_count)
|
||||
|
||||
# Stop all workers
|
||||
for worker in self._workers:
|
||||
worker.stop()
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self._workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=10.0)
|
||||
|
||||
self._workers.clear()
|
||||
|
||||
def _create_worker(self) -> None:
|
||||
"""Create and start a new worker."""
|
||||
worker_id = self._worker_counter
|
||||
self._worker_counter += 1
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
worker_id=worker_id,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
self._workers.append(worker)
|
||||
|
||||
def _remove_worker(self, worker: Worker, worker_id: int) -> None:
|
||||
"""Remove a specific worker from the pool."""
|
||||
# Stop the worker
|
||||
worker.stop()
|
||||
|
||||
# Wait for it to finish
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=2.0)
|
||||
|
||||
# Remove from list
|
||||
if worker in self._workers:
|
||||
self._workers.remove(worker)
|
||||
|
||||
def _try_scale_up(self, queue_depth: int, current_count: int) -> bool:
|
||||
"""
|
||||
Try to scale up workers if needed.
|
||||
|
||||
Args:
|
||||
queue_depth: Current queue depth
|
||||
current_count: Current number of workers
|
||||
|
||||
Returns:
|
||||
True if scaled up, False otherwise
|
||||
"""
|
||||
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
|
||||
old_count = current_count
|
||||
self._create_worker()
|
||||
|
||||
logger.debug(
|
||||
"Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)",
|
||||
old_count,
|
||||
len(self._workers),
|
||||
queue_depth,
|
||||
self._scale_up_threshold,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool:
|
||||
"""
|
||||
Try to scale down workers if we have excess capacity.
|
||||
|
||||
Args:
|
||||
queue_depth: Current queue depth
|
||||
current_count: Current number of workers
|
||||
active_count: Number of active workers
|
||||
idle_count: Number of idle workers
|
||||
|
||||
Returns:
|
||||
True if scaled down, False otherwise
|
||||
"""
|
||||
# Skip if we're at minimum or have no idle workers
|
||||
if current_count <= self._min_workers or idle_count == 0:
|
||||
return False
|
||||
|
||||
# Check if we have excess capacity
|
||||
has_excess_capacity = (
|
||||
queue_depth <= active_count # Active workers can handle current queue
|
||||
or idle_count > active_count # More idle than active workers
|
||||
or (queue_depth == 0 and idle_count > 0) # No work and have idle workers
|
||||
)
|
||||
|
||||
if not has_excess_capacity:
|
||||
return False
|
||||
|
||||
# Find and remove idle workers that have been idle long enough
|
||||
workers_to_remove: list[tuple[Worker, int]] = []
|
||||
|
||||
for worker in self._workers:
|
||||
# Check if worker is idle and has exceeded idle time threshold
|
||||
if worker.is_idle and worker.idle_duration >= self._scale_down_idle_time:
|
||||
# Don't remove if it would leave us unable to handle the queue
|
||||
remaining_workers = current_count - len(workers_to_remove) - 1
|
||||
if remaining_workers >= self._min_workers and remaining_workers >= max(1, queue_depth // 2):
|
||||
workers_to_remove.append((worker, worker.worker_id))
|
||||
# Only remove one worker per check to avoid aggressive scaling
|
||||
break
|
||||
|
||||
# Remove idle workers if any found
|
||||
if workers_to_remove:
|
||||
old_count = current_count
|
||||
for worker, worker_id in workers_to_remove:
|
||||
self._remove_worker(worker, worker_id)
|
||||
|
||||
logger.debug(
|
||||
"Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, "
|
||||
"queue_depth=%d, active=%d, idle=%d)",
|
||||
old_count,
|
||||
len(self._workers),
|
||||
len(workers_to_remove),
|
||||
self._scale_down_idle_time,
|
||||
queue_depth,
|
||||
active_count,
|
||||
idle_count - len(workers_to_remove),
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def check_and_scale(self) -> None:
|
||||
"""Check and perform scaling if needed."""
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
current_count = len(self._workers)
|
||||
queue_depth = self._ready_queue.qsize()
|
||||
|
||||
# Count active vs idle workers by querying their state directly
|
||||
idle_count = sum(1 for worker in self._workers if worker.is_idle)
|
||||
active_count = current_count - idle_count
|
||||
|
||||
# Try to scale up if queue is backing up
|
||||
self._try_scale_up(queue_depth, current_count)
|
||||
|
||||
# Try to scale down if we have excess capacity
|
||||
self._try_scale_down(queue_depth, current_count, active_count, idle_count)
|
||||
|
||||
def get_worker_count(self) -> int:
|
||||
"""Get current number of workers."""
|
||||
with self._lock:
|
||||
return len(self._workers)
|
||||
|
||||
def get_status(self) -> dict[str, int]:
|
||||
"""
|
||||
Get pool status information.
|
||||
|
||||
Returns:
|
||||
Dictionary with status information
|
||||
"""
|
||||
with self._lock:
|
||||
return {
|
||||
"total_workers": len(self._workers),
|
||||
"queue_depth": self._ready_queue.qsize(),
|
||||
"min_workers": self._min_workers,
|
||||
"max_workers": self._max_workers,
|
||||
}
|
||||
76
dify/api/core/workflow/graph_events/__init__.py
Normal file
76
dify/api/core/workflow/graph_events/__init__.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Agent events
|
||||
from .agent import NodeRunAgentLogEvent
|
||||
|
||||
# Base events
|
||||
from .base import (
|
||||
BaseGraphEvent,
|
||||
GraphEngineEvent,
|
||||
GraphNodeEventBase,
|
||||
)
|
||||
|
||||
# Graph events
|
||||
from .graph import (
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
|
||||
# Iteration events
|
||||
from .iteration import (
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
)
|
||||
|
||||
# Loop events
|
||||
from .loop import (
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
)
|
||||
|
||||
# Node events
|
||||
from .node import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseGraphEvent",
|
||||
"GraphEngineEvent",
|
||||
"GraphNodeEventBase",
|
||||
"GraphRunAbortedEvent",
|
||||
"GraphRunFailedEvent",
|
||||
"GraphRunPartialSucceededEvent",
|
||||
"GraphRunPausedEvent",
|
||||
"GraphRunStartedEvent",
|
||||
"GraphRunSucceededEvent",
|
||||
"NodeRunAgentLogEvent",
|
||||
"NodeRunExceptionEvent",
|
||||
"NodeRunFailedEvent",
|
||||
"NodeRunIterationFailedEvent",
|
||||
"NodeRunIterationNextEvent",
|
||||
"NodeRunIterationStartedEvent",
|
||||
"NodeRunIterationSucceededEvent",
|
||||
"NodeRunLoopFailedEvent",
|
||||
"NodeRunLoopNextEvent",
|
||||
"NodeRunLoopStartedEvent",
|
||||
"NodeRunLoopSucceededEvent",
|
||||
"NodeRunPauseRequestedEvent",
|
||||
"NodeRunRetrieverResourceEvent",
|
||||
"NodeRunRetryEvent",
|
||||
"NodeRunStartedEvent",
|
||||
"NodeRunStreamChunkEvent",
|
||||
"NodeRunSucceededEvent",
|
||||
]
|
||||
17
dify/api/core/workflow/graph_events/agent.py
Normal file
17
dify/api/core/workflow/graph_events/agent.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import GraphAgentNodeEventBase
|
||||
|
||||
|
||||
class NodeRunAgentLogEvent(GraphAgentNodeEventBase):
|
||||
message_id: str = Field(..., description="message id")
|
||||
label: str = Field(..., description="label")
|
||||
node_execution_id: str = Field(..., description="node execution id")
|
||||
parent_id: str | None = Field(..., description="parent id")
|
||||
error: str | None = Field(..., description="error")
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
31
dify/api/core/workflow/graph_events/base.py
Normal file
31
dify/api/core/workflow/graph_events/base.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
|
||||
class GraphEngineEvent(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class BaseGraphEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphNodeEventBase(GraphEngineEvent):
|
||||
id: str = Field(..., description="node execution id")
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
# The version of the node, or "1" if not specified.
|
||||
node_version: str = "1"
|
||||
node_run_result: NodeRunResult = Field(default_factory=NodeRunResult)
|
||||
|
||||
|
||||
class GraphAgentNodeEventBase(GraphNodeEventBase):
|
||||
pass
|
||||
53
dify/api/core/workflow/graph_events/graph.py
Normal file
53
dify/api/core/workflow/graph_events/graph.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from pydantic import Field
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.graph_events import BaseGraphEvent
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
"""Event emitted when a run completes successfully with final outputs."""
|
||||
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Final workflow outputs keyed by output selector.",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
exceptions_count: int = Field(description="exception count", default=0)
|
||||
|
||||
|
||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||
"""Event emitted when a run finishes with partial success and failures."""
|
||||
|
||||
exceptions_count: int = Field(..., description="exception count")
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs that were materialised before failures occurred.",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunAbortedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is aborted by user command."""
|
||||
|
||||
reason: str | None = Field(default=None, description="reason for abort")
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs produced before the abort was requested.",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunPausedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is paused by user command."""
|
||||
|
||||
# reason: str | None = Field(default=None, description="reason for pause")
|
||||
reason: PauseReason = Field(..., description="reason for pause")
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs available to the client while the run is paused.",
|
||||
)
|
||||
40
dify/api/core/workflow/graph_events/iteration.py
Normal file
40
dify/api/core/workflow/graph_events/iteration.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
|
||||
|
||||
class NodeRunIterationStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class NodeRunIterationNextEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Any = None
|
||||
|
||||
|
||||
class NodeRunIterationSucceededEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class NodeRunIterationFailedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
40
dify/api/core/workflow/graph_events/loop.py
Normal file
40
dify/api/core/workflow/graph_events/loop.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
|
||||
|
||||
class NodeRunLoopStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class NodeRunLoopNextEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
index: int = Field(..., description="index")
|
||||
pre_loop_output: Any = None
|
||||
|
||||
|
||||
class NodeRunLoopSucceededEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class NodeRunLoopFailedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
58
dify/api/core/workflow/graph_events/node.py
Normal file
58
dify/api/core/workflow/graph_events/node.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
|
||||
|
||||
class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
predecessor_node_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode
|
||||
provider_type: str = ""
|
||||
provider_id: str = ""
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
||||
# Spec-compliant fields
|
||||
selector: Sequence[str] = Field(
|
||||
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
|
||||
)
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
class NodeRunSucceededEvent(GraphNodeEventBase):
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
|
||||
|
||||
class NodeRunFailedEvent(GraphNodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
|
||||
|
||||
class NodeRunExceptionEvent(GraphNodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
|
||||
|
||||
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
42
dify/api/core/workflow/node_events/__init__.py
Normal file
42
dify/api/core/workflow/node_events/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from .agent import AgentLogEvent
|
||||
from .base import NodeEventBase, NodeRunResult
|
||||
from .iteration import (
|
||||
IterationFailedEvent,
|
||||
IterationNextEvent,
|
||||
IterationStartedEvent,
|
||||
IterationSucceededEvent,
|
||||
)
|
||||
from .loop import (
|
||||
LoopFailedEvent,
|
||||
LoopNextEvent,
|
||||
LoopStartedEvent,
|
||||
LoopSucceededEvent,
|
||||
)
|
||||
from .node import (
|
||||
ModelInvokeCompletedEvent,
|
||||
PauseRequestedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunRetryEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentLogEvent",
|
||||
"IterationFailedEvent",
|
||||
"IterationNextEvent",
|
||||
"IterationStartedEvent",
|
||||
"IterationSucceededEvent",
|
||||
"LoopFailedEvent",
|
||||
"LoopNextEvent",
|
||||
"LoopStartedEvent",
|
||||
"LoopSucceededEvent",
|
||||
"ModelInvokeCompletedEvent",
|
||||
"NodeEventBase",
|
||||
"NodeRunResult",
|
||||
"PauseRequestedEvent",
|
||||
"RunRetrieverResourceEvent",
|
||||
"RunRetryEvent",
|
||||
"StreamChunkEvent",
|
||||
"StreamCompletedEvent",
|
||||
]
|
||||
18
dify/api/core/workflow/node_events/agent.py
Normal file
18
dify/api/core/workflow/node_events/agent.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import NodeEventBase
|
||||
|
||||
|
||||
class AgentLogEvent(NodeEventBase):
|
||||
message_id: str = Field(..., description="id")
|
||||
label: str = Field(..., description="label")
|
||||
node_execution_id: str = Field(..., description="node execution id")
|
||||
parent_id: str | None = Field(..., description="parent id")
|
||||
error: str | None = Field(..., description="error")
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
metadata: Mapping[str, Any] = Field(default_factory=dict, description="metadata")
|
||||
node_id: str = Field(..., description="node id")
|
||||
40
dify/api/core/workflow/node_events/base.py
Normal file
40
dify/api/core/workflow/node_events/base.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeEventBase(BaseModel):
|
||||
"""Base class for all node events"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _default_metadata():
|
||||
v: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
return v
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
"""
|
||||
Node Run Result.
|
||||
"""
|
||||
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.PENDING
|
||||
|
||||
inputs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, Any] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=_default_metadata)
|
||||
llm_usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
|
||||
|
||||
edge_source_handle: str = "source" # source handle id of node with multiple branches
|
||||
|
||||
error: str = ""
|
||||
error_type: str = ""
|
||||
|
||||
# single step node run retry
|
||||
retry_index: int = 0
|
||||
36
dify/api/core/workflow/node_events/iteration.py
Normal file
36
dify/api/core/workflow/node_events/iteration.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import NodeEventBase
|
||||
|
||||
|
||||
class IterationStartedEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class IterationNextEvent(NodeEventBase):
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Any = None
|
||||
|
||||
|
||||
class IterationSucceededEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class IterationFailedEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
36
dify/api/core/workflow/node_events/loop.py
Normal file
36
dify/api/core/workflow/node_events/loop.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import NodeEventBase
|
||||
|
||||
|
||||
class LoopStartedEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
|
||||
class LoopNextEvent(NodeEventBase):
|
||||
index: int = Field(..., description="index")
|
||||
pre_loop_output: Any = None
|
||||
|
||||
|
||||
class LoopSucceededEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class LoopFailedEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
47
dify/api/core/workflow/node_events/node.py
Normal file
47
dify/api/core/workflow/node_events/node.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
from .base import NodeEventBase
|
||||
|
||||
|
||||
class RunRetrieverResourceEvent(NodeEventBase):
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
class ModelInvokeCompletedEvent(NodeEventBase):
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
structured_output: dict | None = None
|
||||
|
||||
|
||||
class RunRetryEvent(NodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="Retry attempt number")
|
||||
start_at: datetime = Field(..., description="Retry start time")
|
||||
|
||||
|
||||
class StreamChunkEvent(NodeEventBase):
|
||||
# Spec-compliant fields
|
||||
selector: Sequence[str] = Field(
|
||||
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
|
||||
)
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
|
||||
|
||||
class StreamCompletedEvent(NodeEventBase):
|
||||
node_run_result: NodeRunResult = Field(..., description="run result")
|
||||
|
||||
|
||||
class PauseRequestedEvent(NodeEventBase):
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
3
dify/api/core/workflow/nodes/__init__.py
Normal file
3
dify/api/core/workflow/nodes/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from core.workflow.enums import NodeType
|
||||
|
||||
__all__ = ["NodeType"]
|
||||
3
dify/api/core/workflow/nodes/agent/__init__.py
Normal file
3
dify/api/core/workflow/nodes/agent/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .agent_node import AgentNode
|
||||
|
||||
__all__ = ["AgentNode"]
|
||||
756
dify/api/core/workflow/nodes/agent/agent_node.py
Normal file
756
dify/api/core/workflow/nodes/agent/agent_node.py
Normal file
@@ -0,0 +1,756 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models import ToolFile
|
||||
from models.model import Conversation
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .exc import (
|
||||
AgentInputTypeError,
|
||||
AgentInvocationError,
|
||||
AgentMessageTransformError,
|
||||
AgentNodeError,
|
||||
AgentVariableNotFoundError,
|
||||
AgentVariableTypeError,
|
||||
ToolFileNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
|
||||
class AgentNode(Node):
|
||||
"""
|
||||
Agent Node
|
||||
"""
|
||||
|
||||
node_type = NodeType.AGENT
|
||||
_node_data: AgentNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AgentNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self._node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to get agent strategy: {str(e)}",
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
agent_parameters = strategy.get_parameters()
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self._node_data,
|
||||
strategy=strategy,
|
||||
)
|
||||
parameters_for_log = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self._node_data,
|
||||
for_log=True,
|
||||
strategy=strategy,
|
||||
)
|
||||
credentials = self._generate_credentials(parameters=parameters)
|
||||
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(error),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": self._node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
f"Failed to transform agent message: {str(e)}", original_error=e
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(transform_error),
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_agent_parameters(
|
||||
self,
|
||||
*,
|
||||
agent_parameters: Sequence[AgentStrategyParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
for_log: bool = False,
|
||||
strategy: "PluginAgentStrategy",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
agent_parameters (Sequence[AgentParameter]): The list of agent parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (AgentNodeData): The data associated with the agent node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
if agent_input.type == "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
elif agent_input.type in {"mixed", "constant"}:
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
else:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_type=provider_type,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use the node_data.version field for judgment
|
||||
# But for backward compatibility with historical data
|
||||
# this version field judgment is still preserved here.
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
runtime_variable_pool = variable_pool
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
ToolParameter.ToolParameterForm.FORM
|
||||
if tool_runtime_params.name in manual_input_params
|
||||
else tool_runtime_params.form
|
||||
)
|
||||
manual_input_value = {}
|
||||
if tool_runtime.entity.parameters:
|
||||
manual_input_value = {
|
||||
key: value for key, value in parameters.items() if key in manual_input_params
|
||||
}
|
||||
runtime_parameters = {
|
||||
**tool_runtime.runtime.runtime_parameters,
|
||||
**manual_input_value,
|
||||
}
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = cast(dict[str, Any], value)
|
||||
model_instance, model_schema = self._fetch_model(value)
|
||||
# memory config
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
memory = self._fetch_memory(model_instance)
|
||||
if memory:
|
||||
prompt_messages = memory.get_history_prompt_messages(
|
||||
message_limit=node_data.memory.window.size or None
|
||||
)
|
||||
history_prompt_messages = [
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
]
|
||||
value["history_prompt_messages"] = history_prompt_messages
|
||||
if model_schema:
|
||||
# remove structured output feature to support old version agent plugin
|
||||
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
|
||||
value["entity"] = model_schema.model_dump(mode="json")
|
||||
else:
|
||||
value["entity"] = None
|
||||
result[parameter_name] = value
|
||||
|
||||
return result
|
||||
|
||||
def _generate_credentials(
|
||||
self,
|
||||
parameters: dict[str, Any],
|
||||
) -> "InvokeCredentials":
|
||||
"""
|
||||
Generate credentials based on the given agent parameters.
|
||||
"""
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
credentials = InvokeCredentials()
|
||||
|
||||
# generate credentials for tools selector
|
||||
credentials.tool_credentials = {}
|
||||
for tool in parameters.get("tools", []):
|
||||
if tool.get("credential_id"):
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
except ValidationError:
|
||||
continue
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AgentNodeData.model_validate(node_data)
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
if input.type in ["mixed", "constant"]:
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def agent_strategy_icon(self) -> str | None:
|
||||
"""
|
||||
Get agent strategy icon
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(self.tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
return memory
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM, model=model_name
|
||||
)
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_instance, model_schema
|
||||
|
||||
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
|
||||
if model_schema.features:
|
||||
for feature in model_schema.features[:]: # Create a copy to safely modify during iteration
|
||||
try:
|
||||
AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value
|
||||
except ValueError:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
def _filter_mcp_type_tool(
|
||||
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter MCP type tool
|
||||
:param strategy: plugin agent strategy
|
||||
:param tool: tool
|
||||
:return: filtered tool dict
|
||||
"""
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json_list: list[dict] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise AgentVariableTypeError(
|
||||
"When 'stream' is True, 'variable_value' must be a string.",
|
||||
variable_name=variable_name,
|
||||
expected_type="str",
|
||||
actual_type=type(variable_value).__name__,
|
||||
)
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, dict)
|
||||
# Validate that meta contains a 'file' key
|
||||
if "file" not in message.meta:
|
||||
raise AgentNodeError("File message is missing 'file' key in meta")
|
||||
|
||||
# Validate that the file is an instance of File
|
||||
if not isinstance(message.meta["file"], File):
|
||||
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
message_id=message.message.id,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.message_id == agent_log.message_id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.message_id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json_list:
|
||||
json_output.extend(json_list)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
# Send final chunk events for all streamed outputs
|
||||
# Final chunk for text stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Final chunks for any streamed variables
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"text": text,
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
45
dify/api/core/workflow/nodes/agent/entities.py
Normal file
45
dify/api/core/workflow/nodes/agent/entities.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
class AgentNodeData(BaseNodeData):
|
||||
agent_strategy_provider_name: str # redundancy
|
||||
agent_strategy_name: str
|
||||
agent_strategy_label: str # redundancy
|
||||
memory: MemoryConfig | None = None
|
||||
# The version of the tool parameter.
|
||||
# If this value is None, it indicates this is a previous version
|
||||
# and requires using the legacy parameter parsing rules.
|
||||
tool_node_version: str | None = None
|
||||
|
||||
class AgentInput(BaseModel):
|
||||
value: Union[list[str], list[ToolSelector], Any]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
|
||||
agent_parameters: dict[str, AgentInput]
|
||||
|
||||
|
||||
class ParamsAutoGenerated(IntEnum):
|
||||
CLOSE = 0
|
||||
OPEN = 1
|
||||
|
||||
|
||||
class AgentOldVersionModelFeatures(StrEnum):
|
||||
"""
|
||||
Enum class for old SDK version llm feature.
|
||||
"""
|
||||
|
||||
TOOL_CALL = "tool-call"
|
||||
MULTI_TOOL_CALL = "multi-tool-call"
|
||||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = auto()
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = auto()
|
||||
VIDEO = auto()
|
||||
AUDIO = auto()
|
||||
121
dify/api/core/workflow/nodes/agent/exc.py
Normal file
121
dify/api/core/workflow/nodes/agent/exc.py
Normal file
@@ -0,0 +1,121 @@
|
||||
class AgentNodeError(Exception):
|
||||
"""Base exception for all agent node errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class AgentStrategyError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent strategy."""
|
||||
|
||||
def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None):
|
||||
self.strategy_name = strategy_name
|
||||
self.provider_name = provider_name
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentStrategyNotFoundError(AgentStrategyError):
|
||||
"""Exception raised when the specified agent strategy is not found."""
|
||||
|
||||
def __init__(self, strategy_name: str, provider_name: str | None = None):
|
||||
super().__init__(
|
||||
f"Agent strategy '{strategy_name}' not found"
|
||||
+ (f" for provider '{provider_name}'" if provider_name else ""),
|
||||
strategy_name,
|
||||
provider_name,
|
||||
)
|
||||
|
||||
|
||||
class AgentInvocationError(AgentNodeError):
|
||||
"""Exception raised when there's an error invoking the agent."""
|
||||
|
||||
def __init__(self, message: str, original_error: Exception | None = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentParameterError(AgentNodeError):
|
||||
"""Exception raised when there's an error with agent parameters."""
|
||||
|
||||
def __init__(self, message: str, parameter_name: str | None = None):
|
||||
self.parameter_name = parameter_name
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentVariableError(AgentNodeError):
|
||||
"""Exception raised when there's an error with variables in the agent node."""
|
||||
|
||||
def __init__(self, message: str, variable_name: str | None = None):
|
||||
self.variable_name = variable_name
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentVariableNotFoundError(AgentVariableError):
|
||||
"""Exception raised when a variable is not found in the variable pool."""
|
||||
|
||||
def __init__(self, variable_name: str):
|
||||
super().__init__(f"Variable '{variable_name}' does not exist", variable_name)
|
||||
|
||||
|
||||
class AgentInputTypeError(AgentNodeError):
|
||||
"""Exception raised when an unknown agent input type is encountered."""
|
||||
|
||||
def __init__(self, input_type: str):
|
||||
super().__init__(f"Unknown agent input type '{input_type}'")
|
||||
|
||||
|
||||
class ToolFileError(AgentNodeError):
|
||||
"""Exception raised when there's an error with a tool file."""
|
||||
|
||||
def __init__(self, message: str, file_id: str | None = None):
|
||||
self.file_id = file_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ToolFileNotFoundError(ToolFileError):
|
||||
"""Exception raised when a tool file is not found."""
|
||||
|
||||
def __init__(self, file_id: str):
|
||||
super().__init__(f"Tool file '{file_id}' does not exist", file_id)
|
||||
|
||||
|
||||
class AgentMessageTransformError(AgentNodeError):
|
||||
"""Exception raised when there's an error transforming agent messages."""
|
||||
|
||||
def __init__(self, message: str, original_error: Exception | None = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentModelError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the model used by the agent."""
|
||||
|
||||
def __init__(self, message: str, model_name: str | None = None, provider: str | None = None):
|
||||
self.model_name = model_name
|
||||
self.provider = provider
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentMemoryError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent's memory."""
|
||||
|
||||
def __init__(self, message: str, conversation_id: str | None = None):
|
||||
self.conversation_id = conversation_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentVariableTypeError(AgentNodeError):
|
||||
"""Exception raised when a variable has an unexpected type."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
variable_name: str | None = None,
|
||||
expected_type: str | None = None,
|
||||
actual_type: str | None = None,
|
||||
):
|
||||
self.variable_name = variable_name
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
super().__init__(message)
|
||||
0
dify/api/core/workflow/nodes/answer/__init__.py
Normal file
0
dify/api/core/workflow/nodes/answer/__init__.py
Normal file
96
dify/api/core/workflow/nodes/answer/answer_node.py
Normal file
96
dify/api/core/workflow/nodes/answer/answer_node.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.variables import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.answer.entities import AnswerNodeData
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerNode(Node):
|
||||
node_type = NodeType.ANSWER
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
_node_data: AnswerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AnswerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer)
|
||||
files = self._extract_files_from_segments(segments.value)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)},
|
||||
)
|
||||
|
||||
def _extract_files_from_segments(self, segments: Sequence[Segment]):
|
||||
"""Extract all files from segments containing FileSegment or ArrayFileSegment instances.
|
||||
|
||||
FileSegment contains a single file, while ArrayFileSegment contains multiple files.
|
||||
This method flattens all files into a single list.
|
||||
"""
|
||||
files = []
|
||||
for segment in segments:
|
||||
if isinstance(segment, FileSegment):
|
||||
# Single file - wrap in list for consistency
|
||||
files.append(segment.value)
|
||||
elif isinstance(segment, ArrayFileSegment):
|
||||
# Multiple files - extend the list
|
||||
files.extend(segment.value)
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AnswerNodeData.model_validate(node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
|
||||
def get_streaming_template(self) -> Template:
|
||||
"""
|
||||
Get the template for streaming.
|
||||
|
||||
Returns:
|
||||
Template instance for this Answer node
|
||||
"""
|
||||
return Template.from_answer_template(self._node_data.answer)
|
||||
65
dify/api/core/workflow/nodes/answer/entities.py
Normal file
65
dify/api/core/workflow/nodes/answer/entities.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class AnswerNodeData(BaseNodeData):
|
||||
"""
|
||||
Answer Node Data.
|
||||
"""
|
||||
|
||||
answer: str = Field(..., description="answer template string")
|
||||
|
||||
|
||||
class GenerateRouteChunk(BaseModel):
|
||||
"""
|
||||
Generate Route Chunk.
|
||||
"""
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
VAR = auto()
|
||||
TEXT = auto()
|
||||
|
||||
type: ChunkType = Field(..., description="generate route chunk type")
|
||||
|
||||
|
||||
class VarGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Var Generate Route Chunk.
|
||||
"""
|
||||
|
||||
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
|
||||
"""generate route chunk type"""
|
||||
value_selector: Sequence[str] = Field(..., description="value selector")
|
||||
|
||||
|
||||
class TextGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Text Generate Route Chunk.
|
||||
"""
|
||||
|
||||
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT
|
||||
"""generate route chunk type"""
|
||||
text: str = Field(..., description="text")
|
||||
|
||||
|
||||
class AnswerNodeDoubleLink(BaseModel):
|
||||
node_id: str = Field(..., description="node id")
|
||||
source_node_ids: list[str] = Field(..., description="source node ids")
|
||||
target_node_ids: list[str] = Field(..., description="target node ids")
|
||||
|
||||
|
||||
class AnswerStreamGenerateRoute(BaseModel):
|
||||
"""
|
||||
AnswerStreamGenerateRoute entity
|
||||
"""
|
||||
|
||||
answer_dependencies: dict[str, list[str]] = Field(
|
||||
..., description="answer dependencies (answer node id -> dependent answer node ids)"
|
||||
)
|
||||
answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
|
||||
..., description="answer generate route (answer node id -> generate route chunks)"
|
||||
)
|
||||
11
dify/api/core/workflow/nodes/base/__init__.py
Normal file
11
dify/api/core/workflow/nodes/base/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||
|
||||
__all__ = [
|
||||
"BaseIterationNodeData",
|
||||
"BaseIterationState",
|
||||
"BaseLoopNodeData",
|
||||
"BaseLoopState",
|
||||
"BaseNodeData",
|
||||
"LLMUsageTrackingMixin",
|
||||
]
|
||||
171
dify/api/core/workflow/nodes/base/entities.py
Normal file
171
dify/api/core/workflow/nodes/base/entities.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.workflow.enums import ErrorStrategy
|
||||
|
||||
from .exc import DefaultValueTypeError
|
||||
|
||||
_NumberType = Union[int, float]
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
"""
|
||||
Variable Selector.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_selector: Sequence[str]
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILES = "array[file]"
|
||||
|
||||
|
||||
class DefaultValue(BaseModel):
|
||||
value: Any = None
|
||||
type: DefaultValueType
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
"""Unified number conversion handler"""
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> "DefaultValue":
|
||||
# Type validation configuration
|
||||
type_validators = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
},
|
||||
DefaultValueType.NUMBER: {
|
||||
"type": _NumberType,
|
||||
"converter": self._convert_number,
|
||||
},
|
||||
DefaultValueType.OBJECT: {
|
||||
"type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_NUMBER: {
|
||||
"type": list,
|
||||
"element_type": _NumberType,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_STRING: {
|
||||
"type": list,
|
||||
"element_type": str,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_OBJECT: {
|
||||
"type": list,
|
||||
"element_type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
}
|
||||
|
||||
validator: dict[str, Any] = type_validators.get(self.type, {})
|
||||
if not validator:
|
||||
if self.type == DefaultValueType.ARRAY_FILES:
|
||||
# Handle files type
|
||||
return self
|
||||
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
|
||||
|
||||
# Handle string input cases
|
||||
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
|
||||
self.value = validator["converter"](self.value)
|
||||
|
||||
# Validate base type
|
||||
if not isinstance(self.value, validator["type"]):
|
||||
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
|
||||
|
||||
# Validate array element types
|
||||
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
|
||||
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: str | None = None
|
||||
version: str = "1"
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
iteration_node_id: str
|
||||
index: int
|
||||
inputs: dict
|
||||
|
||||
class MetaData(BaseModel):
|
||||
pass
|
||||
|
||||
metadata: MetaData
|
||||
|
||||
|
||||
class BaseLoopNodeData(BaseNodeData):
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
class BaseLoopState(BaseModel):
|
||||
loop_node_id: str
|
||||
index: int
|
||||
inputs: dict
|
||||
|
||||
class MetaData(BaseModel):
|
||||
pass
|
||||
|
||||
metadata: MetaData
|
||||
10
dify/api/core/workflow/nodes/base/exc.py
Normal file
10
dify/api/core/workflow/nodes/base/exc.py
Normal file
@@ -0,0 +1,10 @@
|
||||
class BaseNodeError(ValueError):
|
||||
"""Base class for node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DefaultValueTypeError(BaseNodeError):
|
||||
"""Raised when the default value type is invalid."""
|
||||
|
||||
pass
|
||||
538
dify/api/core/workflow/nodes/base/node.py
Normal file
538
dify/api/core/workflow/nodes/base/node.py
Normal file
@@ -0,0 +1,538 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from typing import Any, ClassVar
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
AgentLogEvent,
|
||||
IterationFailedEvent,
|
||||
IterationNextEvent,
|
||||
IterationStartedEvent,
|
||||
IterationSucceededEvent,
|
||||
LoopFailedEvent,
|
||||
LoopNextEvent,
|
||||
LoopStartedEvent,
|
||||
LoopSucceededEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
PauseRequestedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import UserFrom
|
||||
|
||||
from .entities import BaseNodeData, RetryConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Node:
|
||||
node_type: ClassVar["NodeType"]
|
||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
self.workflow_id = graph_init_params.workflow_id
|
||||
self.graph_config = graph_init_params.graph_config
|
||||
self.user_id = graph_init_params.user_id
|
||||
self.user_from = UserFrom(graph_init_params.user_from)
|
||||
self.invoke_from = InvokeFrom(graph_init_params.invoke_from)
|
||||
self.workflow_call_depth = graph_init_params.call_depth
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.state: NodeState = NodeState.UNKNOWN # node execution state
|
||||
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
self._node_id = node_id
|
||||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
@abstractmethod
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
# Generate a single node execution ID to use for all events
|
||||
if not self._node_execution_id:
|
||||
self._node_execution_id = str(uuid4())
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
# Create and push start event with required fields
|
||||
start_event = NodeRunStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.title,
|
||||
in_iteration_id=None,
|
||||
start_at=self._start_at,
|
||||
)
|
||||
|
||||
# === FIXME(-LAN-): Needs to refactor.
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
|
||||
if isinstance(self, ToolNode):
|
||||
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
if isinstance(self, DatasourceNode):
|
||||
plugin_id = getattr(self.get_base_node_data(), "plugin_id", "")
|
||||
provider_name = getattr(self.get_base_node_data(), "provider_name", "")
|
||||
|
||||
start_event.provider_id = f"{plugin_id}/{provider_name}"
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
|
||||
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
|
||||
if isinstance(self, TriggerEventNode):
|
||||
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
|
||||
from typing import cast
|
||||
|
||||
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData
|
||||
|
||||
if isinstance(self, AgentNode):
|
||||
start_event.agent_strategy = AgentNodeStrategyInit(
|
||||
name=cast(AgentNodeData, self.get_base_node_data()).agent_strategy_name,
|
||||
icon=self.agent_strategy_icon,
|
||||
)
|
||||
|
||||
# ===
|
||||
yield start_event
|
||||
|
||||
try:
|
||||
result = self._run()
|
||||
|
||||
# Handle NodeRunResult
|
||||
if isinstance(result, NodeRunResult):
|
||||
yield self._convert_node_run_result_to_graph_node_event(result)
|
||||
return
|
||||
|
||||
# Handle event stream
|
||||
for event in result:
|
||||
# NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase
|
||||
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
yield self._dispatch(event)
|
||||
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
event.id = self._node_execution_id
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type="WorkflowNodeError",
|
||||
)
|
||||
yield NodeRunFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=result,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
config: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""Extracts references variable selectors from node configuration.
|
||||
|
||||
The `config` parameter represents the configuration for a specific node type and corresponds
|
||||
to the `data` field in the node definition object.
|
||||
|
||||
The returned mapping has the following structure:
|
||||
|
||||
{'1747829548239.#1747829667553.result#': ['1747829667553', 'result']}
|
||||
|
||||
For loop and iteration nodes, the mapping may look like this:
|
||||
|
||||
{
|
||||
"1748332301644.input_selector": ["1748332363630", "result"],
|
||||
"1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"],
|
||||
}
|
||||
|
||||
where `1748332301644` is the ID of the loop / iteration node,
|
||||
and `1748332325079` is the ID of the node inside the loop or iteration node.
|
||||
|
||||
Here, the key consists of two parts: the current node ID (provided as the `node_id`
|
||||
parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector,
|
||||
enclosed in `#` symbols. These two parts are separated by a dot (`.`).
|
||||
|
||||
The value is a list of string representing the variable selector, where the first element is the node ID
|
||||
of the referenced variable, and the second element is the variable name within that node.
|
||||
|
||||
The meaning of the above response is:
|
||||
|
||||
The node with ID `1747829548239` references the variable `result` from the node with
|
||||
ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a
|
||||
reference to the `result` output variable of node `1747829667553`.
|
||||
|
||||
:param graph_config: graph config
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
# Pass raw dict data instead of creating NodeData instance
|
||||
data = cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
|
||||
)
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
return {}
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||
"""
|
||||
Check if this node blocks the output of specific variables.
|
||||
|
||||
This method is used to determine if a node must complete execution before
|
||||
the specified variables can be used in streaming output.
|
||||
|
||||
:param variable_selectors: Set of variable selectors, each as a tuple (e.g., ('conversation', 'str'))
|
||||
:return: True if this node blocks output of any of the specified variables, False otherwise
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def version(cls) -> str:
|
||||
"""`node_version` returns the version of current node type."""
|
||||
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
|
||||
#
|
||||
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
|
||||
# in `api/core/workflow/nodes/__init__.py`.
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return False
|
||||
|
||||
# Abstract methods that subclasses must implement to provide access
|
||||
# to BaseNodeData properties in a type-safe way
|
||||
|
||||
@abstractmethod
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_title(self) -> str:
|
||||
"""Get the node title."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
"""Get the BaseNodeData object for this node."""
|
||||
...
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
def error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
return self._get_error_strategy()
|
||||
|
||||
@property
|
||||
def retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
return self._get_retry_config()
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
"""Get the node title."""
|
||||
return self._get_title()
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
return self._get_description()
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
return self._get_default_value_dict()
|
||||
|
||||
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
|
||||
match result.status:
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
return NodeRunFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=result,
|
||||
error=result.error,
|
||||
)
|
||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return NodeRunSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=result,
|
||||
)
|
||||
case _:
|
||||
raise Exception(f"result status {result.status} not supported")
|
||||
|
||||
@singledispatchmethod
|
||||
def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase:
|
||||
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
|
||||
match event.node_run_result.status:
|
||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return NodeRunSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=event.node_run_result,
|
||||
)
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
return NodeRunFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=event.node_run_result,
|
||||
error=event.node_run_result.error,
|
||||
)
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
f"Node {self._node_id} does not support status {event.node_run_result.status}"
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
|
||||
return NodeRunPauseRequestedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
|
||||
reason=event.reason,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
|
||||
return NodeRunAgentLogEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
message_id=event.message_id,
|
||||
label=event.label,
|
||||
node_execution_id=event.node_execution_id,
|
||||
parent_id=event.parent_id,
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
||||
return NodeRunLoopStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
|
||||
return NodeRunLoopNextEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
index=event.index,
|
||||
pre_loop_output=event.pre_loop_output,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
|
||||
return NodeRunLoopSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
|
||||
return NodeRunLoopFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
|
||||
return NodeRunIterationStartedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
|
||||
return NodeRunIterationNextEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.pre_iteration_output,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
|
||||
return NodeRunIterationSucceededEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
|
||||
return NodeRunIterationFailedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
||||
return NodeRunRetrieverResourceEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
retriever_resources=event.retriever_resources,
|
||||
context=event.context,
|
||||
node_version=self.version(),
|
||||
)
|
||||
148
dify/api/core/workflow/nodes/base/template.py
Normal file
148
dify/api/core/workflow/nodes/base/template.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Template structures for Response nodes (Answer and End).
|
||||
|
||||
This module provides a unified template structure for both Answer and End nodes,
|
||||
similar to SegmentGroup but focused on template representation without values.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TemplateSegment(ABC):
|
||||
"""Base class for template segments."""
|
||||
|
||||
@abstractmethod
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the segment."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextSegment(TemplateSegment):
|
||||
"""A text segment in a template."""
|
||||
|
||||
text: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VariableSegment(TemplateSegment):
|
||||
"""A variable reference segment in a template."""
|
||||
|
||||
selector: Sequence[str]
|
||||
variable_name: str | None = None # Optional variable name for End nodes
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{{#" + ".".join(self.selector) + "#}}"
|
||||
|
||||
|
||||
# Type alias for segments
|
||||
TemplateSegmentUnion = Union[TextSegment, VariableSegment]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Template:
|
||||
"""Unified template structure for Response nodes.
|
||||
|
||||
Similar to SegmentGroup, but represents the template structure
|
||||
without variable values - only marking variable selectors.
|
||||
"""
|
||||
|
||||
segments: list[TemplateSegmentUnion]
|
||||
|
||||
@classmethod
|
||||
def from_answer_template(cls, template_str: str) -> "Template":
|
||||
"""Create a Template from an Answer node template string.
|
||||
|
||||
Example:
|
||||
"Hello, {{#node1.name#}}" -> [TextSegment("Hello, "), VariableSegment(["node1", "name"])]
|
||||
|
||||
Args:
|
||||
template_str: The answer template string
|
||||
|
||||
Returns:
|
||||
Template instance
|
||||
"""
|
||||
parser = VariableTemplateParser(template_str)
|
||||
segments: list[TemplateSegmentUnion] = []
|
||||
|
||||
# Extract variable selectors to find all variables
|
||||
variable_selectors = parser.extract_variable_selectors()
|
||||
var_map = {var.variable: var.value_selector for var in variable_selectors}
|
||||
|
||||
# Parse template to get ordered segments
|
||||
# We need to split the template by variable placeholders while preserving order
|
||||
import re
|
||||
|
||||
# Create a regex pattern that matches variable placeholders
|
||||
pattern = r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}"
|
||||
|
||||
# Split template while keeping the delimiters (variable placeholders)
|
||||
parts = re.split(pattern, template_str)
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
if not part:
|
||||
continue
|
||||
|
||||
# Check if this part is a variable reference (odd indices after split)
|
||||
if i % 2 == 1: # Odd indices are variable keys
|
||||
# Remove the # symbols from the variable key
|
||||
var_key = part
|
||||
if var_key in var_map:
|
||||
segments.append(VariableSegment(selector=list(var_map[var_key])))
|
||||
else:
|
||||
# This shouldn't happen with valid templates
|
||||
segments.append(TextSegment(text="{{" + part + "}}"))
|
||||
else:
|
||||
# Even indices are text segments
|
||||
segments.append(TextSegment(text=part))
|
||||
|
||||
return cls(segments=segments)
|
||||
|
||||
@classmethod
|
||||
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
|
||||
"""Create a Template from an End node outputs configuration.
|
||||
|
||||
End nodes are treated as templates of concatenated variables with newlines.
|
||||
|
||||
Example:
|
||||
[{"variable": "text", "value_selector": ["node1", "text"]},
|
||||
{"variable": "result", "value_selector": ["node2", "result"]}]
|
||||
->
|
||||
[VariableSegment(["node1", "text"]),
|
||||
TextSegment("\n"),
|
||||
VariableSegment(["node2", "result"])]
|
||||
|
||||
Args:
|
||||
outputs_config: List of output configurations with variable and value_selector
|
||||
|
||||
Returns:
|
||||
Template instance
|
||||
"""
|
||||
segments: list[TemplateSegmentUnion] = []
|
||||
|
||||
for i, output in enumerate(outputs_config):
|
||||
if i > 0:
|
||||
# Add newline separator between variables
|
||||
segments.append(TextSegment(text="\n"))
|
||||
|
||||
value_selector = output.get("value_selector", [])
|
||||
variable_name = output.get("variable", "")
|
||||
if value_selector:
|
||||
segments.append(VariableSegment(selector=list(value_selector), variable_name=variable_name))
|
||||
|
||||
if len(segments) > 0 and isinstance(segments[-1], TextSegment):
|
||||
segments = segments[:-1]
|
||||
|
||||
return cls(segments=segments)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the template."""
|
||||
return "".join(str(segment) for segment in self.segments)
|
||||
28
dify/api/core/workflow/nodes/base/usage_tracking_mixin.py
Normal file
28
dify/api/core/workflow/nodes/base/usage_tracking_mixin.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class LLMUsageTrackingMixin:
|
||||
"""Provides shared helpers for merging and recording LLM usage within workflow nodes."""
|
||||
|
||||
graph_runtime_state: GraphRuntimeState
|
||||
|
||||
@staticmethod
|
||||
def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage:
|
||||
"""Return a combined usage snapshot, preserving zero-value inputs."""
|
||||
if new_usage is None or new_usage.total_tokens <= 0:
|
||||
return current
|
||||
if current.total_tokens == 0:
|
||||
return new_usage
|
||||
return current.plus(new_usage)
|
||||
|
||||
def _accumulate_usage(self, usage: LLMUsage) -> None:
|
||||
"""Push usage into the graph runtime accumulator for downstream reporting."""
|
||||
if usage.total_tokens <= 0:
|
||||
return
|
||||
|
||||
current_usage = self.graph_runtime_state.llm_usage
|
||||
if current_usage.total_tokens == 0:
|
||||
self.graph_runtime_state.llm_usage = usage.model_copy()
|
||||
else:
|
||||
self.graph_runtime_state.llm_usage = current_usage.plus(usage)
|
||||
130
dify/api/core/workflow/nodes/base/variable_template_parser.py
Normal file
130
dify/api/core/workflow/nodes/base/variable_template_parser.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from .entities import VariableSelector
|
||||
|
||||
REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
|
||||
|
||||
SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
|
||||
|
||||
|
||||
def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]:
|
||||
parts = SELECTOR_PATTERN.split(template)
|
||||
selectors = []
|
||||
for part in filter(lambda x: x, parts):
|
||||
if "." in part and part[0] == "#" and part[-1] == "#":
|
||||
selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split(".")))
|
||||
return selectors
|
||||
|
||||
|
||||
class VariableTemplateParser:
|
||||
"""
|
||||
!NOTE: Consider to use the new `segments` module instead of this class.
|
||||
|
||||
A class for parsing and manipulating template variables in a string.
|
||||
|
||||
Rules:
|
||||
|
||||
1. Template variables must be enclosed in `{{}}`.
|
||||
2. The template variable Key can only be: #node_id.var1.var2#.
|
||||
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
|
||||
|
||||
Example usage:
|
||||
|
||||
template = "Hello, {{#node_id.query.name#}}! Your age is {{#node_id.query.age#}}."
|
||||
parser = VariableTemplateParser(template)
|
||||
|
||||
# Extract template variable keys
|
||||
variable_keys = parser.extract()
|
||||
print(variable_keys)
|
||||
# Output: ['#node_id.query.name#', '#node_id.query.age#']
|
||||
|
||||
# Extract variable selectors
|
||||
variable_selectors = parser.extract_variable_selectors()
|
||||
print(variable_selectors)
|
||||
# Output: [VariableSelector(variable='#node_id.query.name#', value_selector=['node_id', 'query', 'name']),
|
||||
# VariableSelector(variable='#node_id.query.age#', value_selector=['node_id', 'query', 'age'])]
|
||||
|
||||
# Format the template string
|
||||
inputs = {'#node_id.query.name#': 'John', '#node_id.query.age#': 25}}
|
||||
formatted_string = parser.format(inputs)
|
||||
print(formatted_string)
|
||||
# Output: "Hello, John! Your age is 25."
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
self.template = template
|
||||
self.variable_keys = self.extract()
|
||||
|
||||
def extract(self):
|
||||
"""
|
||||
Extracts all the template variable keys from the template string.
|
||||
|
||||
Returns:
|
||||
A list of template variable keys.
|
||||
"""
|
||||
# Regular expression to match the template rules
|
||||
matches = re.findall(REGEX, self.template)
|
||||
|
||||
first_group_matches = [match[0] for match in matches]
|
||||
|
||||
return list(set(first_group_matches))
|
||||
|
||||
def extract_variable_selectors(self) -> list[VariableSelector]:
|
||||
"""
|
||||
Extracts the variable selectors from the template variable keys.
|
||||
|
||||
Returns:
|
||||
A list of VariableSelector objects representing the variable selectors.
|
||||
"""
|
||||
variable_selectors = []
|
||||
for variable_key in self.variable_keys:
|
||||
remove_hash = variable_key.replace("#", "")
|
||||
split_result = remove_hash.split(".")
|
||||
if len(split_result) < 2:
|
||||
continue
|
||||
|
||||
variable_selectors.append(VariableSelector(variable=variable_key, value_selector=split_result))
|
||||
|
||||
return variable_selectors
|
||||
|
||||
def format(self, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Formats the template string by replacing the template variables with their corresponding values.
|
||||
|
||||
Args:
|
||||
inputs: A dictionary containing the values for the template variables.
|
||||
|
||||
Returns:
|
||||
The formatted string with template variables replaced by their values.
|
||||
"""
|
||||
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
value = inputs.get(key, match.group(0)) # return original matched string if key not found
|
||||
|
||||
if value is None:
|
||||
value = ""
|
||||
# convert the value to string
|
||||
if isinstance(value, list | dict | bool | int | float):
|
||||
value = str(value)
|
||||
|
||||
# remove template variables if required
|
||||
return VariableTemplateParser.remove_template_variables(value)
|
||||
|
||||
prompt = re.sub(REGEX, replacer, self.template)
|
||||
return re.sub(r"<\|.*?\|>", "", prompt)
|
||||
|
||||
@classmethod
|
||||
def remove_template_variables(cls, text: str):
|
||||
"""
|
||||
Removes the template variables from the given text.
|
||||
|
||||
Args:
|
||||
text: The text from which to remove the template variables.
|
||||
|
||||
Returns:
|
||||
The text with template variables removed.
|
||||
"""
|
||||
return re.sub(REGEX, r"{\1}", text)
|
||||
3
dify/api/core/workflow/nodes/code/__init__.py
Normal file
3
dify/api/core/workflow/nodes/code/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .code_node import CodeNode
|
||||
|
||||
__all__ = ["CodeNode"]
|
||||
444
dify/api/core/workflow/nodes/code/code_node.py
Normal file
444
dify/api/core/workflow/nodes/code/code_node.py
Normal file
@@ -0,0 +1,444 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
DepthLimitError,
|
||||
OutputValidationError,
|
||||
)
|
||||
|
||||
|
||||
class CodeNode(Node):
|
||||
node_type = NodeType.CODE
|
||||
|
||||
_node_data: CodeNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = CodeNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
code_language = CodeLanguage.PYTHON3
|
||||
if filters:
|
||||
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
|
||||
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
|
||||
|
||||
return code_provider.get_default_config()
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get code language
|
||||
code_language = self._node_data.code_language
|
||||
code = self._node_data.code
|
||||
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in self._node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
variables[variable_name] = [v.to_dict() for v in variable.value] if variable.value else None
|
||||
else:
|
||||
variables[variable_name] = variable.to_object() if variable else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=code_language,
|
||||
code=code,
|
||||
inputs=variables,
|
||||
)
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result=result, output_schema=self._node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
)
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
def _check_string(self, value: str | None, variable: str) -> str | None:
|
||||
"""
|
||||
Check string
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{variable}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
|
||||
)
|
||||
|
||||
return value.replace("\x00", "")
|
||||
|
||||
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
return value
|
||||
|
||||
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
|
||||
"""
|
||||
Check number
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` is out of range,"
|
||||
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
|
||||
)
|
||||
|
||||
if isinstance(value, float):
|
||||
decimal_value = Decimal(str(value)).normalize()
|
||||
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
|
||||
# raise error if precision is too high
|
||||
if precision > dify_config.CODE_MAX_PRECISION:
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` has too high precision,"
|
||||
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
def _transform_result(
|
||||
self,
|
||||
result: Mapping[str, Any],
|
||||
output_schema: dict[str, CodeNodeData.Output] | None,
|
||||
prefix: str = "",
|
||||
depth: int = 1,
|
||||
):
|
||||
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
|
||||
# Note that `_transform_result` may produce lists containing `None` values,
|
||||
# which don't conform to the type requirements of `Array*Segment` classes.
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
transformed_result: dict[str, Any] = {}
|
||||
if output_schema is None:
|
||||
# validate output thought instance type
|
||||
for output_name, output_value in result.items():
|
||||
if isinstance(output_value, dict):
|
||||
self._transform_result(
|
||||
result=output_value,
|
||||
output_schema=None,
|
||||
prefix=f"{prefix}.{output_name}" if prefix else output_name,
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif isinstance(output_value, bool):
|
||||
self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name)
|
||||
elif isinstance(output_value, int | float):
|
||||
self._check_number(
|
||||
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
|
||||
)
|
||||
elif isinstance(output_value, str):
|
||||
self._check_string(
|
||||
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
|
||||
)
|
||||
elif isinstance(output_value, list):
|
||||
first_element = output_value[0] if len(output_value) > 0 else None
|
||||
if first_element is not None:
|
||||
if isinstance(first_element, int | float) and all(
|
||||
value is None or isinstance(value, int | float) for value in output_value
|
||||
):
|
||||
for i, value in enumerate(output_value):
|
||||
self._check_number(
|
||||
value=value,
|
||||
variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
|
||||
)
|
||||
elif isinstance(first_element, str) and all(
|
||||
value is None or isinstance(value, str) for value in output_value
|
||||
):
|
||||
for i, value in enumerate(output_value):
|
||||
self._check_string(
|
||||
value=value,
|
||||
variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
|
||||
)
|
||||
elif (
|
||||
isinstance(first_element, dict)
|
||||
and all(value is None or isinstance(value, dict) for value in output_value)
|
||||
or isinstance(first_element, list)
|
||||
and all(value is None or isinstance(value, list) for value in output_value)
|
||||
):
|
||||
for i, value in enumerate(output_value):
|
||||
if value is not None:
|
||||
self._transform_result(
|
||||
result=value,
|
||||
output_schema=None,
|
||||
prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]",
|
||||
depth=depth + 1,
|
||||
)
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}.{output_name} is not a valid array."
|
||||
f" make sure all elements are of the same type."
|
||||
)
|
||||
elif output_value is None:
|
||||
pass
|
||||
else:
|
||||
raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.")
|
||||
|
||||
return result
|
||||
|
||||
parameters_validated = {}
|
||||
for output_name, output_config in output_schema.items():
|
||||
dot = "." if prefix else ""
|
||||
if output_name not in result:
|
||||
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
|
||||
|
||||
if output_config.type == SegmentType.OBJECT:
|
||||
# check if output is object
|
||||
if not isinstance(result.get(output_name), dict):
|
||||
if result[output_name] is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an object,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
transformed_result[output_name] = self._transform_result(
|
||||
result=result[output_name],
|
||||
output_schema=output_config.children,
|
||||
prefix=f"{prefix}.{output_name}",
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif output_config.type == SegmentType.NUMBER:
|
||||
# check if number available
|
||||
value = result.get(output_name)
|
||||
if value is not None and not isinstance(value, (int, float)):
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not a number,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
checked = self._check_number(value=value, variable=f"{prefix}{dot}{output_name}")
|
||||
# If the output is a boolean and the output schema specifies a NUMBER type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
transformed_result[output_name] = self._convert_boolean_to_int(checked)
|
||||
|
||||
elif output_config.type == SegmentType.STRING:
|
||||
# check if string available
|
||||
value = result.get(output_name)
|
||||
if value is not None and not isinstance(value, str):
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} must be a string, got {type(value).__name__} instead"
|
||||
)
|
||||
transformed_result[output_name] = self._check_string(
|
||||
value=value,
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == SegmentType.BOOLEAN:
|
||||
transformed_result[output_name] = self._check_boolean(
|
||||
value=result[output_name],
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == SegmentType.ARRAY_NUMBER:
|
||||
# check if array of number available
|
||||
value = result[output_name]
|
||||
if not isinstance(value, list):
|
||||
if value is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
||||
for i, inner_value in enumerate(value):
|
||||
if not isinstance(inner_value, (int, float)):
|
||||
raise OutputValidationError(
|
||||
f"The element at index {i} of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" a number."
|
||||
)
|
||||
_ = self._check_number(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
transformed_result[output_name] = [
|
||||
# If the element is a boolean and the output schema specifies a `array[number]` type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
self._convert_boolean_to_int(v)
|
||||
for v in value
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_STRING:
|
||||
# check if array of string available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_OBJECT:
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
||||
for i, value in enumerate(result[output_name]):
|
||||
if not isinstance(value, dict):
|
||||
if value is None:
|
||||
pass
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name}[{i}] is not an object,"
|
||||
f" got {type(value)} instead at index {i}."
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
None
|
||||
if value is None
|
||||
else self._transform_result(
|
||||
result=value,
|
||||
output_schema=output_config.children,
|
||||
prefix=f"{prefix}{dot}{output_name}[{i}]",
|
||||
depth=depth + 1,
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
|
||||
# check if array of object available
|
||||
value = result[output_name]
|
||||
if not isinstance(value, list):
|
||||
if value is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
for i, inner_value in enumerate(value):
|
||||
if inner_value is not None and not isinstance(inner_value, bool):
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name}[{i}] is not a boolean,"
|
||||
f" got {type(inner_value)} instead."
|
||||
)
|
||||
_ = self._check_boolean(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
transformed_result[output_name] = value
|
||||
|
||||
else:
|
||||
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
|
||||
|
||||
parameters_validated[output_name] = True
|
||||
|
||||
# check if all output parameters are validated
|
||||
if len(parameters_validated) != len(result):
|
||||
raise CodeNodeError("Not all output parameters are validated.")
|
||||
|
||||
return transformed_result
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = CodeNodeData.model_validate(node_data)
|
||||
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in typed_node_data.variables
|
||||
}
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
@staticmethod
|
||||
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
|
||||
"""This function convert boolean to integers when the output schema specifies a NUMBER type.
|
||||
|
||||
This ensures compatibility with existing workflows that may use
|
||||
`True` and `False` as values for NUMBER type outputs.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
return value
|
||||
47
dify/api/core/workflow/nodes/code/entities.py
Normal file
47
dify/api/core/workflow/nodes/code/entities.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Annotated, Literal, Self
|
||||
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_type(segment_type: SegmentType) -> SegmentType:
|
||||
if segment_type not in _ALLOWED_OUTPUT_FROM_CODE:
|
||||
raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}")
|
||||
return segment_type
|
||||
|
||||
|
||||
class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: dict[str, Self] | None = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
name: str
|
||||
version: str
|
||||
|
||||
variables: list[VariableSelector]
|
||||
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
|
||||
code: str
|
||||
outputs: dict[str, Output]
|
||||
dependencies: list[Dependency] | None = None
|
||||
16
dify/api/core/workflow/nodes/code/exc.py
Normal file
16
dify/api/core/workflow/nodes/code/exc.py
Normal file
@@ -0,0 +1,16 @@
|
||||
class CodeNodeError(ValueError):
|
||||
"""Base class for code node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OutputValidationError(CodeNodeError):
|
||||
"""Raised when there is an output validation error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DepthLimitError(CodeNodeError):
|
||||
"""Raised when the depth limit is reached."""
|
||||
|
||||
pass
|
||||
3
dify/api/core/workflow/nodes/datasource/__init__.py
Normal file
3
dify/api/core/workflow/nodes/datasource/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .datasource_node import DatasourceNode
|
||||
|
||||
__all__ = ["DatasourceNode"]
|
||||
502
dify/api/core/workflow/nodes/datasource/datasource_node.py
Normal file
502
dify/api/core/workflow/nodes/datasource/datasource_node.py
Normal file
@@ -0,0 +1,502 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
DatasourceParameter,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from core.file import File
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.tool.exc import ToolFileError
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from .entities import DatasourceNodeData
|
||||
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||
|
||||
|
||||
class DatasourceNode(Node):
|
||||
"""
|
||||
Datasource Node
|
||||
"""
|
||||
|
||||
_node_data: DatasourceNodeData
|
||||
node_type = NodeType.DATASOURCE
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = DatasourceNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the datasource node
|
||||
"""
|
||||
|
||||
node_data = self._node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
if not datasource_type_segement:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
|
||||
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
|
||||
if not datasource_info_segement:
|
||||
raise DatasourceNodeError("Datasource info is not set")
|
||||
datasource_info_value = datasource_info_segement.value
|
||||
if not isinstance(datasource_info_value, dict):
|
||||
raise DatasourceNodeError("Invalid datasource info format")
|
||||
datasource_info: dict[str, Any] = datasource_info_value
|
||||
# get datasource runtime
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
|
||||
|
||||
parameters_for_log = datasource_info
|
||||
|
||||
try:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
credential_id=datasource_info.get("credential_id", ""),
|
||||
)
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
online_document_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||
workspace_id=datasource_info.get("workspace_id", ""),
|
||||
page_id=datasource_info.get("page", {}).get("page_id", ""),
|
||||
type=datasource_info.get("page", {}).get("type", ""),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
)
|
||||
yield from self._transform_message(
|
||||
messages=online_document_result,
|
||||
parameters_for_log=parameters_for_log,
|
||||
datasource_info=datasource_info,
|
||||
)
|
||||
case DatasourceProviderType.ONLINE_DRIVE:
|
||||
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
online_drive_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.online_drive_download_file(
|
||||
user_id=self.user_id,
|
||||
request=OnlineDriveDownloadFileRequest(
|
||||
id=datasource_info.get("id", ""),
|
||||
bucket=datasource_info.get("bucket"),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
)
|
||||
yield from self._transform_datasource_file_message(
|
||||
messages=online_drive_result,
|
||||
parameters_for_log=parameters_for_log,
|
||||
datasource_info=datasource_info,
|
||||
variable_pool=variable_pool,
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
**datasource_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
related_id = datasource_info.get("related_id")
|
||||
if not related_id:
|
||||
raise DatasourceNodeError("File is not exist")
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first()
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file Info")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=upload_file.source_url,
|
||||
)
|
||||
variable_pool.add([self._node_id, "file"], file_info)
|
||||
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": file_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to transform datasource message: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to invoke datasource: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
datasource_parameters: Sequence[DatasourceParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: DatasourceNodeData,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
if node_data.datasource_parameters:
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
parameter = datasource_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
datasource_input = node_data.datasource_parameters[parameter_name]
|
||||
if datasource_input.type == "variable":
|
||||
variable = variable_pool.get(datasource_input.value)
|
||||
if variable is None:
|
||||
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif datasource_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(datasource_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
typed_node_data = DatasourceNodeData.model_validate(node_data)
|
||||
result = {}
|
||||
if typed_node_data.datasource_parameters:
|
||||
for parameter_name in typed_node_data.datasource_parameters:
|
||||
input = typed_node_data.datasource_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceMessage.MessageType.BINARY_LINK,
|
||||
DatasourceMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
# mark the end of the stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={**variables},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _transform_datasource_file_message(
|
||||
self,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
datasource_type: DatasourceProviderType,
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
file = None
|
||||
for message in message_stream:
|
||||
if message.type == DatasourceMessage.MessageType.BINARY_LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
if file:
|
||||
variable_pool.add([self._node_id, "file"], file)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": file,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
41
dify/api/core/workflow/nodes/datasource/entities.py
Normal file
41
dify/api/core/workflow/nodes/datasource/entities.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
plugin_id: str
|
||||
provider_name: str # redundancy
|
||||
provider_type: str
|
||||
datasource_name: str | None = "local_file"
|
||||
datasource_configurations: dict[str, Any] | None = None
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"] | None = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get("value")
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, str | int | float | bool):
|
||||
raise ValueError("value must be a string, int, float, or bool")
|
||||
return typ
|
||||
|
||||
datasource_parameters: dict[str, DatasourceInput] | None = None
|
||||
16
dify/api/core/workflow/nodes/datasource/exc.py
Normal file
16
dify/api/core/workflow/nodes/datasource/exc.py
Normal file
@@ -0,0 +1,16 @@
|
||||
class DatasourceNodeError(ValueError):
|
||||
"""Base exception for datasource node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceParameterError(DatasourceNodeError):
|
||||
"""Exception raised for errors in datasource parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceFileError(DatasourceNodeError):
|
||||
"""Exception raised for errors related to datasource files."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,4 @@
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .node import DocumentExtractorNode
|
||||
|
||||
__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"]
|
||||
@@ -0,0 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class DocumentExtractorNodeData(BaseNodeData):
|
||||
variable_selector: Sequence[str]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user