This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1,14 @@
from .graph_runtime_state import GraphRuntimeState
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
from .variable_pool import VariablePool, VariableValue
__all__ = [
"GraphRuntimeState",
"ReadOnlyGraphRuntimeState",
"ReadOnlyGraphRuntimeStateWrapper",
"ReadOnlyVariablePool",
"ReadOnlyVariablePoolWrapper",
"VariablePool",
"VariableValue",
]

View File

@@ -0,0 +1,472 @@
from __future__ import annotations
import importlib
import json
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Protocol
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime.variable_pool import VariablePool
class ReadyQueueProtocol(Protocol):
"""Structural interface required from ready queue implementations."""
def put(self, item: str) -> None:
"""Enqueue the identifier of a node that is ready to run."""
...
def get(self, timeout: float | None = None) -> str:
"""Return the next node identifier, blocking until available or timeout expires."""
...
def task_done(self) -> None:
"""Signal that the most recently dequeued node has completed processing."""
...
def empty(self) -> bool:
"""Return True when the queue contains no pending nodes."""
...
def qsize(self) -> int:
"""Approximate the number of pending nodes awaiting execution."""
...
def dumps(self) -> str:
"""Serialize the queue contents for persistence."""
...
def loads(self, data: str) -> None:
"""Restore the queue contents from a serialized payload."""
...
class GraphExecutionProtocol(Protocol):
"""Structural interface for graph execution aggregate."""
workflow_id: str
started: bool
completed: bool
aborted: bool
error: Exception | None
exceptions_count: int
def start(self) -> None:
"""Transition execution into the running state."""
...
def complete(self) -> None:
"""Mark execution as successfully completed."""
...
def abort(self, reason: str) -> None:
"""Abort execution in response to an external stop request."""
...
def fail(self, error: Exception) -> None:
"""Record an unrecoverable error and end execution."""
...
def dumps(self) -> str:
"""Serialize execution state into a JSON payload."""
...
def loads(self, data: str) -> None:
"""Restore execution state from a previously serialized payload."""
...
class ResponseStreamCoordinatorProtocol(Protocol):
"""Structural interface for response stream coordinator."""
def register(self, response_node_id: str) -> None:
"""Register a response node so its outputs can be streamed."""
...
def loads(self, data: str) -> None:
"""Restore coordinator state from a serialized payload."""
...
def dumps(self) -> str:
"""Serialize coordinator state for persistence."""
...
class GraphProtocol(Protocol):
"""Structural interface required from graph instances attached to the runtime state."""
nodes: Mapping[str, object]
edges: Mapping[str, object]
root_node: object
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
@dataclass(slots=True)
class _GraphRuntimeStateSnapshot:
"""Immutable view of a serialized runtime state snapshot."""
start_at: float
total_tokens: int
node_run_steps: int
llm_usage: LLMUsage
outputs: dict[str, Any]
variable_pool: VariablePool
has_variable_pool: bool
ready_queue_dump: str | None
graph_execution_dump: str | None
response_coordinator_dump: str | None
paused_nodes: tuple[str, ...]
class GraphRuntimeState:
"""Mutable runtime state shared across graph execution components."""
def __init__(
self,
*,
variable_pool: VariablePool,
start_at: float,
total_tokens: int = 0,
llm_usage: LLMUsage | None = None,
outputs: dict[str, object] | None = None,
node_run_steps: int = 0,
ready_queue: ReadyQueueProtocol | None = None,
graph_execution: GraphExecutionProtocol | None = None,
response_coordinator: ResponseStreamCoordinatorProtocol | None = None,
graph: GraphProtocol | None = None,
) -> None:
self._variable_pool = variable_pool
self._start_at = start_at
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
self._llm_usage = (llm_usage or LLMUsage.empty_usage()).model_copy()
self._outputs = deepcopy(outputs) if outputs is not None else {}
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
self._graph: GraphProtocol | None = None
self._ready_queue = ready_queue
self._graph_execution = graph_execution
self._response_coordinator = response_coordinator
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
if graph is not None:
self.attach_graph(graph)
# ------------------------------------------------------------------
# Context binding helpers
# ------------------------------------------------------------------
def attach_graph(self, graph: GraphProtocol) -> None:
"""Attach the materialized graph to the runtime state."""
if self._graph is not None and self._graph is not graph:
raise ValueError("GraphRuntimeState already attached to a different graph instance")
self._graph = graph
if self._response_coordinator is None:
self._response_coordinator = self._build_response_coordinator(graph)
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
self._response_coordinator.loads(self._pending_response_coordinator_dump)
self._pending_response_coordinator_dump = None
def configure(self, *, graph: GraphProtocol | None = None) -> None:
"""Ensure core collaborators are initialized with the provided context."""
if graph is not None:
self.attach_graph(graph)
# Ensure collaborators are instantiated
_ = self.ready_queue
_ = self.graph_execution
if self._graph is not None:
_ = self.response_coordinator
# ------------------------------------------------------------------
# Primary collaborators
# ------------------------------------------------------------------
@property
def variable_pool(self) -> VariablePool:
return self._variable_pool
@property
def ready_queue(self) -> ReadyQueueProtocol:
if self._ready_queue is None:
self._ready_queue = self._build_ready_queue()
return self._ready_queue
@property
def graph_execution(self) -> GraphExecutionProtocol:
if self._graph_execution is None:
self._graph_execution = self._build_graph_execution()
return self._graph_execution
@property
def response_coordinator(self) -> ResponseStreamCoordinatorProtocol:
if self._response_coordinator is None:
if self._graph is None:
raise ValueError("Graph must be attached before accessing response coordinator")
self._response_coordinator = self._build_response_coordinator(self._graph)
return self._response_coordinator
# ------------------------------------------------------------------
# Scalar state
# ------------------------------------------------------------------
@property
def start_at(self) -> float:
return self._start_at
@start_at.setter
def start_at(self, value: float) -> None:
self._start_at = value
@property
def total_tokens(self) -> int:
return self._total_tokens
@total_tokens.setter
def total_tokens(self, value: int) -> None:
if value < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = value
@property
def llm_usage(self) -> LLMUsage:
return self._llm_usage.model_copy()
@llm_usage.setter
def llm_usage(self, value: LLMUsage) -> None:
self._llm_usage = value.model_copy()
@property
def outputs(self) -> dict[str, Any]:
return deepcopy(self._outputs)
@outputs.setter
def outputs(self, value: dict[str, Any]) -> None:
self._outputs = deepcopy(value)
def set_output(self, key: str, value: object) -> None:
self._outputs[key] = deepcopy(value)
def get_output(self, key: str, default: object = None) -> object:
return deepcopy(self._outputs.get(key, default))
def update_outputs(self, updates: dict[str, object]) -> None:
for key, value in updates.items():
self._outputs[key] = deepcopy(value)
@property
def node_run_steps(self) -> int:
return self._node_run_steps
@node_run_steps.setter
def node_run_steps(self, value: int) -> None:
if value < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = value
def increment_node_run_steps(self) -> None:
self._node_run_steps += 1
def add_tokens(self, tokens: int) -> None:
if tokens < 0:
raise ValueError("tokens must be non-negative")
self._total_tokens += tokens
# ------------------------------------------------------------------
# Serialization
# ------------------------------------------------------------------
def dumps(self) -> str:
"""Serialize runtime state into a JSON string."""
snapshot: dict[str, Any] = {
"version": "1.0",
"start_at": self._start_at,
"total_tokens": self._total_tokens,
"node_run_steps": self._node_run_steps,
"llm_usage": self._llm_usage.model_dump(mode="json"),
"outputs": self.outputs,
"variable_pool": self.variable_pool.model_dump(mode="json"),
"ready_queue": self.ready_queue.dumps(),
"graph_execution": self.graph_execution.dumps(),
"paused_nodes": list(self._paused_nodes),
}
if self._response_coordinator is not None and self._graph is not None:
snapshot["response_coordinator"] = self._response_coordinator.dumps()
return json.dumps(snapshot, default=pydantic_encoder)
@classmethod
def from_snapshot(cls, data: str | Mapping[str, Any]) -> GraphRuntimeState:
"""Restore runtime state from a serialized snapshot."""
snapshot = cls._parse_snapshot_payload(data)
state = cls(
variable_pool=snapshot.variable_pool,
start_at=snapshot.start_at,
total_tokens=snapshot.total_tokens,
llm_usage=snapshot.llm_usage,
outputs=snapshot.outputs,
node_run_steps=snapshot.node_run_steps,
)
state._apply_snapshot(snapshot)
return state
def loads(self, data: str | Mapping[str, Any]) -> None:
"""Restore runtime state from a serialized snapshot (legacy API)."""
snapshot = self._parse_snapshot_payload(data)
self._apply_snapshot(snapshot)
def register_paused_node(self, node_id: str) -> None:
"""Record a node that should resume when execution is continued."""
self._paused_nodes.add(node_id)
def consume_paused_nodes(self) -> list[str]:
"""Retrieve and clear the list of paused nodes awaiting resume."""
nodes = list(self._paused_nodes)
self._paused_nodes.clear()
return nodes
# ------------------------------------------------------------------
# Builders
# ------------------------------------------------------------------
def _build_ready_queue(self) -> ReadyQueueProtocol:
# Import lazily to avoid breaching architecture boundaries enforced by import-linter.
module = importlib.import_module("core.workflow.graph_engine.ready_queue")
in_memory_cls = module.InMemoryReadyQueue
return in_memory_cls()
def _build_graph_execution(self) -> GraphExecutionProtocol:
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
module = importlib.import_module("core.workflow.graph_engine.domain.graph_execution")
graph_execution_cls = module.GraphExecution
workflow_id = self._pending_graph_execution_workflow_id or ""
self._pending_graph_execution_workflow_id = None
return graph_execution_cls(workflow_id=workflow_id)
def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol:
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
module = importlib.import_module("core.workflow.graph_engine.response_coordinator")
coordinator_cls = module.ResponseStreamCoordinator
return coordinator_cls(variable_pool=self.variable_pool, graph=graph)
# ------------------------------------------------------------------
# Snapshot helpers
# ------------------------------------------------------------------
@classmethod
def _parse_snapshot_payload(cls, data: str | Mapping[str, Any]) -> _GraphRuntimeStateSnapshot:
payload: dict[str, Any]
if isinstance(data, str):
payload = json.loads(data)
else:
payload = dict(data)
version = payload.get("version")
if version != "1.0":
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}")
start_at = float(payload.get("start_at", 0.0))
total_tokens = int(payload.get("total_tokens", 0))
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
node_run_steps = int(payload.get("node_run_steps", 0))
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
llm_usage_payload = payload.get("llm_usage", {})
llm_usage = LLMUsage.model_validate(llm_usage_payload)
outputs_payload = deepcopy(payload.get("outputs", {}))
variable_pool_payload = payload.get("variable_pool")
has_variable_pool = variable_pool_payload is not None
variable_pool = VariablePool.model_validate(variable_pool_payload) if has_variable_pool else VariablePool()
ready_queue_payload = payload.get("ready_queue")
graph_execution_payload = payload.get("graph_execution")
response_payload = payload.get("response_coordinator")
paused_nodes_payload = payload.get("paused_nodes", [])
return _GraphRuntimeStateSnapshot(
start_at=start_at,
total_tokens=total_tokens,
node_run_steps=node_run_steps,
llm_usage=llm_usage,
outputs=outputs_payload,
variable_pool=variable_pool,
has_variable_pool=has_variable_pool,
ready_queue_dump=ready_queue_payload,
graph_execution_dump=graph_execution_payload,
response_coordinator_dump=response_payload,
paused_nodes=tuple(map(str, paused_nodes_payload)),
)
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
self._start_at = snapshot.start_at
self._total_tokens = snapshot.total_tokens
self._node_run_steps = snapshot.node_run_steps
self._llm_usage = snapshot.llm_usage.model_copy()
self._outputs = deepcopy(snapshot.outputs)
if snapshot.has_variable_pool or self._variable_pool is None:
self._variable_pool = snapshot.variable_pool
self._restore_ready_queue(snapshot.ready_queue_dump)
self._restore_graph_execution(snapshot.graph_execution_dump)
self._restore_response_coordinator(snapshot.response_coordinator_dump)
self._paused_nodes = set(snapshot.paused_nodes)
def _restore_ready_queue(self, payload: str | None) -> None:
if payload is not None:
self._ready_queue = self._build_ready_queue()
self._ready_queue.loads(payload)
else:
self._ready_queue = None
def _restore_graph_execution(self, payload: str | None) -> None:
self._graph_execution = None
self._pending_graph_execution_workflow_id = None
if payload is None:
return
try:
execution_payload = json.loads(payload)
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
except (json.JSONDecodeError, TypeError, AttributeError):
self._pending_graph_execution_workflow_id = None
self.graph_execution.loads(payload)
def _restore_response_coordinator(self, payload: str | None) -> None:
if payload is None:
self._pending_response_coordinator_dump = None
self._response_coordinator = None
return
if self._graph is not None:
self.response_coordinator.loads(payload)
self._pending_response_coordinator_dump = None
return
self._pending_response_coordinator_dump = payload
self._response_coordinator = None

View File

@@ -0,0 +1,83 @@
from collections.abc import Mapping
from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
from core.workflow.system_variable import SystemVariableReadOnlyView
class ReadOnlyVariablePool(Protocol):
"""Read-only interface for VariablePool."""
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Get a variable value (read-only)."""
...
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Get all variables for a node (read-only)."""
...
def get_by_prefix(self, prefix: str) -> Mapping[str, object]:
"""Get all variables stored under a given node prefix (read-only)."""
...
class ReadOnlyGraphRuntimeState(Protocol):
"""
Read-only view of GraphRuntimeState for layers.
This protocol defines a read-only interface that prevents layers from
modifying the graph runtime state while still allowing observation.
All methods return defensive copies to ensure immutability.
"""
@property
def system_variable(self) -> SystemVariableReadOnlyView: ...
@property
def variable_pool(self) -> ReadOnlyVariablePool:
"""Get read-only access to the variable pool."""
...
@property
def start_at(self) -> float:
"""Get the start time (read-only)."""
...
@property
def total_tokens(self) -> int:
"""Get the total tokens count (read-only)."""
...
@property
def llm_usage(self) -> LLMUsage:
"""Get a copy of LLM usage info (read-only)."""
...
@property
def outputs(self) -> dict[str, Any]:
"""Get a defensive copy of outputs (read-only)."""
...
@property
def node_run_steps(self) -> int:
"""Get the node run steps count (read-only)."""
...
@property
def ready_queue_size(self) -> int:
"""Get the number of nodes currently in the ready queue."""
...
@property
def exceptions_count(self) -> int:
"""Get the number of node execution exceptions recorded."""
...
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
...
def dumps(self) -> str:
"""Serialize the runtime state into a JSON snapshot (read-only)."""
...

View File

@@ -0,0 +1,87 @@
from __future__ import annotations
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
from core.workflow.system_variable import SystemVariableReadOnlyView
from .graph_runtime_state import GraphRuntimeState
from .variable_pool import VariablePool
class ReadOnlyVariablePoolWrapper:
"""Provide defensive, read-only access to ``VariablePool``."""
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Return a copy of a variable value if present."""
value = self._variable_pool.get([node_id, variable_key])
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Return a copy of all variables for the specified node."""
variables: dict[str, object] = {}
if node_id in self._variable_pool.variable_dictionary:
for key, variable in self._variable_pool.variable_dictionary[node_id].items():
variables[key] = deepcopy(variable.value)
return variables
def get_by_prefix(self, prefix: str) -> Mapping[str, object]:
"""Return a copy of all variables stored under the given prefix."""
return self._variable_pool.get_by_prefix(prefix)
class ReadOnlyGraphRuntimeStateWrapper:
"""Expose a defensive, read-only view of ``GraphRuntimeState``."""
def __init__(self, state: GraphRuntimeState) -> None:
self._state = state
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
@property
def system_variable(self) -> SystemVariableReadOnlyView:
return self._state.variable_pool.system_variables.as_view()
@property
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
return self._variable_pool_wrapper
@property
def start_at(self) -> float:
return self._state.start_at
@property
def total_tokens(self) -> int:
return self._state.total_tokens
@property
def llm_usage(self) -> LLMUsage:
return self._state.llm_usage.model_copy()
@property
def outputs(self) -> dict[str, Any]:
return deepcopy(self._state.outputs)
@property
def node_run_steps(self) -> int:
return self._state.node_run_steps
@property
def ready_queue_size(self) -> int:
return self._state.ready_queue.qsize()
@property
def exceptions_count(self) -> int:
return self._state.graph_execution.exceptions_count
def get_output(self, key: str, default: Any = None) -> Any:
return self._state.get_output(key, default)
def dumps(self) -> str:
"""Serialize the underlying runtime state for external persistence."""
return self._state.dumps()

View File

@@ -0,0 +1,272 @@
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
RAG_PIPELINE_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
from core.workflow.system_variable import SystemVariable
from factories import variable_factory
VariableValue = Union[str, int, float, dict[str, object], list[object], File]
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
class VariablePool(BaseModel):
# Variable dictionary is a dictionary for looking up variables by their selector.
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
# The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere.
user_inputs: Mapping[str, Any] = Field(
description="User inputs",
default_factory=dict,
)
system_variables: SystemVariable = Field(
description="System variables",
default_factory=SystemVariable.empty,
)
environment_variables: Sequence[VariableUnion] = Field(
description="Environment variables.",
default_factory=list[VariableUnion],
)
conversation_variables: Sequence[VariableUnion] = Field(
description="Conversation variables.",
default_factory=list[VariableUnion],
)
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.",
default_factory=list,
)
def model_post_init(self, context: Any, /):
# Create a mapping from field names to SystemVariableKey enum values
self._add_system_variables(self.system_variables)
# Add environment variables to the variable pool
for var in self.environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
# Add conversation variables to the variable pool
for var in self.conversation_variables:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
# Add rag pipeline variables to the variable pool
if self.rag_pipeline_variables:
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)
for rag_var in self.rag_pipeline_variables:
node_id = rag_var.variable.belong_to_node_id
key = rag_var.variable.variable
value = rag_var.value
rag_pipeline_variables_map[node_id][key] = value
for key, value in rag_pipeline_variables_map.items():
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
def add(self, selector: Sequence[str], value: Any, /):
"""
Add a variable to the variable pool.
This method accepts a selector path and a value, converting the value
to a Variable object if necessary before storing it in the pool.
Args:
selector: A two-element sequence containing [node_id, variable_name].
The selector must have exactly 2 elements to be valid.
value: The value to store. Can be a Variable, Segment, or any value
that can be converted to a Segment (str, int, float, dict, list, File).
Raises:
ValueError: If selector length is not exactly 2 elements.
Note:
While non-Segment values are currently accepted and automatically
converted, it's recommended to pass Segment or Variable objects directly.
"""
if len(selector) != SELECTORS_LENGTH:
raise ValueError(
f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), "
f"got {len(selector)} elements"
)
if isinstance(value, Variable):
variable = value
elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
else:
segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
node_id, name = self._selector_to_keys(selector)
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
return selector[0], selector[1]
def _has(self, selector: Sequence[str]) -> bool:
node_id, name = self._selector_to_keys(selector)
if node_id not in self.variable_dictionary:
return False
if name not in self.variable_dictionary[node_id]:
return False
return True
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
Retrieve a variable's value from the pool as a Segment.
This method supports both simple selectors [node_id, variable_name] and
extended selectors that include attribute access for FileSegment and
ObjectSegment types.
Args:
selector: A sequence with at least 2 elements:
- [node_id, variable_name]: Returns the full segment
- [node_id, variable_name, attr, ...]: Returns a nested value
from FileSegment (e.g., 'url', 'name') or ObjectSegment
Returns:
The Segment associated with the selector, or None if not found.
Returns None if selector has fewer than 2 elements.
Raises:
ValueError: If attempting to access an invalid FileAttribute.
"""
if len(selector) < SELECTORS_LENGTH:
return None
node_id, name = self._selector_to_keys(selector)
node_map = self.variable_dictionary.get(node_id)
if node_map is None:
return None
segment: Segment | None = node_map.get(name)
if segment is None:
return None
if len(selector) == 2:
return segment
if isinstance(segment, FileSegment):
attr = selector[2]
# Python support `attr in FileAttribute` after 3.12
if attr not in {item.value for item in FileAttribute}:
return None
attr = FileAttribute(attr)
attr_value = file_manager.get_attr(file=segment.value, attr=attr)
return variable_factory.build_segment(attr_value)
# Navigate through nested attributes
result: Any = segment
for attr in selector[2:]:
result = self._extract_value(result)
result = self._get_nested_attribute(result, attr)
if result is None:
return None
# Return result as Segment
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
def _extract_value(self, obj: Any):
"""Extract the actual value from an ObjectSegment."""
return obj.value if isinstance(obj, ObjectSegment) else obj
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None:
"""
Get a nested attribute from a dictionary-like object.
Args:
obj: The dictionary-like object to search.
attr: The key to look up.
Returns:
Segment | None:
The corresponding Segment built from the attribute value if the key exists,
otherwise None.
"""
if not isinstance(obj, dict) or attr not in obj:
return None
return variable_factory.build_segment(obj.get(attr))
def remove(self, selector: Sequence[str], /):
"""
Remove variables from the variable pool based on the given selector.
Args:
selector (Sequence[str]): A sequence of strings representing the selector.
Returns:
None
"""
if not selector:
return
if len(selector) == 1:
self.variable_dictionary[selector[0]] = {}
return
key, hash_key = self._selector_to_keys(selector)
self.variable_dictionary[key].pop(hash_key, None)
def convert_template(self, template: str, /):
parts = VARIABLE_PATTERN.split(template)
segments: list[Segment] = []
for part in filter(lambda x: x, parts):
if "." in part and (variable := self.get(part.split("."))):
segments.append(variable)
else:
segments.append(variable_factory.build_segment(part))
return SegmentGroup(value=segments)
def get_file(self, selector: Sequence[str], /) -> FileSegment | None:
segment = self.get(selector)
if isinstance(segment, FileSegment):
return segment
return None
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]:
"""Return a copy of all variables stored under the given node prefix."""
nodes = self.variable_dictionary.get(prefix)
if not nodes:
return {}
result: dict[str, object] = {}
for key, variable in nodes.items():
value = variable.value
result[key] = deepcopy(value)
return result
def _add_system_variables(self, system_variable: SystemVariable):
sys_var_mapping = system_variable.to_dict()
for key, value in sys_var_mapping.items():
if value is None:
continue
selector = (SYSTEM_VARIABLE_NODE_ID, key)
# If the system variable already exists, do not add it again.
# This ensures that we can keep the id of the system variables intact.
if self._has(selector):
continue
self.add(selector, value)
@classmethod
def empty(cls) -> "VariablePool":
"""Create an empty variable pool."""
return cls(system_variables=SystemVariable.empty())