dify
This commit is contained in:
14
dify/api/core/workflow/runtime/__init__.py
Normal file
14
dify/api/core/workflow/runtime/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
|
||||
from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
|
||||
from .variable_pool import VariablePool, VariableValue
|
||||
|
||||
__all__ = [
|
||||
"GraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeStateWrapper",
|
||||
"ReadOnlyVariablePool",
|
||||
"ReadOnlyVariablePoolWrapper",
|
||||
"VariablePool",
|
||||
"VariableValue",
|
||||
]
|
||||
472
dify/api/core/workflow/runtime/graph_runtime_state.py
Normal file
472
dify/api/core/workflow/runtime/graph_runtime_state.py
Normal file
@@ -0,0 +1,472 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
|
||||
class ReadyQueueProtocol(Protocol):
|
||||
"""Structural interface required from ready queue implementations."""
|
||||
|
||||
def put(self, item: str) -> None:
|
||||
"""Enqueue the identifier of a node that is ready to run."""
|
||||
...
|
||||
|
||||
def get(self, timeout: float | None = None) -> str:
|
||||
"""Return the next node identifier, blocking until available or timeout expires."""
|
||||
...
|
||||
|
||||
def task_done(self) -> None:
|
||||
"""Signal that the most recently dequeued node has completed processing."""
|
||||
...
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Return True when the queue contains no pending nodes."""
|
||||
...
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""Approximate the number of pending nodes awaiting execution."""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the queue contents for persistence."""
|
||||
...
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore the queue contents from a serialized payload."""
|
||||
...
|
||||
|
||||
|
||||
class GraphExecutionProtocol(Protocol):
|
||||
"""Structural interface for graph execution aggregate."""
|
||||
|
||||
workflow_id: str
|
||||
started: bool
|
||||
completed: bool
|
||||
aborted: bool
|
||||
error: Exception | None
|
||||
exceptions_count: int
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
...
|
||||
|
||||
def complete(self) -> None:
|
||||
"""Mark execution as successfully completed."""
|
||||
...
|
||||
|
||||
def abort(self, reason: str) -> None:
|
||||
"""Abort execution in response to an external stop request."""
|
||||
...
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Record an unrecoverable error and end execution."""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize execution state into a JSON payload."""
|
||||
...
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore execution state from a previously serialized payload."""
|
||||
...
|
||||
|
||||
|
||||
class ResponseStreamCoordinatorProtocol(Protocol):
|
||||
"""Structural interface for response stream coordinator."""
|
||||
|
||||
def register(self, response_node_id: str) -> None:
|
||||
"""Register a response node so its outputs can be streamed."""
|
||||
...
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore coordinator state from a serialized payload."""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize coordinator state for persistence."""
|
||||
...
|
||||
|
||||
|
||||
class GraphProtocol(Protocol):
|
||||
"""Structural interface required from graph instances attached to the runtime state."""
|
||||
|
||||
nodes: Mapping[str, object]
|
||||
edges: Mapping[str, object]
|
||||
root_node: object
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _GraphRuntimeStateSnapshot:
|
||||
"""Immutable view of a serialized runtime state snapshot."""
|
||||
|
||||
start_at: float
|
||||
total_tokens: int
|
||||
node_run_steps: int
|
||||
llm_usage: LLMUsage
|
||||
outputs: dict[str, Any]
|
||||
variable_pool: VariablePool
|
||||
has_variable_pool: bool
|
||||
ready_queue_dump: str | None
|
||||
graph_execution_dump: str | None
|
||||
response_coordinator_dump: str | None
|
||||
paused_nodes: tuple[str, ...]
|
||||
|
||||
|
||||
class GraphRuntimeState:
|
||||
"""Mutable runtime state shared across graph execution components."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
variable_pool: VariablePool,
|
||||
start_at: float,
|
||||
total_tokens: int = 0,
|
||||
llm_usage: LLMUsage | None = None,
|
||||
outputs: dict[str, object] | None = None,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue: ReadyQueueProtocol | None = None,
|
||||
graph_execution: GraphExecutionProtocol | None = None,
|
||||
response_coordinator: ResponseStreamCoordinatorProtocol | None = None,
|
||||
graph: GraphProtocol | None = None,
|
||||
) -> None:
|
||||
self._variable_pool = variable_pool
|
||||
self._start_at = start_at
|
||||
|
||||
if total_tokens < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = total_tokens
|
||||
|
||||
self._llm_usage = (llm_usage or LLMUsage.empty_usage()).model_copy()
|
||||
self._outputs = deepcopy(outputs) if outputs is not None else {}
|
||||
|
||||
if node_run_steps < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = node_run_steps
|
||||
|
||||
self._graph: GraphProtocol | None = None
|
||||
|
||||
self._ready_queue = ready_queue
|
||||
self._graph_execution = graph_execution
|
||||
self._response_coordinator = response_coordinator
|
||||
self._pending_response_coordinator_dump: str | None = None
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Context binding helpers
|
||||
# ------------------------------------------------------------------
|
||||
def attach_graph(self, graph: GraphProtocol) -> None:
|
||||
"""Attach the materialized graph to the runtime state."""
|
||||
if self._graph is not None and self._graph is not graph:
|
||||
raise ValueError("GraphRuntimeState already attached to a different graph instance")
|
||||
|
||||
self._graph = graph
|
||||
|
||||
if self._response_coordinator is None:
|
||||
self._response_coordinator = self._build_response_coordinator(graph)
|
||||
|
||||
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
|
||||
self._response_coordinator.loads(self._pending_response_coordinator_dump)
|
||||
self._pending_response_coordinator_dump = None
|
||||
|
||||
def configure(self, *, graph: GraphProtocol | None = None) -> None:
|
||||
"""Ensure core collaborators are initialized with the provided context."""
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
# Ensure collaborators are instantiated
|
||||
_ = self.ready_queue
|
||||
_ = self.graph_execution
|
||||
if self._graph is not None:
|
||||
_ = self.response_coordinator
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Primary collaborators
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def variable_pool(self) -> VariablePool:
|
||||
return self._variable_pool
|
||||
|
||||
@property
|
||||
def ready_queue(self) -> ReadyQueueProtocol:
|
||||
if self._ready_queue is None:
|
||||
self._ready_queue = self._build_ready_queue()
|
||||
return self._ready_queue
|
||||
|
||||
@property
|
||||
def graph_execution(self) -> GraphExecutionProtocol:
|
||||
if self._graph_execution is None:
|
||||
self._graph_execution = self._build_graph_execution()
|
||||
return self._graph_execution
|
||||
|
||||
@property
|
||||
def response_coordinator(self) -> ResponseStreamCoordinatorProtocol:
|
||||
if self._response_coordinator is None:
|
||||
if self._graph is None:
|
||||
raise ValueError("Graph must be attached before accessing response coordinator")
|
||||
self._response_coordinator = self._build_response_coordinator(self._graph)
|
||||
return self._response_coordinator
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scalar state
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
return self._start_at
|
||||
|
||||
@start_at.setter
|
||||
def start_at(self, value: float) -> None:
|
||||
self._start_at = value
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self._total_tokens
|
||||
|
||||
@total_tokens.setter
|
||||
def total_tokens(self, value: int) -> None:
|
||||
if value < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = value
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
return self._llm_usage.model_copy()
|
||||
|
||||
@llm_usage.setter
|
||||
def llm_usage(self, value: LLMUsage) -> None:
|
||||
self._llm_usage = value.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
return deepcopy(self._outputs)
|
||||
|
||||
@outputs.setter
|
||||
def outputs(self, value: dict[str, Any]) -> None:
|
||||
self._outputs = deepcopy(value)
|
||||
|
||||
def set_output(self, key: str, value: object) -> None:
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
def get_output(self, key: str, default: object = None) -> object:
|
||||
return deepcopy(self._outputs.get(key, default))
|
||||
|
||||
def update_outputs(self, updates: dict[str, object]) -> None:
|
||||
for key, value in updates.items():
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
return self._node_run_steps
|
||||
|
||||
@node_run_steps.setter
|
||||
def node_run_steps(self, value: int) -> None:
|
||||
if value < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = value
|
||||
|
||||
def increment_node_run_steps(self) -> None:
|
||||
self._node_run_steps += 1
|
||||
|
||||
def add_tokens(self, tokens: int) -> None:
|
||||
if tokens < 0:
|
||||
raise ValueError("tokens must be non-negative")
|
||||
self._total_tokens += tokens
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serialization
|
||||
# ------------------------------------------------------------------
|
||||
def dumps(self) -> str:
|
||||
"""Serialize runtime state into a JSON string."""
|
||||
|
||||
snapshot: dict[str, Any] = {
|
||||
"version": "1.0",
|
||||
"start_at": self._start_at,
|
||||
"total_tokens": self._total_tokens,
|
||||
"node_run_steps": self._node_run_steps,
|
||||
"llm_usage": self._llm_usage.model_dump(mode="json"),
|
||||
"outputs": self.outputs,
|
||||
"variable_pool": self.variable_pool.model_dump(mode="json"),
|
||||
"ready_queue": self.ready_queue.dumps(),
|
||||
"graph_execution": self.graph_execution.dumps(),
|
||||
"paused_nodes": list(self._paused_nodes),
|
||||
}
|
||||
|
||||
if self._response_coordinator is not None and self._graph is not None:
|
||||
snapshot["response_coordinator"] = self._response_coordinator.dumps()
|
||||
|
||||
return json.dumps(snapshot, default=pydantic_encoder)
|
||||
|
||||
@classmethod
|
||||
def from_snapshot(cls, data: str | Mapping[str, Any]) -> GraphRuntimeState:
|
||||
"""Restore runtime state from a serialized snapshot."""
|
||||
|
||||
snapshot = cls._parse_snapshot_payload(data)
|
||||
|
||||
state = cls(
|
||||
variable_pool=snapshot.variable_pool,
|
||||
start_at=snapshot.start_at,
|
||||
total_tokens=snapshot.total_tokens,
|
||||
llm_usage=snapshot.llm_usage,
|
||||
outputs=snapshot.outputs,
|
||||
node_run_steps=snapshot.node_run_steps,
|
||||
)
|
||||
state._apply_snapshot(snapshot)
|
||||
return state
|
||||
|
||||
def loads(self, data: str | Mapping[str, Any]) -> None:
|
||||
"""Restore runtime state from a serialized snapshot (legacy API)."""
|
||||
|
||||
snapshot = self._parse_snapshot_payload(data)
|
||||
self._apply_snapshot(snapshot)
|
||||
|
||||
def register_paused_node(self, node_id: str) -> None:
|
||||
"""Record a node that should resume when execution is continued."""
|
||||
|
||||
self._paused_nodes.add(node_id)
|
||||
|
||||
def consume_paused_nodes(self) -> list[str]:
|
||||
"""Retrieve and clear the list of paused nodes awaiting resume."""
|
||||
|
||||
nodes = list(self._paused_nodes)
|
||||
self._paused_nodes.clear()
|
||||
return nodes
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Builders
|
||||
# ------------------------------------------------------------------
|
||||
def _build_ready_queue(self) -> ReadyQueueProtocol:
|
||||
# Import lazily to avoid breaching architecture boundaries enforced by import-linter.
|
||||
module = importlib.import_module("core.workflow.graph_engine.ready_queue")
|
||||
in_memory_cls = module.InMemoryReadyQueue
|
||||
return in_memory_cls()
|
||||
|
||||
def _build_graph_execution(self) -> GraphExecutionProtocol:
|
||||
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
|
||||
module = importlib.import_module("core.workflow.graph_engine.domain.graph_execution")
|
||||
graph_execution_cls = module.GraphExecution
|
||||
workflow_id = self._pending_graph_execution_workflow_id or ""
|
||||
self._pending_graph_execution_workflow_id = None
|
||||
return graph_execution_cls(workflow_id=workflow_id)
|
||||
|
||||
def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol:
|
||||
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
|
||||
module = importlib.import_module("core.workflow.graph_engine.response_coordinator")
|
||||
coordinator_cls = module.ResponseStreamCoordinator
|
||||
return coordinator_cls(variable_pool=self.variable_pool, graph=graph)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Snapshot helpers
|
||||
# ------------------------------------------------------------------
|
||||
@classmethod
|
||||
def _parse_snapshot_payload(cls, data: str | Mapping[str, Any]) -> _GraphRuntimeStateSnapshot:
|
||||
payload: dict[str, Any]
|
||||
if isinstance(data, str):
|
||||
payload = json.loads(data)
|
||||
else:
|
||||
payload = dict(data)
|
||||
|
||||
version = payload.get("version")
|
||||
if version != "1.0":
|
||||
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}")
|
||||
|
||||
start_at = float(payload.get("start_at", 0.0))
|
||||
|
||||
total_tokens = int(payload.get("total_tokens", 0))
|
||||
if total_tokens < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
|
||||
node_run_steps = int(payload.get("node_run_steps", 0))
|
||||
if node_run_steps < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
|
||||
llm_usage_payload = payload.get("llm_usage", {})
|
||||
llm_usage = LLMUsage.model_validate(llm_usage_payload)
|
||||
|
||||
outputs_payload = deepcopy(payload.get("outputs", {}))
|
||||
|
||||
variable_pool_payload = payload.get("variable_pool")
|
||||
has_variable_pool = variable_pool_payload is not None
|
||||
variable_pool = VariablePool.model_validate(variable_pool_payload) if has_variable_pool else VariablePool()
|
||||
|
||||
ready_queue_payload = payload.get("ready_queue")
|
||||
graph_execution_payload = payload.get("graph_execution")
|
||||
response_payload = payload.get("response_coordinator")
|
||||
paused_nodes_payload = payload.get("paused_nodes", [])
|
||||
|
||||
return _GraphRuntimeStateSnapshot(
|
||||
start_at=start_at,
|
||||
total_tokens=total_tokens,
|
||||
node_run_steps=node_run_steps,
|
||||
llm_usage=llm_usage,
|
||||
outputs=outputs_payload,
|
||||
variable_pool=variable_pool,
|
||||
has_variable_pool=has_variable_pool,
|
||||
ready_queue_dump=ready_queue_payload,
|
||||
graph_execution_dump=graph_execution_payload,
|
||||
response_coordinator_dump=response_payload,
|
||||
paused_nodes=tuple(map(str, paused_nodes_payload)),
|
||||
)
|
||||
|
||||
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
|
||||
self._start_at = snapshot.start_at
|
||||
self._total_tokens = snapshot.total_tokens
|
||||
self._node_run_steps = snapshot.node_run_steps
|
||||
self._llm_usage = snapshot.llm_usage.model_copy()
|
||||
self._outputs = deepcopy(snapshot.outputs)
|
||||
if snapshot.has_variable_pool or self._variable_pool is None:
|
||||
self._variable_pool = snapshot.variable_pool
|
||||
|
||||
self._restore_ready_queue(snapshot.ready_queue_dump)
|
||||
self._restore_graph_execution(snapshot.graph_execution_dump)
|
||||
self._restore_response_coordinator(snapshot.response_coordinator_dump)
|
||||
self._paused_nodes = set(snapshot.paused_nodes)
|
||||
|
||||
def _restore_ready_queue(self, payload: str | None) -> None:
|
||||
if payload is not None:
|
||||
self._ready_queue = self._build_ready_queue()
|
||||
self._ready_queue.loads(payload)
|
||||
else:
|
||||
self._ready_queue = None
|
||||
|
||||
def _restore_graph_execution(self, payload: str | None) -> None:
|
||||
self._graph_execution = None
|
||||
self._pending_graph_execution_workflow_id = None
|
||||
|
||||
if payload is None:
|
||||
return
|
||||
|
||||
try:
|
||||
execution_payload = json.loads(payload)
|
||||
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
|
||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||
self._pending_graph_execution_workflow_id = None
|
||||
|
||||
self.graph_execution.loads(payload)
|
||||
|
||||
def _restore_response_coordinator(self, payload: str | None) -> None:
|
||||
if payload is None:
|
||||
self._pending_response_coordinator_dump = None
|
||||
self._response_coordinator = None
|
||||
return
|
||||
|
||||
if self._graph is not None:
|
||||
self.response_coordinator.loads(payload)
|
||||
self._pending_response_coordinator_dump = None
|
||||
return
|
||||
|
||||
self._pending_response_coordinator_dump = payload
|
||||
self._response_coordinator = None
|
||||
@@ -0,0 +1,83 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||
|
||||
|
||||
class ReadOnlyVariablePool(Protocol):
|
||||
"""Read-only interface for VariablePool."""
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
"""Get a variable value (read-only)."""
|
||||
...
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
"""Get all variables for a node (read-only)."""
|
||||
...
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> Mapping[str, object]:
|
||||
"""Get all variables stored under a given node prefix (read-only)."""
|
||||
...
|
||||
|
||||
|
||||
class ReadOnlyGraphRuntimeState(Protocol):
|
||||
"""
|
||||
Read-only view of GraphRuntimeState for layers.
|
||||
|
||||
This protocol defines a read-only interface that prevents layers from
|
||||
modifying the graph runtime state while still allowing observation.
|
||||
All methods return defensive copies to ensure immutability.
|
||||
"""
|
||||
|
||||
@property
|
||||
def system_variable(self) -> SystemVariableReadOnlyView: ...
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||
"""Get read-only access to the variable pool."""
|
||||
...
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get a copy of LLM usage info (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
"""Get a defensive copy of outputs (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def ready_queue_size(self) -> int:
|
||||
"""Get the number of nodes currently in the ready queue."""
|
||||
...
|
||||
|
||||
@property
|
||||
def exceptions_count(self) -> int:
|
||||
"""Get the number of node execution exceptions recorded."""
|
||||
...
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value (returns a copy)."""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the runtime state into a JSON snapshot (read-only)."""
|
||||
...
|
||||
87
dify/api/core/workflow/runtime/read_only_wrappers.py
Normal file
87
dify/api/core/workflow/runtime/read_only_wrappers.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .variable_pool import VariablePool
|
||||
|
||||
|
||||
class ReadOnlyVariablePoolWrapper:
|
||||
"""Provide defensive, read-only access to ``VariablePool``."""
|
||||
|
||||
def __init__(self, variable_pool: VariablePool) -> None:
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
"""Return a copy of a variable value if present."""
|
||||
value = self._variable_pool.get([node_id, variable_key])
|
||||
return deepcopy(value) if value is not None else None
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
"""Return a copy of all variables for the specified node."""
|
||||
variables: dict[str, object] = {}
|
||||
if node_id in self._variable_pool.variable_dictionary:
|
||||
for key, variable in self._variable_pool.variable_dictionary[node_id].items():
|
||||
variables[key] = deepcopy(variable.value)
|
||||
return variables
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> Mapping[str, object]:
|
||||
"""Return a copy of all variables stored under the given prefix."""
|
||||
return self._variable_pool.get_by_prefix(prefix)
|
||||
|
||||
|
||||
class ReadOnlyGraphRuntimeStateWrapper:
|
||||
"""Expose a defensive, read-only view of ``GraphRuntimeState``."""
|
||||
|
||||
def __init__(self, state: GraphRuntimeState) -> None:
|
||||
self._state = state
|
||||
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
|
||||
|
||||
@property
|
||||
def system_variable(self) -> SystemVariableReadOnlyView:
|
||||
return self._state.variable_pool.system_variables.as_view()
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
|
||||
return self._variable_pool_wrapper
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
return self._state.start_at
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self._state.total_tokens
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
return self._state.llm_usage.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
return deepcopy(self._state.outputs)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
return self._state.node_run_steps
|
||||
|
||||
@property
|
||||
def ready_queue_size(self) -> int:
|
||||
return self._state.ready_queue.qsize()
|
||||
|
||||
@property
|
||||
def exceptions_count(self) -> int:
|
||||
return self._state.graph_execution.exceptions_count
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
return self._state.get_output(key, default)
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the underlying runtime state for external persistence."""
|
||||
return self._state.dumps()
|
||||
272
dify/api/core/workflow/runtime/variable_pool.py
Normal file
272
dify/api/core/workflow/runtime/variable_pool.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Annotated, Any, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, ObjectSegment
|
||||
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
RAG_PIPELINE_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from factories import variable_factory
|
||||
|
||||
VariableValue = Union[str, int, float, dict[str, object], list[object], File]
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
|
||||
|
||||
class VariablePool(BaseModel):
|
||||
# Variable dictionary is a dictionary for looking up variables by their selector.
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
|
||||
description="Variables mapping",
|
||||
default=defaultdict(dict),
|
||||
)
|
||||
|
||||
# The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere.
|
||||
user_inputs: Mapping[str, Any] = Field(
|
||||
description="User inputs",
|
||||
default_factory=dict,
|
||||
)
|
||||
system_variables: SystemVariable = Field(
|
||||
description="System variables",
|
||||
default_factory=SystemVariable.empty,
|
||||
)
|
||||
environment_variables: Sequence[VariableUnion] = Field(
|
||||
description="Environment variables.",
|
||||
default_factory=list[VariableUnion],
|
||||
)
|
||||
conversation_variables: Sequence[VariableUnion] = Field(
|
||||
description="Conversation variables.",
|
||||
default_factory=list[VariableUnion],
|
||||
)
|
||||
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
|
||||
description="RAG pipeline variables.",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
def model_post_init(self, context: Any, /):
|
||||
# Create a mapping from field names to SystemVariableKey enum values
|
||||
self._add_system_variables(self.system_variables)
|
||||
# Add environment variables to the variable pool
|
||||
for var in self.environment_variables:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
# Add conversation variables to the variable pool
|
||||
for var in self.conversation_variables:
|
||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
# Add rag pipeline variables to the variable pool
|
||||
if self.rag_pipeline_variables:
|
||||
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)
|
||||
for rag_var in self.rag_pipeline_variables:
|
||||
node_id = rag_var.variable.belong_to_node_id
|
||||
key = rag_var.variable.variable
|
||||
value = rag_var.value
|
||||
rag_pipeline_variables_map[node_id][key] = value
|
||||
for key, value in rag_pipeline_variables_map.items():
|
||||
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /):
|
||||
"""
|
||||
Add a variable to the variable pool.
|
||||
|
||||
This method accepts a selector path and a value, converting the value
|
||||
to a Variable object if necessary before storing it in the pool.
|
||||
|
||||
Args:
|
||||
selector: A two-element sequence containing [node_id, variable_name].
|
||||
The selector must have exactly 2 elements to be valid.
|
||||
value: The value to store. Can be a Variable, Segment, or any value
|
||||
that can be converted to a Segment (str, int, float, dict, list, File).
|
||||
|
||||
Raises:
|
||||
ValueError: If selector length is not exactly 2 elements.
|
||||
|
||||
Note:
|
||||
While non-Segment values are currently accepted and automatically
|
||||
converted, it's recommended to pass Segment or Variable objects directly.
|
||||
"""
|
||||
if len(selector) != SELECTORS_LENGTH:
|
||||
raise ValueError(
|
||||
f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), "
|
||||
f"got {len(selector)} elements"
|
||||
)
|
||||
|
||||
if isinstance(value, Variable):
|
||||
variable = value
|
||||
elif isinstance(value, Segment):
|
||||
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
|
||||
else:
|
||||
segment = variable_factory.build_segment(value)
|
||||
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
||||
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
|
||||
|
||||
@classmethod
|
||||
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
|
||||
return selector[0], selector[1]
|
||||
|
||||
def _has(self, selector: Sequence[str]) -> bool:
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
if node_id not in self.variable_dictionary:
|
||||
return False
|
||||
if name not in self.variable_dictionary[node_id]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""
|
||||
Retrieve a variable's value from the pool as a Segment.
|
||||
|
||||
This method supports both simple selectors [node_id, variable_name] and
|
||||
extended selectors that include attribute access for FileSegment and
|
||||
ObjectSegment types.
|
||||
|
||||
Args:
|
||||
selector: A sequence with at least 2 elements:
|
||||
- [node_id, variable_name]: Returns the full segment
|
||||
- [node_id, variable_name, attr, ...]: Returns a nested value
|
||||
from FileSegment (e.g., 'url', 'name') or ObjectSegment
|
||||
|
||||
Returns:
|
||||
The Segment associated with the selector, or None if not found.
|
||||
Returns None if selector has fewer than 2 elements.
|
||||
|
||||
Raises:
|
||||
ValueError: If attempting to access an invalid FileAttribute.
|
||||
"""
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
return None
|
||||
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
node_map = self.variable_dictionary.get(node_id)
|
||||
if node_map is None:
|
||||
return None
|
||||
|
||||
segment: Segment | None = node_map.get(name)
|
||||
|
||||
if segment is None:
|
||||
return None
|
||||
|
||||
if len(selector) == 2:
|
||||
return segment
|
||||
|
||||
if isinstance(segment, FileSegment):
|
||||
attr = selector[2]
|
||||
# Python support `attr in FileAttribute` after 3.12
|
||||
if attr not in {item.value for item in FileAttribute}:
|
||||
return None
|
||||
attr = FileAttribute(attr)
|
||||
attr_value = file_manager.get_attr(file=segment.value, attr=attr)
|
||||
return variable_factory.build_segment(attr_value)
|
||||
|
||||
# Navigate through nested attributes
|
||||
result: Any = segment
|
||||
for attr in selector[2:]:
|
||||
result = self._extract_value(result)
|
||||
result = self._get_nested_attribute(result, attr)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
# Return result as Segment
|
||||
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
|
||||
|
||||
def _extract_value(self, obj: Any):
|
||||
"""Extract the actual value from an ObjectSegment."""
|
||||
return obj.value if isinstance(obj, ObjectSegment) else obj
|
||||
|
||||
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None:
|
||||
"""
|
||||
Get a nested attribute from a dictionary-like object.
|
||||
|
||||
Args:
|
||||
obj: The dictionary-like object to search.
|
||||
attr: The key to look up.
|
||||
|
||||
Returns:
|
||||
Segment | None:
|
||||
The corresponding Segment built from the attribute value if the key exists,
|
||||
otherwise None.
|
||||
"""
|
||||
if not isinstance(obj, dict) or attr not in obj:
|
||||
return None
|
||||
return variable_factory.build_segment(obj.get(attr))
|
||||
|
||||
def remove(self, selector: Sequence[str], /):
|
||||
"""
|
||||
Remove variables from the variable pool based on the given selector.
|
||||
|
||||
Args:
|
||||
selector (Sequence[str]): A sequence of strings representing the selector.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not selector:
|
||||
return
|
||||
if len(selector) == 1:
|
||||
self.variable_dictionary[selector[0]] = {}
|
||||
return
|
||||
key, hash_key = self._selector_to_keys(selector)
|
||||
self.variable_dictionary[key].pop(hash_key, None)
|
||||
|
||||
def convert_template(self, template: str, /):
|
||||
parts = VARIABLE_PATTERN.split(template)
|
||||
segments: list[Segment] = []
|
||||
for part in filter(lambda x: x, parts):
|
||||
if "." in part and (variable := self.get(part.split("."))):
|
||||
segments.append(variable)
|
||||
else:
|
||||
segments.append(variable_factory.build_segment(part))
|
||||
return SegmentGroup(value=segments)
|
||||
|
||||
def get_file(self, selector: Sequence[str], /) -> FileSegment | None:
|
||||
segment = self.get(selector)
|
||||
if isinstance(segment, FileSegment):
|
||||
return segment
|
||||
return None
|
||||
|
||||
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]:
|
||||
"""Return a copy of all variables stored under the given node prefix."""
|
||||
|
||||
nodes = self.variable_dictionary.get(prefix)
|
||||
if not nodes:
|
||||
return {}
|
||||
|
||||
result: dict[str, object] = {}
|
||||
for key, variable in nodes.items():
|
||||
value = variable.value
|
||||
result[key] = deepcopy(value)
|
||||
|
||||
return result
|
||||
|
||||
def _add_system_variables(self, system_variable: SystemVariable):
|
||||
sys_var_mapping = system_variable.to_dict()
|
||||
for key, value in sys_var_mapping.items():
|
||||
if value is None:
|
||||
continue
|
||||
selector = (SYSTEM_VARIABLE_NODE_ID, key)
|
||||
# If the system variable already exists, do not add it again.
|
||||
# This ensures that we can keep the id of the system variables intact.
|
||||
if self._has(selector):
|
||||
continue
|
||||
self.add(selector, value)
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "VariablePool":
|
||||
"""Create an empty variable pool."""
|
||||
return cls(system_variables=SystemVariable.empty())
|
||||
Reference in New Issue
Block a user