This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1,377 @@
"""Simplified unit tests for DraftVarLoader focusing on core functionality."""
import json
from unittest.mock import Mock, patch
import pytest
from sqlalchemy import Engine
from core.variables.segments import ObjectSegment, StringSegment
from core.variables.types import SegmentType
from models.model import UploadFile
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
from services.workflow_draft_variable_service import DraftVarLoader
class TestDraftVarLoaderSimple:
"""Simplified unit tests for DraftVarLoader core methods."""
@pytest.fixture
def mock_engine(self) -> Engine:
return Mock(spec=Engine)
@pytest.fixture
def draft_var_loader(self, mock_engine):
"""Create DraftVarLoader instance for testing."""
return DraftVarLoader(
engine=mock_engine, app_id="test-app-id", tenant_id="test-tenant-id", fallback_variables=[]
)
def test_load_offloaded_variable_string_type_unit(self, draft_var_loader):
"""Test _load_offloaded_variable with string type - isolated unit test."""
# Create mock objects
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/test.txt"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.STRING
variable_file.upload_file = upload_file
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.id = "draft-var-id"
draft_var.node_id = "test-node-id"
draft_var.name = "test_variable"
draft_var.description = "test description"
draft_var.get_selector.return_value = ["test-node-id", "test_variable"]
draft_var.variable_file = variable_file
test_content = "This is the full string content"
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = test_content.encode()
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
mock_variable = Mock()
mock_variable.id = "draft-var-id"
mock_variable.name = "test_variable"
mock_variable.value = StringSegment(value=test_content)
mock_segment_to_variable.return_value = mock_variable
# Execute the method
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
# Verify results
assert selector_tuple == ("test-node-id", "test_variable")
assert variable.id == "draft-var-id"
assert variable.name == "test_variable"
assert variable.description == "test description"
assert variable.value == test_content
# Verify storage was called correctly
mock_storage.load.assert_called_once_with("storage/key/test.txt")
def test_load_offloaded_variable_object_type_unit(self, draft_var_loader):
"""Test _load_offloaded_variable with object type - isolated unit test."""
# Create mock objects
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/test.json"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.OBJECT
variable_file.upload_file = upload_file
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.id = "draft-var-id"
draft_var.node_id = "test-node-id"
draft_var.name = "test_object"
draft_var.description = "test description"
draft_var.get_selector.return_value = ["test-node-id", "test_object"]
draft_var.variable_file = variable_file
test_object = {"key1": "value1", "key2": 42}
test_json_content = json.dumps(test_object, ensure_ascii=False, separators=(",", ":"))
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = test_json_content.encode()
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
mock_segment = ObjectSegment(value=test_object)
mock_build_segment.return_value = mock_segment
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
mock_variable = Mock()
mock_variable.id = "draft-var-id"
mock_variable.name = "test_object"
mock_variable.value = mock_segment
mock_segment_to_variable.return_value = mock_variable
# Execute the method
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
# Verify results
assert selector_tuple == ("test-node-id", "test_object")
assert variable.id == "draft-var-id"
assert variable.name == "test_object"
assert variable.description == "test description"
assert variable.value == test_object
# Verify method calls
mock_storage.load.assert_called_once_with("storage/key/test.json")
mock_build_segment.assert_called_once_with(SegmentType.OBJECT, test_object)
def test_load_offloaded_variable_missing_variable_file_unit(self, draft_var_loader):
"""Test that assertion error is raised when variable_file is None."""
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.variable_file = None
with pytest.raises(AssertionError):
draft_var_loader._load_offloaded_variable(draft_var)
def test_load_offloaded_variable_missing_upload_file_unit(self, draft_var_loader):
"""Test that assertion error is raised when upload_file is None."""
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.upload_file = None
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.variable_file = variable_file
with pytest.raises(AssertionError):
draft_var_loader._load_offloaded_variable(draft_var)
def test_load_variables_empty_selectors_unit(self, draft_var_loader):
"""Test load_variables returns empty list for empty selectors."""
result = draft_var_loader.load_variables([])
assert result == []
def test_selector_to_tuple_unit(self, draft_var_loader):
"""Test _selector_to_tuple method."""
selector = ["node_id", "var_name", "extra_field"]
result = draft_var_loader._selector_to_tuple(selector)
assert result == ("node_id", "var_name")
def test_load_offloaded_variable_number_type_unit(self, draft_var_loader):
"""Test _load_offloaded_variable with number type - isolated unit test."""
# Create mock objects
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/test_number.json"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.NUMBER
variable_file.upload_file = upload_file
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.id = "draft-var-id"
draft_var.node_id = "test-node-id"
draft_var.name = "test_number"
draft_var.description = "test number description"
draft_var.get_selector.return_value = ["test-node-id", "test_number"]
draft_var.variable_file = variable_file
test_number = 123.45
test_json_content = json.dumps(test_number)
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = test_json_content.encode()
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
from core.variables.segments import FloatSegment
mock_segment = FloatSegment(value=test_number)
mock_build_segment.return_value = mock_segment
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
mock_variable = Mock()
mock_variable.id = "draft-var-id"
mock_variable.name = "test_number"
mock_variable.value = mock_segment
mock_segment_to_variable.return_value = mock_variable
# Execute the method
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
# Verify results
assert selector_tuple == ("test-node-id", "test_number")
assert variable.id == "draft-var-id"
assert variable.name == "test_number"
assert variable.description == "test number description"
# Verify method calls
mock_storage.load.assert_called_once_with("storage/key/test_number.json")
mock_build_segment.assert_called_once_with(SegmentType.NUMBER, test_number)
def test_load_offloaded_variable_array_type_unit(self, draft_var_loader):
"""Test _load_offloaded_variable with array type - isolated unit test."""
# Create mock objects
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/test_array.json"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.ARRAY_ANY
variable_file.upload_file = upload_file
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.id = "draft-var-id"
draft_var.node_id = "test-node-id"
draft_var.name = "test_array"
draft_var.description = "test array description"
draft_var.get_selector.return_value = ["test-node-id", "test_array"]
draft_var.variable_file = variable_file
test_array = ["item1", "item2", "item3"]
test_json_content = json.dumps(test_array)
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = test_json_content.encode()
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
from core.variables.segments import ArrayAnySegment
mock_segment = ArrayAnySegment(value=test_array)
mock_build_segment.return_value = mock_segment
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
mock_variable = Mock()
mock_variable.id = "draft-var-id"
mock_variable.name = "test_array"
mock_variable.value = mock_segment
mock_segment_to_variable.return_value = mock_variable
# Execute the method
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
# Verify results
assert selector_tuple == ("test-node-id", "test_array")
assert variable.id == "draft-var-id"
assert variable.name == "test_array"
assert variable.description == "test array description"
# Verify method calls
mock_storage.load.assert_called_once_with("storage/key/test_array.json")
mock_build_segment.assert_called_once_with(SegmentType.ARRAY_ANY, test_array)
def test_load_variables_with_offloaded_variables_unit(self, draft_var_loader):
"""Test load_variables method with mix of regular and offloaded variables."""
selectors = [["node1", "regular_var"], ["node2", "offloaded_var"]]
# Mock regular variable
regular_draft_var = Mock(spec=WorkflowDraftVariable)
regular_draft_var.is_truncated.return_value = False
regular_draft_var.node_id = "node1"
regular_draft_var.name = "regular_var"
regular_draft_var.get_value.return_value = StringSegment(value="regular_value")
regular_draft_var.get_selector.return_value = ["node1", "regular_var"]
regular_draft_var.id = "regular-var-id"
regular_draft_var.description = "regular description"
# Mock offloaded variable
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/offloaded.txt"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.STRING
variable_file.upload_file = upload_file
offloaded_draft_var = Mock(spec=WorkflowDraftVariable)
offloaded_draft_var.is_truncated.return_value = True
offloaded_draft_var.node_id = "node2"
offloaded_draft_var.name = "offloaded_var"
offloaded_draft_var.get_selector.return_value = ["node2", "offloaded_var"]
offloaded_draft_var.variable_file = variable_file
offloaded_draft_var.id = "offloaded-var-id"
offloaded_draft_var.description = "offloaded description"
draft_vars = [regular_draft_var, offloaded_draft_var]
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
mock_session = Mock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_service = Mock()
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
with patch(
"services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service
):
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
# Mock regular variable creation
regular_variable = Mock()
regular_variable.selector = ["node1", "regular_var"]
# Mock offloaded variable creation
offloaded_variable = Mock()
offloaded_variable.selector = ["node2", "offloaded_var"]
mock_segment_to_variable.return_value = regular_variable
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = b"offloaded_content"
with patch.object(draft_var_loader, "_load_offloaded_variable") as mock_load_offloaded:
mock_load_offloaded.return_value = (("node2", "offloaded_var"), offloaded_variable)
with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor_cls:
mock_executor = Mock()
mock_executor_cls.return_value.__enter__.return_value = mock_executor
mock_executor.map.return_value = [(("node2", "offloaded_var"), offloaded_variable)]
# Execute the method
result = draft_var_loader.load_variables(selectors)
# Verify results
assert len(result) == 2
# Verify service method was called
mock_service.get_draft_variables_by_selectors.assert_called_once_with(
draft_var_loader._app_id, selectors
)
# Verify offloaded variable loading was called
mock_load_offloaded.assert_called_once_with(offloaded_draft_var)
def test_load_variables_all_offloaded_variables_unit(self, draft_var_loader):
"""Test load_variables method with only offloaded variables."""
selectors = [["node1", "offloaded_var1"], ["node2", "offloaded_var2"]]
# Mock first offloaded variable
offloaded_var1 = Mock(spec=WorkflowDraftVariable)
offloaded_var1.is_truncated.return_value = True
offloaded_var1.node_id = "node1"
offloaded_var1.name = "offloaded_var1"
# Mock second offloaded variable
offloaded_var2 = Mock(spec=WorkflowDraftVariable)
offloaded_var2.is_truncated.return_value = True
offloaded_var2.node_id = "node2"
offloaded_var2.name = "offloaded_var2"
draft_vars = [offloaded_var1, offloaded_var2]
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
mock_session = Mock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_service = Mock()
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
with patch(
"services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service
):
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
with patch("services.workflow_draft_variable_service.ThreadPoolExecutor") as mock_executor_cls:
mock_executor = Mock()
mock_executor_cls.return_value.__enter__.return_value = mock_executor
mock_executor.map.return_value = [
(("node1", "offloaded_var1"), Mock()),
(("node2", "offloaded_var2"), Mock()),
]
# Execute the method
result = draft_var_loader.load_variables(selectors)
# Verify results - since we have only offloaded variables, should have 2 results
assert len(result) == 2
# Verify ThreadPoolExecutor was used
mock_executor_cls.assert_called_once_with(max_workers=10)
mock_executor.map.assert_called_once()

View File

@@ -0,0 +1,432 @@
# test for api/services/workflow/workflow_converter.py
import json
from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import (
AdvancedChatMessageEntity,
AdvancedChatPromptTemplateEntity,
AdvancedCompletionPromptTemplateEntity,
DatasetEntity,
DatasetRetrieveConfigEntity,
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
VariableEntityType,
)
from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from models.model import AppMode
from services.workflow.workflow_converter import WorkflowConverter
@pytest.fixture
def default_variables():
value = [
VariableEntity(
variable="text_input",
label="text-input",
type=VariableEntityType.TEXT_INPUT,
),
VariableEntity(
variable="paragraph",
label="paragraph",
type=VariableEntityType.PARAGRAPH,
),
VariableEntity(
variable="select",
label="select",
type=VariableEntityType.SELECT,
),
]
return value
def test__convert_to_start_node(default_variables):
# act
result = WorkflowConverter()._convert_to_start_node(default_variables)
# assert
assert isinstance(result["data"]["variables"][0]["type"], str)
assert result["data"]["variables"][0]["type"] == "text-input"
assert result["data"]["variables"][0]["variable"] == "text_input"
assert result["data"]["variables"][1]["variable"] == "paragraph"
assert result["data"]["variables"][2]["variable"] == "select"
def test__convert_to_http_request_node_for_chatbot(default_variables):
"""
Test convert to http request nodes for chatbot
:return:
"""
app_model = MagicMock()
app_model.id = "app_id"
app_model.tenant_id = "tenant_id"
app_model.mode = AppMode.CHAT
api_based_extension_id = "api_based_extension_id"
mock_api_based_extension = APIBasedExtension(
tenant_id="tenant_id",
name="api-1",
api_key="encrypted_api_key",
api_endpoint="https://dify.ai",
)
mock_api_based_extension.id = api_based_extension_id
workflow_converter = WorkflowConverter()
workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension)
encrypter.decrypt_token = MagicMock(return_value="api_key")
external_data_variables = [
ExternalDataVariableEntity(
variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id}
)
]
nodes, _ = workflow_converter._convert_to_http_request_node(
app_model=app_model, variables=default_variables, external_data_variables=external_data_variables
)
assert len(nodes) == 2
assert nodes[0]["data"]["type"] == "http-request"
http_request_node = nodes[0]
assert http_request_node["data"]["method"] == "post"
assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint
assert http_request_node["data"]["authorization"]["type"] == "api-key"
assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"}
assert http_request_node["data"]["body"]["type"] == "json"
body_data = http_request_node["data"]["body"]["data"]
assert body_data
body_data_json = json.loads(body_data)
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
body_params = body_data_json["params"]
assert body_params["app_id"] == app_model.id
assert body_params["tool_variable"] == external_data_variables[0].variable
assert len(body_params["inputs"]) == 3
assert body_params["query"] == "{{#sys.query#}}" # for chatbot
code_node = nodes[1]
assert code_node["data"]["type"] == "code"
def test__convert_to_http_request_node_for_workflow_app(default_variables):
"""
Test convert to http request nodes for workflow app
:return:
"""
app_model = MagicMock()
app_model.id = "app_id"
app_model.tenant_id = "tenant_id"
app_model.mode = AppMode.WORKFLOW
api_based_extension_id = "api_based_extension_id"
mock_api_based_extension = APIBasedExtension(
tenant_id="tenant_id",
name="api-1",
api_key="encrypted_api_key",
api_endpoint="https://dify.ai",
)
mock_api_based_extension.id = api_based_extension_id
workflow_converter = WorkflowConverter()
workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension)
encrypter.decrypt_token = MagicMock(return_value="api_key")
external_data_variables = [
ExternalDataVariableEntity(
variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id}
)
]
nodes, _ = workflow_converter._convert_to_http_request_node(
app_model=app_model, variables=default_variables, external_data_variables=external_data_variables
)
assert len(nodes) == 2
assert nodes[0]["data"]["type"] == "http-request"
http_request_node = nodes[0]
assert http_request_node["data"]["method"] == "post"
assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint
assert http_request_node["data"]["authorization"]["type"] == "api-key"
assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"}
assert http_request_node["data"]["body"]["type"] == "json"
body_data = http_request_node["data"]["body"]["data"]
assert body_data
body_data_json = json.loads(body_data)
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
body_params = body_data_json["params"]
assert body_params["app_id"] == app_model.id
assert body_params["tool_variable"] == external_data_variables[0].variable
assert len(body_params["inputs"]) == 3
assert body_params["query"] == ""
code_node = nodes[1]
assert code_node["data"]["type"] == "code"
def test__convert_to_knowledge_retrieval_node_for_chatbot():
new_app_mode = AppMode.ADVANCED_CHAT
dataset_config = DatasetEntity(
dataset_ids=["dataset_id_1", "dataset_id_2"],
retrieve_config=DatasetRetrieveConfigEntity(
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=5,
score_threshold=0.8,
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
reranking_enabled=True,
),
)
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config
)
assert node is not None
assert node["data"]["type"] == "knowledge-retrieval"
assert node["data"]["query_variable_selector"] == ["sys", "query"]
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value
assert node["data"]["multiple_retrieval_config"] == {
"top_k": dataset_config.retrieve_config.top_k,
"score_threshold": dataset_config.retrieve_config.score_threshold,
"reranking_model": dataset_config.retrieve_config.reranking_model,
}
def test__convert_to_knowledge_retrieval_node_for_workflow_app():
new_app_mode = AppMode.WORKFLOW
dataset_config = DatasetEntity(
dataset_ids=["dataset_id_1", "dataset_id_2"],
retrieve_config=DatasetRetrieveConfigEntity(
query_variable="query",
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=5,
score_threshold=0.8,
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
reranking_enabled=True,
),
)
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config
)
assert node is not None
assert node["data"]["type"] == "knowledge-retrieval"
assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable]
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value
assert node["data"]["multiple_retrieval_config"] == {
"top_k": dataset_config.retrieve_config.top_k,
"score_threshold": dataset_config.retrieve_config.score_threshold,
"reranking_model": dataset_config.retrieve_config.reranking_model,
}
def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
new_app_mode = AppMode.ADVANCED_CHAT
model = "gpt-4"
model_mode = LLMMode.CHAT
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
model_config_mock.stop = []
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.",
)
llm_node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
template = prompt_template.simple_prompt_template
assert template is not None
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n"
assert llm_node["data"]["context"]["enabled"] is False
def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables):
new_app_mode = AppMode.ADVANCED_CHAT
model = "gpt-3.5-turbo-instruct"
model_mode = LLMMode.COMPLETION
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
model_config_mock.stop = []
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.",
)
llm_node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
template = prompt_template.simple_prompt_template
assert template is not None
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"]["text"] == template + "\n"
assert llm_node["data"]["context"]["enabled"] is False
def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables):
new_app_mode = AppMode.ADVANCED_CHAT
model = "gpt-4"
model_mode = LLMMode.CHAT
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
model_config_mock.stop = []
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
messages=[
AdvancedChatMessageEntity(
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
role=PromptMessageRole.SYSTEM,
),
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
]
),
)
llm_node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
assert isinstance(llm_node["data"]["prompt_template"], list)
assert prompt_template.advanced_chat_prompt_template is not None
assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages)
template = prompt_template.advanced_chat_prompt_template.messages[0].text
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"][0]["text"] == template
def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables):
new_app_mode = AppMode.ADVANCED_CHAT
model = "gpt-3.5-turbo-instruct"
model_mode = LLMMode.COMPLETION
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
model_config_mock.stop = []
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ",
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"),
),
)
llm_node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
assert isinstance(llm_node["data"]["prompt_template"], dict)
assert prompt_template.advanced_completion_prompt_template is not None
template = prompt_template.advanced_completion_prompt_template.prompt
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"]["text"] == template

View File

@@ -0,0 +1,127 @@
from unittest.mock import MagicMock
import pytest
from sqlalchemy.orm import Session
from models.model import App
from models.workflow import Workflow
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
@pytest.fixture
def workflow_setup():
mock_session_maker = MagicMock()
workflow_service = WorkflowService(mock_session_maker)
session = MagicMock(spec=Session)
tenant_id = "test-tenant-id"
workflow_id = "test-workflow-id"
# Mock workflow
workflow = MagicMock(spec=Workflow)
workflow.id = workflow_id
workflow.tenant_id = tenant_id
workflow.version = "1.0" # Not a draft
workflow.tool_published = False # Not published as a tool by default
# Mock app
app = MagicMock(spec=App)
app.id = "test-app-id"
app.name = "Test App"
app.workflow_id = None # Not used by an app by default
return {
"workflow_service": workflow_service,
"session": session,
"tenant_id": tenant_id,
"workflow_id": workflow_id,
"workflow": workflow,
"app": app,
}
def test_delete_workflow_success(workflow_setup):
# Setup mocks
# Mock the tool provider query to return None (not published as a tool)
workflow_setup["session"].query.return_value.where.return_value.first.return_value = None
workflow_setup["session"].scalar = MagicMock(
side_effect=[workflow_setup["workflow"], None]
) # Return workflow first, then None for app
# Call the method
result = workflow_setup["workflow_service"].delete_workflow(
session=workflow_setup["session"],
workflow_id=workflow_setup["workflow_id"],
tenant_id=workflow_setup["tenant_id"],
)
# Verify
assert result is True
workflow_setup["session"].delete.assert_called_once_with(workflow_setup["workflow"])
def test_delete_workflow_draft_error(workflow_setup):
# Setup mocks
workflow_setup["workflow"].version = "draft"
workflow_setup["session"].scalar = MagicMock(return_value=workflow_setup["workflow"])
# Call the method and verify exception
with pytest.raises(DraftWorkflowDeletionError):
workflow_setup["workflow_service"].delete_workflow(
session=workflow_setup["session"],
workflow_id=workflow_setup["workflow_id"],
tenant_id=workflow_setup["tenant_id"],
)
# Verify
workflow_setup["session"].delete.assert_not_called()
def test_delete_workflow_in_use_by_app_error(workflow_setup):
# Setup mocks
workflow_setup["app"].workflow_id = workflow_setup["workflow_id"]
workflow_setup["session"].scalar = MagicMock(
side_effect=[workflow_setup["workflow"], workflow_setup["app"]]
) # Return workflow first, then app
# Call the method and verify exception
with pytest.raises(WorkflowInUseError) as excinfo:
workflow_setup["workflow_service"].delete_workflow(
session=workflow_setup["session"],
workflow_id=workflow_setup["workflow_id"],
tenant_id=workflow_setup["tenant_id"],
)
# Verify error message contains app name
assert "Cannot delete workflow that is currently in use by app" in str(excinfo.value)
# Verify
workflow_setup["session"].delete.assert_not_called()
def test_delete_workflow_published_as_tool_error(workflow_setup):
# Setup mocks
from models.tools import WorkflowToolProvider
# Mock the tool provider query
mock_tool_provider = MagicMock(spec=WorkflowToolProvider)
workflow_setup["session"].query.return_value.where.return_value.first.return_value = mock_tool_provider
workflow_setup["session"].scalar = MagicMock(
side_effect=[workflow_setup["workflow"], None]
) # Return workflow first, then None for app
# Call the method and verify exception
with pytest.raises(WorkflowInUseError) as excinfo:
workflow_setup["workflow_service"].delete_workflow(
session=workflow_setup["session"],
workflow_id=workflow_setup["workflow_id"],
tenant_id=workflow_setup["tenant_id"],
)
# Verify error message
assert "Cannot delete workflow that is published as a tool" in str(excinfo.value)
# Verify
workflow_setup["session"].delete.assert_not_called()

View File

@@ -0,0 +1,477 @@
import dataclasses
import secrets
import uuid
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy import Engine
from sqlalchemy.orm import Session
from core.variables.segments import StringSegment
from core.variables.types import SegmentType
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import NodeType
from libs.uuid_utils import uuidv7
from models.account import Account
from models.enums import DraftVariableType
from models.workflow import (
Workflow,
WorkflowDraftVariable,
WorkflowDraftVariableFile,
WorkflowNodeExecutionModel,
is_system_variable_editable,
)
from services.workflow_draft_variable_service import (
DraftVariableSaver,
VariableResetError,
WorkflowDraftVariableService,
)
@pytest.fixture
def mock_engine() -> Engine:
return Mock(spec=Engine)
@pytest.fixture
def mock_session(mock_engine) -> Session:
mock_session = Mock(spec=Session)
mock_session.get_bind.return_value = mock_engine
return mock_session
class TestDraftVariableSaver:
def _get_test_app_id(self):
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def test__should_variable_be_visible(self):
mock_session = MagicMock(spec=Session)
mock_user = Account(name="test", email="test@example.com")
mock_user.id = str(uuid.uuid4())
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
app_id=test_app_id,
node_id="test_node_id",
node_type=NodeType.START,
node_execution_id="test_execution_id",
user=mock_user,
)
assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False
assert saver._should_variable_be_visible("123", NodeType.START, "output") == True
def test__normalize_variable_for_start_node(self):
@dataclasses.dataclass(frozen=True)
class TestCase:
name: str
input_node_id: str
input_name: str
expected_node_id: str
expected_name: str
_NODE_ID = "1747228642872"
cases = [
TestCase(
name="name with `sys.` prefix should return the system node_id",
input_node_id=_NODE_ID,
input_name="sys.workflow_id",
expected_node_id=SYSTEM_VARIABLE_NODE_ID,
expected_name="workflow_id",
),
TestCase(
name="name without `sys.` prefix should return the original input node_id",
input_node_id=_NODE_ID,
input_name="start_input",
expected_node_id=_NODE_ID,
expected_name="start_input",
),
TestCase(
name="dummy_variable should return the original input node_id",
input_node_id=_NODE_ID,
input_name="__dummy__",
expected_node_id=_NODE_ID,
expected_name="__dummy__",
),
]
mock_session = MagicMock(spec=Session)
mock_user = MagicMock()
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
app_id=test_app_id,
node_id=_NODE_ID,
node_type=NodeType.START,
node_execution_id="test_execution_id",
user=mock_user,
)
for idx, c in enumerate(cases, 1):
fail_msg = f"Test case {c.name} failed, index={idx}"
node_id, name = saver._normalize_variable_for_start_node(c.input_name)
assert node_id == c.expected_node_id, fail_msg
assert name == c.expected_name, fail_msg
@pytest.fixture
def mock_session(self):
"""Mock SQLAlchemy session."""
from sqlalchemy import Engine
mock_session = MagicMock(spec=Session)
mock_engine = MagicMock(spec=Engine)
mock_session.get_bind.return_value = mock_engine
return mock_session
@pytest.fixture
def draft_saver(self, mock_session):
"""Create DraftVariableSaver instance with user context."""
# Create a mock user
mock_user = MagicMock(spec=Account)
mock_user.id = "test-user-id"
mock_user.tenant_id = "test-tenant-id"
return DraftVariableSaver(
session=mock_session,
app_id="test-app-id",
node_id="test-node-id",
node_type=NodeType.LLM,
node_execution_id="test-execution-id",
user=mock_user,
)
def test_draft_saver_with_small_variables(self, draft_saver, mock_session):
with patch(
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
) as _mock_try_offload:
_mock_try_offload.return_value = None
mock_segment = StringSegment(value="small value")
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
# Should not have large variable metadata
assert draft_var.file_id is None
_mock_try_offload.return_value = None
def test_draft_saver_with_large_variables(self, draft_saver, mock_session):
with patch(
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
) as _mock_try_offload:
mock_segment = StringSegment(value="small value")
mock_draft_var_file = WorkflowDraftVariableFile(
id=str(uuidv7()),
size=1024,
length=10,
value_type=SegmentType.ARRAY_STRING,
upload_file_id=str(uuid.uuid4()),
)
_mock_try_offload.return_value = mock_segment, mock_draft_var_file
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
# Should not have large variable metadata
assert draft_var.file_id == mock_draft_var_file.id
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable")
def test_save_method_integration(self, mock_batch_upsert, draft_saver):
"""Test complete save workflow."""
outputs = {"result": {"data": "test_output"}, "metadata": {"type": "llm_response"}}
draft_saver.save(outputs=outputs)
# Should batch upsert draft variables
mock_batch_upsert.assert_called_once()
draft_vars = mock_batch_upsert.call_args[0][1]
assert len(draft_vars) == 2
class TestWorkflowDraftVariableService:
def _get_test_app_id(self):
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def _create_test_workflow(self, app_id: str) -> Workflow:
"""Create a real Workflow instance for testing"""
return Workflow.new(
tenant_id="test_tenant_id",
app_id=app_id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features="{}",
created_by="test_user_id",
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
def test_reset_conversation_variable(self, mock_session):
"""Test resetting a conversation variable"""
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create real conversation variable
test_value = StringSegment(value="test_value")
variable = WorkflowDraftVariable.new_conversation_variable(
app_id=test_app_id, name="test_var", value=test_value, description="Test conversation variable"
)
# Mock the _reset_conv_var method
expected_result = WorkflowDraftVariable.new_conversation_variable(
app_id=test_app_id,
name="test_var",
value=StringSegment(value="reset_value"),
)
with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv:
result = service.reset_variable(workflow, variable)
mock_reset_conv.assert_called_once_with(workflow, variable)
assert result == expected_result
def test_reset_node_variable_with_no_execution_id(self, mock_session):
"""Test resetting a node variable with no execution ID - should delete variable"""
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create real node variable with no execution ID
test_value = StringSegment(value="test_value")
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id,
node_id="test_node_id",
name="test_var",
value=test_value,
node_execution_id="exec-id", # Set initially
)
# Manually set to None to simulate the test condition
variable.node_execution_id = None
result = service._reset_node_var_or_sys_var(workflow, variable)
# Should delete the variable and return None
mock_session.delete.assert_called_once_with(instance=variable)
mock_session.flush.assert_called_once()
assert result is None
def test_reset_node_variable_with_missing_execution_record(
self,
mock_engine,
mock_session,
monkeypatch,
):
"""Test resetting a node variable when execution record doesn't exist"""
mock_repo_session = Mock(spec=Session)
mock_session_maker = MagicMock()
# Mock the context manager protocol for sessionmaker
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
mock_session_maker.return_value.__exit__.return_value = None
monkeypatch.setattr("services.workflow_draft_variable_service.sessionmaker", mock_session_maker)
service = WorkflowDraftVariableService(mock_session)
# Mock the repository to return None (no execution record found)
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = None
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create real node variable with execution ID
test_value = StringSegment(value="test_value")
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
# Variable is editable by default from factory method
result = service._reset_node_var_or_sys_var(workflow, variable)
mock_session_maker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
# Should delete the variable and return None
mock_session.delete.assert_called_once_with(instance=variable)
mock_session.flush.assert_called_once()
assert result is None
def test_reset_node_variable_with_valid_execution_record(
self,
mock_session,
monkeypatch,
):
"""Test resetting a node variable with valid execution record - should restore from execution"""
mock_repo_session = Mock(spec=Session)
mock_session_maker = MagicMock()
# Mock the context manager protocol for sessionmaker
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
mock_session_maker.return_value.__exit__.return_value = None
mock_session_maker = monkeypatch.setattr(
"services.workflow_draft_variable_service.sessionmaker", mock_session_maker
)
service = WorkflowDraftVariableService(mock_session)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.load_full_outputs.return_value = {"test_var": "output_value"}
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create real node variable with execution ID
test_value = StringSegment(value="original_value")
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
# Variable is editable by default from factory method
# Mock workflow methods
mock_node_config = {"type": "test_node"}
with (
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config),
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM),
):
result = service._reset_node_var_or_sys_var(workflow, variable)
# Verify last_edited_at was reset
assert variable.last_edited_at is None
# Verify session.flush was called
mock_session.flush.assert_called()
# Should return the updated variable
assert result == variable
def test_reset_non_editable_system_variable_raises_error(self, mock_session):
"""Test that resetting a non-editable system variable raises an error"""
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create a non-editable system variable (workflow_id is not editable)
test_value = StringSegment(value="test_workflow_id")
variable = WorkflowDraftVariable.new_sys_variable(
app_id=test_app_id,
name="workflow_id", # This is not in _EDITABLE_SYSTEM_VARIABLE
value=test_value,
node_execution_id="exec-id",
editable=False, # Non-editable system variable
)
with pytest.raises(VariableResetError) as exc_info:
service.reset_variable(workflow, variable)
assert "cannot reset system variable" in str(exc_info.value)
assert f"variable_id={variable.id}" in str(exc_info.value)
def test_reset_editable_system_variable_succeeds(self, mock_session):
"""Test that resetting an editable system variable succeeds"""
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create an editable system variable (files is editable)
test_value = StringSegment(value="[]")
variable = WorkflowDraftVariable.new_sys_variable(
app_id=test_app_id,
name="files", # This is in _EDITABLE_SYSTEM_VARIABLE
value=test_value,
node_execution_id="exec-id",
editable=True, # Editable system variable
)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.load_full_outputs.return_value = {"sys.files": "[]"}
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
result = service._reset_node_var_or_sys_var(workflow, variable)
# Should succeed and return the variable
assert result == variable
assert variable.last_edited_at is None
mock_session.flush.assert_called()
def test_reset_query_system_variable_succeeds(self, mock_session):
"""Test that resetting query system variable (another editable one) succeeds"""
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create an editable system variable (query is editable)
test_value = StringSegment(value="original query")
variable = WorkflowDraftVariable.new_sys_variable(
app_id=test_app_id,
name="query", # This is in _EDITABLE_SYSTEM_VARIABLE
value=test_value,
node_execution_id="exec-id",
editable=True, # Editable system variable
)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.load_full_outputs.return_value = {"sys.query": "reset query"}
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
result = service._reset_node_var_or_sys_var(workflow, variable)
# Should succeed and return the variable
assert result == variable
assert variable.last_edited_at is None
mock_session.flush.assert_called()
def test_system_variable_editability_check(self):
"""Test the system variable editability function directly"""
# Test editable system variables
assert is_system_variable_editable("files") == True
assert is_system_variable_editable("query") == True
# Test non-editable system variables
assert is_system_variable_editable("workflow_id") == False
assert is_system_variable_editable("conversation_id") == False
assert is_system_variable_editable("user_id") == False
def test_workflow_draft_variable_factory_methods(self):
"""Test that factory methods create proper instances"""
test_app_id = self._get_test_app_id()
test_value = StringSegment(value="test_value")
# Test conversation variable factory
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=test_app_id, name="conv_var", value=test_value, description="Test conversation variable"
)
assert conv_var.get_variable_type() == DraftVariableType.CONVERSATION
assert conv_var.editable == True
assert conv_var.node_execution_id is None
# Test system variable factory
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=test_app_id, name="workflow_id", value=test_value, node_execution_id="exec-id", editable=False
)
assert sys_var.get_variable_type() == DraftVariableType.SYS
assert sys_var.editable == False
assert sys_var.node_execution_id == "exec-id"
# Test node variable factory
node_var = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id,
node_id="node-id",
name="node_var",
value=test_value,
node_execution_id="exec-id",
visible=True,
editable=True,
)
assert node_var.get_variable_type() == DraftVariableType.NODE
assert node_var.visible == True
assert node_var.editable == True
assert node_var.node_execution_id == "exec-id"

View File

@@ -0,0 +1,288 @@
from datetime import datetime
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from models.workflow import WorkflowNodeExecutionModel
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
@pytest.fixture
def repository(self):
mock_session_maker = MagicMock()
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
@pytest.fixture
def mock_execution(self):
execution = MagicMock(spec=WorkflowNodeExecutionModel)
execution.id = str(uuid4())
execution.tenant_id = "tenant-123"
execution.app_id = "app-456"
execution.workflow_id = "workflow-789"
execution.workflow_run_id = "run-101"
execution.node_id = "node-202"
execution.index = 1
execution.created_at = "2023-01-01T00:00:00Z"
return execution
def test_get_node_last_execution_found(self, repository, mock_execution):
"""Test getting the last execution for a node when it exists."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = mock_execution
# Act
result = repository.get_node_last_execution(
tenant_id="tenant-123",
app_id="app-456",
workflow_id="workflow-789",
node_id="node-202",
)
# Assert
assert result == mock_execution
mock_session.scalar.assert_called_once()
# Verify the query was constructed correctly
call_args = mock_session.scalar.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
def test_get_node_last_execution_not_found(self, repository):
"""Test getting the last execution for a node when it doesn't exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = None
# Act
result = repository.get_node_last_execution(
tenant_id="tenant-123",
app_id="app-456",
workflow_id="workflow-789",
node_id="node-202",
)
# Assert
assert result is None
mock_session.scalar.assert_called_once()
def test_get_executions_by_workflow_run(self, repository, mock_execution):
"""Test getting all executions for a workflow run."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
executions = [mock_execution]
mock_session.execute.return_value.scalars.return_value.all.return_value = executions
# Act
result = repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-456",
workflow_run_id="run-101",
)
# Assert
assert result == executions
mock_session.execute.assert_called_once()
# Verify the query was constructed correctly
call_args = mock_session.execute.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
def test_get_executions_by_workflow_run_empty(self, repository):
"""Test getting executions for a workflow run when none exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalars.return_value.all.return_value = []
# Act
result = repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-456",
workflow_run_id="run-101",
)
# Assert
assert result == []
mock_session.execute.assert_called_once()
def test_get_execution_by_id_found(self, repository, mock_execution):
"""Test getting execution by ID when it exists."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = mock_execution
# Act
result = repository.get_execution_by_id(mock_execution.id)
# Assert
assert result == mock_execution
mock_session.scalar.assert_called_once()
def test_get_execution_by_id_not_found(self, repository):
"""Test getting execution by ID when it doesn't exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = None
# Act
result = repository.get_execution_by_id("non-existent-id")
# Assert
assert result is None
mock_session.scalar.assert_called_once()
def test_repository_implements_protocol(self, repository):
"""Test that the repository implements the required protocol methods."""
# Verify all protocol methods are implemented
assert hasattr(repository, "get_node_last_execution")
assert hasattr(repository, "get_executions_by_workflow_run")
assert hasattr(repository, "get_execution_by_id")
# Verify methods are callable
assert callable(repository.get_node_last_execution)
assert callable(repository.get_executions_by_workflow_run)
assert callable(repository.get_execution_by_id)
assert callable(repository.delete_expired_executions)
assert callable(repository.delete_executions_by_app)
assert callable(repository.get_expired_executions_batch)
assert callable(repository.delete_executions_by_ids)
def test_delete_expired_executions(self, repository):
"""Test deleting expired executions."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
before_date = datetime(2023, 1, 1)
# Act
result = repository.delete_expired_executions(
tenant_id="tenant-123",
before_date=before_date,
batch_size=1000,
)
# Assert
assert result == 2
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_delete_executions_by_app(self, repository):
"""Test deleting executions by app."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"]
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
# Act
result = repository.delete_executions_by_app(
tenant_id="tenant-123",
app_id="app-456",
batch_size=1000,
)
# Assert
assert result == 2
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_get_expired_executions_batch(self, repository):
"""Test getting expired executions batch for backup."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Create mock execution objects
mock_execution1 = MagicMock()
mock_execution1.id = "exec-1"
mock_execution2 = MagicMock()
mock_execution2.id = "exec-2"
mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
before_date = datetime(2023, 1, 1)
# Act
result = repository.get_expired_executions_batch(
tenant_id="tenant-123",
before_date=before_date,
batch_size=1000,
)
# Assert
assert len(result) == 2
assert result[0].id == "exec-1"
assert result[1].id == "exec-2"
mock_session.execute.assert_called_once()
def test_delete_executions_by_ids(self, repository):
"""Test deleting executions by IDs."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the delete query result
mock_result = MagicMock()
mock_result.rowcount = 3
mock_session.execute.return_value = mock_result
execution_ids = ["id1", "id2", "id3"]
# Act
result = repository.delete_executions_by_ids(execution_ids)
# Assert
assert result == 3
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_delete_executions_by_ids_empty_list(self, repository):
"""Test deleting executions with empty ID list."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Act
result = repository.delete_executions_by_ids([])
# Assert
assert result == 0
mock_session.query.assert_not_called()
mock_session.commit.assert_not_called()

View File

@@ -0,0 +1,163 @@
from unittest.mock import MagicMock
import pytest
from models.model import App
from models.workflow import Workflow
from services.workflow_service import WorkflowService
class TestWorkflowService:
@pytest.fixture
def workflow_service(self):
mock_session_maker = MagicMock()
return WorkflowService(mock_session_maker)
@pytest.fixture
def mock_app(self):
app = MagicMock(spec=App)
app.id = "app-id-1"
app.workflow_id = "workflow-id-1"
app.tenant_id = "tenant-id-1"
return app
@pytest.fixture
def mock_workflows(self):
workflows = []
for i in range(5):
workflow = MagicMock(spec=Workflow)
workflow.id = f"workflow-id-{i}"
workflow.app_id = "app-id-1"
workflow.created_at = f"2023-01-0{5 - i}" # Descending date order
workflow.created_by = "user-id-1" if i % 2 == 0 else "user-id-2"
workflow.marked_name = f"Workflow {i}" if i % 2 == 0 else ""
workflows.append(workflow)
return workflows
def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app):
mock_app.workflow_id = None
mock_session = MagicMock()
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
)
assert workflows == []
assert has_more is False
mock_session.scalars.assert_not_called()
def test_get_all_published_workflow_basic(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
mock_scalar_result.all.return_value = mock_workflows[:3]
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
)
assert workflows == mock_workflows[:3]
assert has_more is False
mock_session.scalars.assert_called_once()
def test_get_all_published_workflow_pagination(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Return 4 items when limit is 3, which should indicate has_more=True
mock_scalar_result.all.return_value = mock_workflows[:4]
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
)
# Should return only the first 3 items
assert len(workflows) == 3
assert workflows == mock_workflows[:3]
assert has_more is True
# Test page 2
mock_scalar_result.all.return_value = mock_workflows[3:]
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=2, limit=3, user_id=None
)
assert len(workflows) == 2
assert has_more is False
def test_get_all_published_workflow_user_filter(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Filter workflows for user-id-1
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1"]
mock_scalar_result.all.return_value = filtered_workflows
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1"
)
assert workflows == filtered_workflows
assert has_more is False
mock_session.scalars.assert_called_once()
# Verify that the select contains a user filter clause
args = mock_session.scalars.call_args[0][0]
assert "created_by" in str(args)
def test_get_all_published_workflow_named_only(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Filter workflows that have a marked_name
named_workflows = [w for w in mock_workflows if w.marked_name]
mock_scalar_result.all.return_value = named_workflows
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None, named_only=True
)
assert workflows == named_workflows
assert has_more is False
mock_session.scalars.assert_called_once()
# Verify that the select contains a named_only filter clause
args = mock_session.scalars.call_args[0][0]
assert "marked_name !=" in str(args)
def test_get_all_published_workflow_combined_filters(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Combined filter: user-id-1 and has marked_name
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1" and w.marked_name]
mock_scalar_result.all.return_value = filtered_workflows
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1", named_only=True
)
assert workflows == filtered_workflows
assert has_more is False
mock_session.scalars.assert_called_once()
# Verify that both filters are applied
args = mock_session.scalars.call_args[0][0]
assert "created_by" in str(args)
assert "marked_name !=" in str(args)
def test_get_all_published_workflow_empty_result(self, workflow_service, mock_app):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
mock_scalar_result.all.return_value = []
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
)
assert workflows == []
assert has_more is False
mock_session.scalars.assert_called_once()