dify
This commit is contained in:
@@ -0,0 +1,377 @@
|
||||
"""Simplified unit tests for DraftVarLoader focusing on core functionality."""
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.variables.segments import ObjectSegment, StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from services.workflow_draft_variable_service import DraftVarLoader
|
||||
|
||||
|
||||
class TestDraftVarLoaderSimple:
|
||||
"""Simplified unit tests for DraftVarLoader core methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine(self) -> Engine:
|
||||
return Mock(spec=Engine)
|
||||
|
||||
@pytest.fixture
|
||||
def draft_var_loader(self, mock_engine):
|
||||
"""Create DraftVarLoader instance for testing."""
|
||||
return DraftVarLoader(
|
||||
engine=mock_engine, app_id="test-app-id", tenant_id="test-tenant-id", fallback_variables=[]
|
||||
)
|
||||
|
||||
def test_load_offloaded_variable_string_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with string type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test.txt"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.STRING
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_variable"
|
||||
draft_var.description = "test description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_variable"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_content = "This is the full string content"
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_content.encode()
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_variable"
|
||||
mock_variable.value = StringSegment(value=test_content)
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_variable")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_variable"
|
||||
assert variable.description == "test description"
|
||||
assert variable.value == test_content
|
||||
|
||||
# Verify storage was called correctly
|
||||
mock_storage.load.assert_called_once_with("storage/key/test.txt")
|
||||
|
||||
def test_load_offloaded_variable_object_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with object type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.OBJECT
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_object"
|
||||
draft_var.description = "test description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_object"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_object = {"key1": "value1", "key2": 42}
|
||||
test_json_content = json.dumps(test_object, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
mock_segment = ObjectSegment(value=test_object)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_object"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_object")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_object"
|
||||
assert variable.description == "test description"
|
||||
assert variable.value == test_object
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.OBJECT, test_object)
|
||||
|
||||
def test_load_offloaded_variable_missing_variable_file_unit(self, draft_var_loader):
|
||||
"""Test that assertion error is raised when variable_file is None."""
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.variable_file = None
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
def test_load_offloaded_variable_missing_upload_file_unit(self, draft_var_loader):
|
||||
"""Test that assertion error is raised when upload_file is None."""
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.upload_file = None
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
def test_load_variables_empty_selectors_unit(self, draft_var_loader):
|
||||
"""Test load_variables returns empty list for empty selectors."""
|
||||
result = draft_var_loader.load_variables([])
|
||||
assert result == []
|
||||
|
||||
def test_selector_to_tuple_unit(self, draft_var_loader):
|
||||
"""Test _selector_to_tuple method."""
|
||||
selector = ["node_id", "var_name", "extra_field"]
|
||||
result = draft_var_loader._selector_to_tuple(selector)
|
||||
assert result == ("node_id", "var_name")
|
||||
|
||||
def test_load_offloaded_variable_number_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with number type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test_number.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.NUMBER
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_number"
|
||||
draft_var.description = "test number description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_number"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_number = 123.45
|
||||
test_json_content = json.dumps(test_number)
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
from core.variables.segments import FloatSegment
|
||||
|
||||
mock_segment = FloatSegment(value=test_number)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_number"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_number")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_number"
|
||||
assert variable.description == "test number description"
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test_number.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.NUMBER, test_number)
|
||||
|
||||
def test_load_offloaded_variable_array_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with array type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test_array.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.ARRAY_ANY
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_array"
|
||||
draft_var.description = "test array description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_array"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_array = ["item1", "item2", "item3"]
|
||||
test_json_content = json.dumps(test_array)
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
|
||||
mock_segment = ArrayAnySegment(value=test_array)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_array"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_array")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_array"
|
||||
assert variable.description == "test array description"
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test_array.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.ARRAY_ANY, test_array)
|
||||
|
||||
def test_load_variables_with_offloaded_variables_unit(self, draft_var_loader):
|
||||
"""Test load_variables method with mix of regular and offloaded variables."""
|
||||
selectors = [["node1", "regular_var"], ["node2", "offloaded_var"]]
|
||||
|
||||
# Mock regular variable
|
||||
regular_draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
regular_draft_var.is_truncated.return_value = False
|
||||
regular_draft_var.node_id = "node1"
|
||||
regular_draft_var.name = "regular_var"
|
||||
regular_draft_var.get_value.return_value = StringSegment(value="regular_value")
|
||||
regular_draft_var.get_selector.return_value = ["node1", "regular_var"]
|
||||
regular_draft_var.id = "regular-var-id"
|
||||
regular_draft_var.description = "regular description"
|
||||
|
||||
# Mock offloaded variable
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/offloaded.txt"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.STRING
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
offloaded_draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_draft_var.is_truncated.return_value = True
|
||||
offloaded_draft_var.node_id = "node2"
|
||||
offloaded_draft_var.name = "offloaded_var"
|
||||
offloaded_draft_var.get_selector.return_value = ["node2", "offloaded_var"]
|
||||
offloaded_draft_var.variable_file = variable_file
|
||||
offloaded_draft_var.id = "offloaded-var-id"
|
||||
offloaded_draft_var.description = "offloaded description"
|
||||
|
||||
draft_vars = [regular_draft_var, offloaded_draft_var]
|
||||
|
||||
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
|
||||
mock_session = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
|
||||
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service
|
||||
):
|
||||
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
# Mock regular variable creation
|
||||
regular_variable = Mock()
|
||||
regular_variable.selector = ["node1", "regular_var"]
|
||||
|
||||
# Mock offloaded variable creation
|
||||
offloaded_variable = Mock()
|
||||
offloaded_variable.selector = ["node2", "offloaded_var"]
|
||||
|
||||
mock_segment_to_variable.return_value = regular_variable
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = b"offloaded_content"
|
||||
|
||||
with patch.object(draft_var_loader, "_load_offloaded_variable") as mock_load_offloaded:
|
||||
mock_load_offloaded.return_value = (("node2", "offloaded_var"), offloaded_variable)
|
||||
|
||||
with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor_cls:
|
||||
mock_executor = Mock()
|
||||
mock_executor_cls.return_value.__enter__.return_value = mock_executor
|
||||
mock_executor.map.return_value = [(("node2", "offloaded_var"), offloaded_variable)]
|
||||
|
||||
# Execute the method
|
||||
result = draft_var_loader.load_variables(selectors)
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify service method was called
|
||||
mock_service.get_draft_variables_by_selectors.assert_called_once_with(
|
||||
draft_var_loader._app_id, selectors
|
||||
)
|
||||
|
||||
# Verify offloaded variable loading was called
|
||||
mock_load_offloaded.assert_called_once_with(offloaded_draft_var)
|
||||
|
||||
def test_load_variables_all_offloaded_variables_unit(self, draft_var_loader):
|
||||
"""Test load_variables method with only offloaded variables."""
|
||||
selectors = [["node1", "offloaded_var1"], ["node2", "offloaded_var2"]]
|
||||
|
||||
# Mock first offloaded variable
|
||||
offloaded_var1 = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_var1.is_truncated.return_value = True
|
||||
offloaded_var1.node_id = "node1"
|
||||
offloaded_var1.name = "offloaded_var1"
|
||||
|
||||
# Mock second offloaded variable
|
||||
offloaded_var2 = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_var2.is_truncated.return_value = True
|
||||
offloaded_var2.node_id = "node2"
|
||||
offloaded_var2.name = "offloaded_var2"
|
||||
|
||||
draft_vars = [offloaded_var1, offloaded_var2]
|
||||
|
||||
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
|
||||
mock_session = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
|
||||
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service
|
||||
):
|
||||
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
|
||||
with patch("services.workflow_draft_variable_service.ThreadPoolExecutor") as mock_executor_cls:
|
||||
mock_executor = Mock()
|
||||
mock_executor_cls.return_value.__enter__.return_value = mock_executor
|
||||
mock_executor.map.return_value = [
|
||||
(("node1", "offloaded_var1"), Mock()),
|
||||
(("node2", "offloaded_var2"), Mock()),
|
||||
]
|
||||
|
||||
# Execute the method
|
||||
result = draft_var_loader.load_variables(selectors)
|
||||
|
||||
# Verify results - since we have only offloaded variables, should have 2 results
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify ThreadPoolExecutor was used
|
||||
mock_executor_cls.assert_called_once_with(max_workers=10)
|
||||
mock_executor.map.assert_called_once()
|
||||
@@ -0,0 +1,432 @@
|
||||
# test for api/services/workflow/workflow_converter.py
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
AdvancedChatMessageEntity,
|
||||
AdvancedChatPromptTemplateEntity,
|
||||
AdvancedCompletionPromptTemplateEntity,
|
||||
DatasetEntity,
|
||||
DatasetRetrieveConfigEntity,
|
||||
ExternalDataVariableEntity,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
VariableEntity,
|
||||
VariableEntityType,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
from models.model import AppMode
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_variables():
|
||||
value = [
|
||||
VariableEntity(
|
||||
variable="text_input",
|
||||
label="text-input",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
),
|
||||
VariableEntity(
|
||||
variable="paragraph",
|
||||
label="paragraph",
|
||||
type=VariableEntityType.PARAGRAPH,
|
||||
),
|
||||
VariableEntity(
|
||||
variable="select",
|
||||
label="select",
|
||||
type=VariableEntityType.SELECT,
|
||||
),
|
||||
]
|
||||
return value
|
||||
|
||||
|
||||
def test__convert_to_start_node(default_variables):
|
||||
# act
|
||||
result = WorkflowConverter()._convert_to_start_node(default_variables)
|
||||
|
||||
# assert
|
||||
assert isinstance(result["data"]["variables"][0]["type"], str)
|
||||
assert result["data"]["variables"][0]["type"] == "text-input"
|
||||
assert result["data"]["variables"][0]["variable"] == "text_input"
|
||||
assert result["data"]["variables"][1]["variable"] == "paragraph"
|
||||
assert result["data"]["variables"][2]["variable"] == "select"
|
||||
|
||||
|
||||
def test__convert_to_http_request_node_for_chatbot(default_variables):
|
||||
"""
|
||||
Test convert to http request nodes for chatbot
|
||||
:return:
|
||||
"""
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app_id"
|
||||
app_model.tenant_id = "tenant_id"
|
||||
app_model.mode = AppMode.CHAT
|
||||
|
||||
api_based_extension_id = "api_based_extension_id"
|
||||
mock_api_based_extension = APIBasedExtension(
|
||||
tenant_id="tenant_id",
|
||||
name="api-1",
|
||||
api_key="encrypted_api_key",
|
||||
api_endpoint="https://dify.ai",
|
||||
)
|
||||
|
||||
mock_api_based_extension.id = api_based_extension_id
|
||||
workflow_converter = WorkflowConverter()
|
||||
workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension)
|
||||
|
||||
encrypter.decrypt_token = MagicMock(return_value="api_key")
|
||||
|
||||
external_data_variables = [
|
||||
ExternalDataVariableEntity(
|
||||
variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id}
|
||||
)
|
||||
]
|
||||
|
||||
nodes, _ = workflow_converter._convert_to_http_request_node(
|
||||
app_model=app_model, variables=default_variables, external_data_variables=external_data_variables
|
||||
)
|
||||
|
||||
assert len(nodes) == 2
|
||||
assert nodes[0]["data"]["type"] == "http-request"
|
||||
|
||||
http_request_node = nodes[0]
|
||||
|
||||
assert http_request_node["data"]["method"] == "post"
|
||||
assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint
|
||||
assert http_request_node["data"]["authorization"]["type"] == "api-key"
|
||||
assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"}
|
||||
assert http_request_node["data"]["body"]["type"] == "json"
|
||||
|
||||
body_data = http_request_node["data"]["body"]["data"]
|
||||
|
||||
assert body_data
|
||||
|
||||
body_data_json = json.loads(body_data)
|
||||
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
|
||||
|
||||
body_params = body_data_json["params"]
|
||||
assert body_params["app_id"] == app_model.id
|
||||
assert body_params["tool_variable"] == external_data_variables[0].variable
|
||||
assert len(body_params["inputs"]) == 3
|
||||
assert body_params["query"] == "{{#sys.query#}}" # for chatbot
|
||||
|
||||
code_node = nodes[1]
|
||||
assert code_node["data"]["type"] == "code"
|
||||
|
||||
|
||||
def test__convert_to_http_request_node_for_workflow_app(default_variables):
|
||||
"""
|
||||
Test convert to http request nodes for workflow app
|
||||
:return:
|
||||
"""
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app_id"
|
||||
app_model.tenant_id = "tenant_id"
|
||||
app_model.mode = AppMode.WORKFLOW
|
||||
|
||||
api_based_extension_id = "api_based_extension_id"
|
||||
mock_api_based_extension = APIBasedExtension(
|
||||
tenant_id="tenant_id",
|
||||
name="api-1",
|
||||
api_key="encrypted_api_key",
|
||||
api_endpoint="https://dify.ai",
|
||||
)
|
||||
mock_api_based_extension.id = api_based_extension_id
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension)
|
||||
|
||||
encrypter.decrypt_token = MagicMock(return_value="api_key")
|
||||
|
||||
external_data_variables = [
|
||||
ExternalDataVariableEntity(
|
||||
variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id}
|
||||
)
|
||||
]
|
||||
|
||||
nodes, _ = workflow_converter._convert_to_http_request_node(
|
||||
app_model=app_model, variables=default_variables, external_data_variables=external_data_variables
|
||||
)
|
||||
|
||||
assert len(nodes) == 2
|
||||
assert nodes[0]["data"]["type"] == "http-request"
|
||||
|
||||
http_request_node = nodes[0]
|
||||
|
||||
assert http_request_node["data"]["method"] == "post"
|
||||
assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint
|
||||
assert http_request_node["data"]["authorization"]["type"] == "api-key"
|
||||
assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"}
|
||||
assert http_request_node["data"]["body"]["type"] == "json"
|
||||
|
||||
body_data = http_request_node["data"]["body"]["data"]
|
||||
|
||||
assert body_data
|
||||
|
||||
body_data_json = json.loads(body_data)
|
||||
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
|
||||
|
||||
body_params = body_data_json["params"]
|
||||
assert body_params["app_id"] == app_model.id
|
||||
assert body_params["tool_variable"] == external_data_variables[0].variable
|
||||
assert len(body_params["inputs"]) == 3
|
||||
assert body_params["query"] == ""
|
||||
|
||||
code_node = nodes[1]
|
||||
assert code_node["data"]["type"] == "code"
|
||||
|
||||
|
||||
def test__convert_to_knowledge_retrieval_node_for_chatbot():
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
|
||||
dataset_config = DatasetEntity(
|
||||
dataset_ids=["dataset_id_1", "dataset_id_2"],
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
top_k=5,
|
||||
score_threshold=0.8,
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
|
||||
reranking_enabled=True,
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
|
||||
|
||||
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config
|
||||
)
|
||||
assert node is not None
|
||||
|
||||
assert node["data"]["type"] == "knowledge-retrieval"
|
||||
assert node["data"]["query_variable_selector"] == ["sys", "query"]
|
||||
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
|
||||
assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value
|
||||
assert node["data"]["multiple_retrieval_config"] == {
|
||||
"top_k": dataset_config.retrieve_config.top_k,
|
||||
"score_threshold": dataset_config.retrieve_config.score_threshold,
|
||||
"reranking_model": dataset_config.retrieve_config.reranking_model,
|
||||
}
|
||||
|
||||
|
||||
def test__convert_to_knowledge_retrieval_node_for_workflow_app():
|
||||
new_app_mode = AppMode.WORKFLOW
|
||||
|
||||
dataset_config = DatasetEntity(
|
||||
dataset_ids=["dataset_id_1", "dataset_id_2"],
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable="query",
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
top_k=5,
|
||||
score_threshold=0.8,
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
|
||||
reranking_enabled=True,
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
|
||||
|
||||
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config
|
||||
)
|
||||
assert node is not None
|
||||
|
||||
assert node["data"]["type"] == "knowledge-retrieval"
|
||||
assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable]
|
||||
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
|
||||
assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value
|
||||
assert node["data"]["multiple_retrieval_config"] == {
|
||||
"top_k": dataset_config.retrieve_config.top_k,
|
||||
"score_threshold": dataset_config.retrieve_config.score_threshold,
|
||||
"reranking_model": dataset_config.retrieve_config.reranking_model,
|
||||
}
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
model = "gpt-4"
|
||||
model_mode = LLMMode.CHAT
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
start_node = workflow_converter._convert_to_start_node(default_variables)
|
||||
graph = {
|
||||
"nodes": [start_node],
|
||||
"edges": [], # no need
|
||||
}
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = model
|
||||
model_config_mock.mode = model_mode.value
|
||||
model_config_mock.parameters = {}
|
||||
model_config_mock.stop = []
|
||||
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.",
|
||||
)
|
||||
|
||||
llm_node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=new_app_mode,
|
||||
model_config=model_config_mock,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
template = prompt_template.simple_prompt_template
|
||||
assert template is not None
|
||||
for v in default_variables:
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n"
|
||||
assert llm_node["data"]["context"]["enabled"] is False
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables):
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
model = "gpt-3.5-turbo-instruct"
|
||||
model_mode = LLMMode.COMPLETION
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
start_node = workflow_converter._convert_to_start_node(default_variables)
|
||||
graph = {
|
||||
"nodes": [start_node],
|
||||
"edges": [], # no need
|
||||
}
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = model
|
||||
model_config_mock.mode = model_mode.value
|
||||
model_config_mock.parameters = {}
|
||||
model_config_mock.stop = []
|
||||
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.",
|
||||
)
|
||||
|
||||
llm_node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=new_app_mode,
|
||||
model_config=model_config_mock,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
template = prompt_template.simple_prompt_template
|
||||
assert template is not None
|
||||
for v in default_variables:
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
assert llm_node["data"]["prompt_template"]["text"] == template + "\n"
|
||||
assert llm_node["data"]["context"]["enabled"] is False
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables):
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
model = "gpt-4"
|
||||
model_mode = LLMMode.CHAT
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
start_node = workflow_converter._convert_to_start_node(default_variables)
|
||||
graph = {
|
||||
"nodes": [start_node],
|
||||
"edges": [], # no need
|
||||
}
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = model
|
||||
model_config_mock.mode = model_mode.value
|
||||
model_config_mock.parameters = {}
|
||||
model_config_mock.stop = []
|
||||
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
|
||||
messages=[
|
||||
AdvancedChatMessageEntity(
|
||||
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
),
|
||||
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
|
||||
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
llm_node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=new_app_mode,
|
||||
model_config=model_config_mock,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
assert isinstance(llm_node["data"]["prompt_template"], list)
|
||||
assert prompt_template.advanced_chat_prompt_template is not None
|
||||
assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages)
|
||||
template = prompt_template.advanced_chat_prompt_template.messages[0].text
|
||||
for v in default_variables:
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
assert llm_node["data"]["prompt_template"][0]["text"] == template
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables):
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
model = "gpt-3.5-turbo-instruct"
|
||||
model_mode = LLMMode.COMPLETION
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
start_node = workflow_converter._convert_to_start_node(default_variables)
|
||||
graph = {
|
||||
"nodes": [start_node],
|
||||
"edges": [], # no need
|
||||
}
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = model
|
||||
model_config_mock.mode = model_mode.value
|
||||
model_config_mock.parameters = {}
|
||||
model_config_mock.stop = []
|
||||
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
|
||||
prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ",
|
||||
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"),
|
||||
),
|
||||
)
|
||||
|
||||
llm_node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=new_app_mode,
|
||||
model_config=model_config_mock,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
assert isinstance(llm_node["data"]["prompt_template"], dict)
|
||||
assert prompt_template.advanced_completion_prompt_template is not None
|
||||
template = prompt_template.advanced_completion_prompt_template.prompt
|
||||
for v in default_variables:
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
assert llm_node["data"]["prompt_template"]["text"] == template
|
||||
@@ -0,0 +1,127 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_setup():
|
||||
mock_session_maker = MagicMock()
|
||||
workflow_service = WorkflowService(mock_session_maker)
|
||||
session = MagicMock(spec=Session)
|
||||
tenant_id = "test-tenant-id"
|
||||
workflow_id = "test-workflow-id"
|
||||
|
||||
# Mock workflow
|
||||
workflow = MagicMock(spec=Workflow)
|
||||
workflow.id = workflow_id
|
||||
workflow.tenant_id = tenant_id
|
||||
workflow.version = "1.0" # Not a draft
|
||||
workflow.tool_published = False # Not published as a tool by default
|
||||
|
||||
# Mock app
|
||||
app = MagicMock(spec=App)
|
||||
app.id = "test-app-id"
|
||||
app.name = "Test App"
|
||||
app.workflow_id = None # Not used by an app by default
|
||||
|
||||
return {
|
||||
"workflow_service": workflow_service,
|
||||
"session": session,
|
||||
"tenant_id": tenant_id,
|
||||
"workflow_id": workflow_id,
|
||||
"workflow": workflow,
|
||||
"app": app,
|
||||
}
|
||||
|
||||
|
||||
def test_delete_workflow_success(workflow_setup):
|
||||
# Setup mocks
|
||||
|
||||
# Mock the tool provider query to return None (not published as a tool)
|
||||
workflow_setup["session"].query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
workflow_setup["session"].scalar = MagicMock(
|
||||
side_effect=[workflow_setup["workflow"], None]
|
||||
) # Return workflow first, then None for app
|
||||
|
||||
# Call the method
|
||||
result = workflow_setup["workflow_service"].delete_workflow(
|
||||
session=workflow_setup["session"],
|
||||
workflow_id=workflow_setup["workflow_id"],
|
||||
tenant_id=workflow_setup["tenant_id"],
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result is True
|
||||
workflow_setup["session"].delete.assert_called_once_with(workflow_setup["workflow"])
|
||||
|
||||
|
||||
def test_delete_workflow_draft_error(workflow_setup):
|
||||
# Setup mocks
|
||||
workflow_setup["workflow"].version = "draft"
|
||||
workflow_setup["session"].scalar = MagicMock(return_value=workflow_setup["workflow"])
|
||||
|
||||
# Call the method and verify exception
|
||||
with pytest.raises(DraftWorkflowDeletionError):
|
||||
workflow_setup["workflow_service"].delete_workflow(
|
||||
session=workflow_setup["session"],
|
||||
workflow_id=workflow_setup["workflow_id"],
|
||||
tenant_id=workflow_setup["tenant_id"],
|
||||
)
|
||||
|
||||
# Verify
|
||||
workflow_setup["session"].delete.assert_not_called()
|
||||
|
||||
|
||||
def test_delete_workflow_in_use_by_app_error(workflow_setup):
|
||||
# Setup mocks
|
||||
workflow_setup["app"].workflow_id = workflow_setup["workflow_id"]
|
||||
workflow_setup["session"].scalar = MagicMock(
|
||||
side_effect=[workflow_setup["workflow"], workflow_setup["app"]]
|
||||
) # Return workflow first, then app
|
||||
|
||||
# Call the method and verify exception
|
||||
with pytest.raises(WorkflowInUseError) as excinfo:
|
||||
workflow_setup["workflow_service"].delete_workflow(
|
||||
session=workflow_setup["session"],
|
||||
workflow_id=workflow_setup["workflow_id"],
|
||||
tenant_id=workflow_setup["tenant_id"],
|
||||
)
|
||||
|
||||
# Verify error message contains app name
|
||||
assert "Cannot delete workflow that is currently in use by app" in str(excinfo.value)
|
||||
|
||||
# Verify
|
||||
workflow_setup["session"].delete.assert_not_called()
|
||||
|
||||
|
||||
def test_delete_workflow_published_as_tool_error(workflow_setup):
|
||||
# Setup mocks
|
||||
from models.tools import WorkflowToolProvider
|
||||
|
||||
# Mock the tool provider query
|
||||
mock_tool_provider = MagicMock(spec=WorkflowToolProvider)
|
||||
workflow_setup["session"].query.return_value.where.return_value.first.return_value = mock_tool_provider
|
||||
|
||||
workflow_setup["session"].scalar = MagicMock(
|
||||
side_effect=[workflow_setup["workflow"], None]
|
||||
) # Return workflow first, then None for app
|
||||
|
||||
# Call the method and verify exception
|
||||
with pytest.raises(WorkflowInUseError) as excinfo:
|
||||
workflow_setup["workflow_service"].delete_workflow(
|
||||
session=workflow_setup["session"],
|
||||
workflow_id=workflow_setup["workflow_id"],
|
||||
tenant_id=workflow_setup["tenant_id"],
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert "Cannot delete workflow that is published as a tool" in str(excinfo.value)
|
||||
|
||||
# Verify
|
||||
workflow_setup["session"].delete.assert_not_called()
|
||||
@@ -0,0 +1,477 @@
|
||||
import dataclasses
|
||||
import secrets
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeType
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.account import Account
|
||||
from models.enums import DraftVariableType
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowDraftVariable,
|
||||
WorkflowDraftVariableFile,
|
||||
WorkflowNodeExecutionModel,
|
||||
is_system_variable_editable,
|
||||
)
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
VariableResetError,
|
||||
WorkflowDraftVariableService,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine() -> Engine:
|
||||
return Mock(spec=Engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(mock_engine) -> Session:
|
||||
mock_session = Mock(spec=Session)
|
||||
mock_session.get_bind.return_value = mock_engine
|
||||
return mock_session
|
||||
|
||||
|
||||
class TestDraftVariableSaver:
|
||||
def _get_test_app_id(self):
|
||||
suffix = secrets.token_hex(6)
|
||||
return f"test_app_id_{suffix}"
|
||||
|
||||
def test__should_variable_be_visible(self):
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_user = Account(name="test", email="test@example.com")
|
||||
mock_user.id = str(uuid.uuid4())
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
app_id=test_app_id,
|
||||
node_id="test_node_id",
|
||||
node_type=NodeType.START,
|
||||
node_execution_id="test_execution_id",
|
||||
user=mock_user,
|
||||
)
|
||||
assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False
|
||||
assert saver._should_variable_be_visible("123", NodeType.START, "output") == True
|
||||
|
||||
def test__normalize_variable_for_start_node(self):
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TestCase:
|
||||
name: str
|
||||
input_node_id: str
|
||||
input_name: str
|
||||
expected_node_id: str
|
||||
expected_name: str
|
||||
|
||||
_NODE_ID = "1747228642872"
|
||||
cases = [
|
||||
TestCase(
|
||||
name="name with `sys.` prefix should return the system node_id",
|
||||
input_node_id=_NODE_ID,
|
||||
input_name="sys.workflow_id",
|
||||
expected_node_id=SYSTEM_VARIABLE_NODE_ID,
|
||||
expected_name="workflow_id",
|
||||
),
|
||||
TestCase(
|
||||
name="name without `sys.` prefix should return the original input node_id",
|
||||
input_node_id=_NODE_ID,
|
||||
input_name="start_input",
|
||||
expected_node_id=_NODE_ID,
|
||||
expected_name="start_input",
|
||||
),
|
||||
TestCase(
|
||||
name="dummy_variable should return the original input node_id",
|
||||
input_node_id=_NODE_ID,
|
||||
input_name="__dummy__",
|
||||
expected_node_id=_NODE_ID,
|
||||
expected_name="__dummy__",
|
||||
),
|
||||
]
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_user = MagicMock()
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
app_id=test_app_id,
|
||||
node_id=_NODE_ID,
|
||||
node_type=NodeType.START,
|
||||
node_execution_id="test_execution_id",
|
||||
user=mock_user,
|
||||
)
|
||||
for idx, c in enumerate(cases, 1):
|
||||
fail_msg = f"Test case {c.name} failed, index={idx}"
|
||||
node_id, name = saver._normalize_variable_for_start_node(c.input_name)
|
||||
assert node_id == c.expected_node_id, fail_msg
|
||||
assert name == c.expected_name, fail_msg
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Mock SQLAlchemy session."""
|
||||
from sqlalchemy import Engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_engine = MagicMock(spec=Engine)
|
||||
mock_session.get_bind.return_value = mock_engine
|
||||
return mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def draft_saver(self, mock_session):
|
||||
"""Create DraftVariableSaver instance with user context."""
|
||||
# Create a mock user
|
||||
mock_user = MagicMock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.tenant_id = "test-tenant-id"
|
||||
|
||||
return DraftVariableSaver(
|
||||
session=mock_session,
|
||||
app_id="test-app-id",
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
node_execution_id="test-execution-id",
|
||||
user=mock_user,
|
||||
)
|
||||
|
||||
def test_draft_saver_with_small_variables(self, draft_saver, mock_session):
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
|
||||
) as _mock_try_offload:
|
||||
_mock_try_offload.return_value = None
|
||||
mock_segment = StringSegment(value="small value")
|
||||
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
|
||||
|
||||
# Should not have large variable metadata
|
||||
assert draft_var.file_id is None
|
||||
_mock_try_offload.return_value = None
|
||||
|
||||
def test_draft_saver_with_large_variables(self, draft_saver, mock_session):
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
|
||||
) as _mock_try_offload:
|
||||
mock_segment = StringSegment(value="small value")
|
||||
mock_draft_var_file = WorkflowDraftVariableFile(
|
||||
id=str(uuidv7()),
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.ARRAY_STRING,
|
||||
upload_file_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
_mock_try_offload.return_value = mock_segment, mock_draft_var_file
|
||||
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
|
||||
|
||||
# Should not have large variable metadata
|
||||
assert draft_var.file_id == mock_draft_var_file.id
|
||||
|
||||
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable")
|
||||
def test_save_method_integration(self, mock_batch_upsert, draft_saver):
|
||||
"""Test complete save workflow."""
|
||||
outputs = {"result": {"data": "test_output"}, "metadata": {"type": "llm_response"}}
|
||||
|
||||
draft_saver.save(outputs=outputs)
|
||||
|
||||
# Should batch upsert draft variables
|
||||
mock_batch_upsert.assert_called_once()
|
||||
draft_vars = mock_batch_upsert.call_args[0][1]
|
||||
assert len(draft_vars) == 2
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableService:
|
||||
def _get_test_app_id(self):
|
||||
suffix = secrets.token_hex(6)
|
||||
return f"test_app_id_{suffix}"
|
||||
|
||||
def _create_test_workflow(self, app_id: str) -> Workflow:
|
||||
"""Create a real Workflow instance for testing"""
|
||||
return Workflow.new(
|
||||
tenant_id="test_tenant_id",
|
||||
app_id=app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features="{}",
|
||||
created_by="test_user_id",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
|
||||
def test_reset_conversation_variable(self, mock_session):
|
||||
"""Test resetting a conversation variable"""
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create real conversation variable
|
||||
test_value = StringSegment(value="test_value")
|
||||
variable = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=test_app_id, name="test_var", value=test_value, description="Test conversation variable"
|
||||
)
|
||||
|
||||
# Mock the _reset_conv_var method
|
||||
expected_result = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=test_app_id,
|
||||
name="test_var",
|
||||
value=StringSegment(value="reset_value"),
|
||||
)
|
||||
with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv:
|
||||
result = service.reset_variable(workflow, variable)
|
||||
|
||||
mock_reset_conv.assert_called_once_with(workflow, variable)
|
||||
assert result == expected_result
|
||||
|
||||
def test_reset_node_variable_with_no_execution_id(self, mock_session):
|
||||
"""Test resetting a node variable with no execution ID - should delete variable"""
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create real node variable with no execution ID
|
||||
test_value = StringSegment(value="test_value")
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id,
|
||||
node_id="test_node_id",
|
||||
name="test_var",
|
||||
value=test_value,
|
||||
node_execution_id="exec-id", # Set initially
|
||||
)
|
||||
# Manually set to None to simulate the test condition
|
||||
variable.node_execution_id = None
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Should delete the variable and return None
|
||||
mock_session.delete.assert_called_once_with(instance=variable)
|
||||
mock_session.flush.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
def test_reset_node_variable_with_missing_execution_record(
|
||||
self,
|
||||
mock_engine,
|
||||
mock_session,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Test resetting a node variable when execution record doesn't exist"""
|
||||
mock_repo_session = Mock(spec=Session)
|
||||
|
||||
mock_session_maker = MagicMock()
|
||||
# Mock the context manager protocol for sessionmaker
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
monkeypatch.setattr("services.workflow_draft_variable_service.sessionmaker", mock_session_maker)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
# Mock the repository to return None (no execution record found)
|
||||
service._api_node_execution_repo = Mock()
|
||||
service._api_node_execution_repo.get_execution_by_id.return_value = None
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create real node variable with execution ID
|
||||
test_value = StringSegment(value="test_value")
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
|
||||
)
|
||||
# Variable is editable by default from factory method
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
mock_session_maker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
|
||||
# Should delete the variable and return None
|
||||
mock_session.delete.assert_called_once_with(instance=variable)
|
||||
mock_session.flush.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
def test_reset_node_variable_with_valid_execution_record(
|
||||
self,
|
||||
mock_session,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Test resetting a node variable with valid execution record - should restore from execution"""
|
||||
mock_repo_session = Mock(spec=Session)
|
||||
|
||||
mock_session_maker = MagicMock()
|
||||
# Mock the context manager protocol for sessionmaker
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
mock_session_maker = monkeypatch.setattr(
|
||||
"services.workflow_draft_variable_service.sessionmaker", mock_session_maker
|
||||
)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.load_full_outputs.return_value = {"test_var": "output_value"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create real node variable with execution ID
|
||||
test_value = StringSegment(value="original_value")
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
|
||||
)
|
||||
# Variable is editable by default from factory method
|
||||
|
||||
# Mock workflow methods
|
||||
mock_node_config = {"type": "test_node"}
|
||||
with (
|
||||
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config),
|
||||
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM),
|
||||
):
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Verify last_edited_at was reset
|
||||
assert variable.last_edited_at is None
|
||||
# Verify session.flush was called
|
||||
mock_session.flush.assert_called()
|
||||
|
||||
# Should return the updated variable
|
||||
assert result == variable
|
||||
|
||||
def test_reset_non_editable_system_variable_raises_error(self, mock_session):
|
||||
"""Test that resetting a non-editable system variable raises an error"""
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create a non-editable system variable (workflow_id is not editable)
|
||||
test_value = StringSegment(value="test_workflow_id")
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id,
|
||||
name="workflow_id", # This is not in _EDITABLE_SYSTEM_VARIABLE
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
editable=False, # Non-editable system variable
|
||||
)
|
||||
|
||||
with pytest.raises(VariableResetError) as exc_info:
|
||||
service.reset_variable(workflow, variable)
|
||||
assert "cannot reset system variable" in str(exc_info.value)
|
||||
assert f"variable_id={variable.id}" in str(exc_info.value)
|
||||
|
||||
def test_reset_editable_system_variable_succeeds(self, mock_session):
|
||||
"""Test that resetting an editable system variable succeeds"""
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create an editable system variable (files is editable)
|
||||
test_value = StringSegment(value="[]")
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id,
|
||||
name="files", # This is in _EDITABLE_SYSTEM_VARIABLE
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
editable=True, # Editable system variable
|
||||
)
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.load_full_outputs.return_value = {"sys.files": "[]"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Should succeed and return the variable
|
||||
assert result == variable
|
||||
assert variable.last_edited_at is None
|
||||
mock_session.flush.assert_called()
|
||||
|
||||
def test_reset_query_system_variable_succeeds(self, mock_session):
|
||||
"""Test that resetting query system variable (another editable one) succeeds"""
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create an editable system variable (query is editable)
|
||||
test_value = StringSegment(value="original query")
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id,
|
||||
name="query", # This is in _EDITABLE_SYSTEM_VARIABLE
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
editable=True, # Editable system variable
|
||||
)
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.load_full_outputs.return_value = {"sys.query": "reset query"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Should succeed and return the variable
|
||||
assert result == variable
|
||||
assert variable.last_edited_at is None
|
||||
mock_session.flush.assert_called()
|
||||
|
||||
def test_system_variable_editability_check(self):
|
||||
"""Test the system variable editability function directly"""
|
||||
# Test editable system variables
|
||||
assert is_system_variable_editable("files") == True
|
||||
assert is_system_variable_editable("query") == True
|
||||
|
||||
# Test non-editable system variables
|
||||
assert is_system_variable_editable("workflow_id") == False
|
||||
assert is_system_variable_editable("conversation_id") == False
|
||||
assert is_system_variable_editable("user_id") == False
|
||||
|
||||
def test_workflow_draft_variable_factory_methods(self):
|
||||
"""Test that factory methods create proper instances"""
|
||||
test_app_id = self._get_test_app_id()
|
||||
test_value = StringSegment(value="test_value")
|
||||
|
||||
# Test conversation variable factory
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=test_app_id, name="conv_var", value=test_value, description="Test conversation variable"
|
||||
)
|
||||
assert conv_var.get_variable_type() == DraftVariableType.CONVERSATION
|
||||
assert conv_var.editable == True
|
||||
assert conv_var.node_execution_id is None
|
||||
|
||||
# Test system variable factory
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id, name="workflow_id", value=test_value, node_execution_id="exec-id", editable=False
|
||||
)
|
||||
assert sys_var.get_variable_type() == DraftVariableType.SYS
|
||||
assert sys_var.editable == False
|
||||
assert sys_var.node_execution_id == "exec-id"
|
||||
|
||||
# Test node variable factory
|
||||
node_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id,
|
||||
node_id="node-id",
|
||||
name="node_var",
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
visible=True,
|
||||
editable=True,
|
||||
)
|
||||
assert node_var.get_variable_type() == DraftVariableType.NODE
|
||||
assert node_var.visible == True
|
||||
assert node_var.editable == True
|
||||
assert node_var.node_execution_id == "exec-id"
|
||||
@@ -0,0 +1,288 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
|
||||
|
||||
class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
|
||||
@pytest.fixture
|
||||
def repository(self):
|
||||
mock_session_maker = MagicMock()
|
||||
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_execution(self):
|
||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
execution.id = str(uuid4())
|
||||
execution.tenant_id = "tenant-123"
|
||||
execution.app_id = "app-456"
|
||||
execution.workflow_id = "workflow-789"
|
||||
execution.workflow_run_id = "run-101"
|
||||
execution.node_id = "node-202"
|
||||
execution.index = 1
|
||||
execution.created_at = "2023-01-01T00:00:00Z"
|
||||
return execution
|
||||
|
||||
def test_get_node_last_execution_found(self, repository, mock_execution):
|
||||
"""Test getting the last execution for a node when it exists."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = mock_execution
|
||||
|
||||
# Act
|
||||
result = repository.get_node_last_execution(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_id="workflow-789",
|
||||
node_id="node-202",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_execution
|
||||
mock_session.scalar.assert_called_once()
|
||||
# Verify the query was constructed correctly
|
||||
call_args = mock_session.scalar.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
def test_get_node_last_execution_not_found(self, repository):
|
||||
"""Test getting the last execution for a node when it doesn't exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = repository.get_node_last_execution(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_id="workflow-789",
|
||||
node_id="node-202",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_executions_by_workflow_run(self, repository, mock_execution):
|
||||
"""Test getting all executions for a workflow run."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
executions = [mock_execution]
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = executions
|
||||
|
||||
# Act
|
||||
result = repository.get_executions_by_workflow_run(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_run_id="run-101",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == executions
|
||||
mock_session.execute.assert_called_once()
|
||||
# Verify the query was constructed correctly
|
||||
call_args = mock_session.execute.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
def test_get_executions_by_workflow_run_empty(self, repository):
|
||||
"""Test getting executions for a workflow run when none exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = repository.get_executions_by_workflow_run(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_run_id="run-101",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_get_execution_by_id_found(self, repository, mock_execution):
|
||||
"""Test getting execution by ID when it exists."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = mock_execution
|
||||
|
||||
# Act
|
||||
result = repository.get_execution_by_id(mock_execution.id)
|
||||
|
||||
# Assert
|
||||
assert result == mock_execution
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_execution_by_id_not_found(self, repository):
|
||||
"""Test getting execution by ID when it doesn't exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = repository.get_execution_by_id("non-existent-id")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_repository_implements_protocol(self, repository):
|
||||
"""Test that the repository implements the required protocol methods."""
|
||||
# Verify all protocol methods are implemented
|
||||
assert hasattr(repository, "get_node_last_execution")
|
||||
assert hasattr(repository, "get_executions_by_workflow_run")
|
||||
assert hasattr(repository, "get_execution_by_id")
|
||||
|
||||
# Verify methods are callable
|
||||
assert callable(repository.get_node_last_execution)
|
||||
assert callable(repository.get_executions_by_workflow_run)
|
||||
assert callable(repository.get_execution_by_id)
|
||||
assert callable(repository.delete_expired_executions)
|
||||
assert callable(repository.delete_executions_by_app)
|
||||
assert callable(repository.get_expired_executions_batch)
|
||||
assert callable(repository.delete_executions_by_ids)
|
||||
|
||||
def test_delete_expired_executions(self, repository):
|
||||
"""Test deleting expired executions."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the select query to return some IDs first time, then empty to stop loop
|
||||
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
|
||||
|
||||
# Mock execute method to handle both select and delete statements
|
||||
def mock_execute(stmt):
|
||||
mock_result = MagicMock()
|
||||
# For select statements, return execution IDs
|
||||
if hasattr(stmt, "limit"): # This is our select statement
|
||||
mock_result.scalars.return_value.all.return_value = execution_ids
|
||||
else: # This is our delete statement
|
||||
mock_result.rowcount = 2
|
||||
return mock_result
|
||||
|
||||
mock_session.execute.side_effect = mock_execute
|
||||
|
||||
before_date = datetime(2023, 1, 1)
|
||||
|
||||
# Act
|
||||
result = repository.delete_expired_executions(
|
||||
tenant_id="tenant-123",
|
||||
before_date=before_date,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == 2
|
||||
assert mock_session.execute.call_count == 2 # One select call, one delete call
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_app(self, repository):
|
||||
"""Test deleting executions by app."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the select query to return some IDs first time, then empty to stop loop
|
||||
execution_ids = ["id1", "id2"]
|
||||
|
||||
# Mock execute method to handle both select and delete statements
|
||||
def mock_execute(stmt):
|
||||
mock_result = MagicMock()
|
||||
# For select statements, return execution IDs
|
||||
if hasattr(stmt, "limit"): # This is our select statement
|
||||
mock_result.scalars.return_value.all.return_value = execution_ids
|
||||
else: # This is our delete statement
|
||||
mock_result.rowcount = 2
|
||||
return mock_result
|
||||
|
||||
mock_session.execute.side_effect = mock_execute
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_app(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == 2
|
||||
assert mock_session.execute.call_count == 2 # One select call, one delete call
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_get_expired_executions_batch(self, repository):
|
||||
"""Test getting expired executions batch for backup."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Create mock execution objects
|
||||
mock_execution1 = MagicMock()
|
||||
mock_execution1.id = "exec-1"
|
||||
mock_execution2 = MagicMock()
|
||||
mock_execution2.id = "exec-2"
|
||||
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
|
||||
|
||||
before_date = datetime(2023, 1, 1)
|
||||
|
||||
# Act
|
||||
result = repository.get_expired_executions_batch(
|
||||
tenant_id="tenant-123",
|
||||
before_date=before_date,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0].id == "exec-1"
|
||||
assert result[1].id == "exec-2"
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_ids(self, repository):
|
||||
"""Test deleting executions by IDs."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the delete query result
|
||||
mock_result = MagicMock()
|
||||
mock_result.rowcount = 3
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
execution_ids = ["id1", "id2", "id3"]
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_ids(execution_ids)
|
||||
|
||||
# Assert
|
||||
assert result == 3
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_ids_empty_list(self, repository):
|
||||
"""Test deleting executions with empty ID list."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_ids([])
|
||||
|
||||
# Assert
|
||||
assert result == 0
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
@@ -0,0 +1,163 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class TestWorkflowService:
|
||||
@pytest.fixture
|
||||
def workflow_service(self):
|
||||
mock_session_maker = MagicMock()
|
||||
return WorkflowService(mock_session_maker)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
app = MagicMock(spec=App)
|
||||
app.id = "app-id-1"
|
||||
app.workflow_id = "workflow-id-1"
|
||||
app.tenant_id = "tenant-id-1"
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflows(self):
|
||||
workflows = []
|
||||
for i in range(5):
|
||||
workflow = MagicMock(spec=Workflow)
|
||||
workflow.id = f"workflow-id-{i}"
|
||||
workflow.app_id = "app-id-1"
|
||||
workflow.created_at = f"2023-01-0{5 - i}" # Descending date order
|
||||
workflow.created_by = "user-id-1" if i % 2 == 0 else "user-id-2"
|
||||
workflow.marked_name = f"Workflow {i}" if i % 2 == 0 else ""
|
||||
workflows.append(workflow)
|
||||
return workflows
|
||||
|
||||
def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app):
|
||||
mock_app.workflow_id = None
|
||||
mock_session = MagicMock()
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
|
||||
)
|
||||
|
||||
assert workflows == []
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_not_called()
|
||||
|
||||
def test_get_all_published_workflow_basic(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
mock_scalar_result.all.return_value = mock_workflows[:3]
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
|
||||
)
|
||||
|
||||
assert workflows == mock_workflows[:3]
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
def test_get_all_published_workflow_pagination(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
# Return 4 items when limit is 3, which should indicate has_more=True
|
||||
mock_scalar_result.all.return_value = mock_workflows[:4]
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
|
||||
)
|
||||
|
||||
# Should return only the first 3 items
|
||||
assert len(workflows) == 3
|
||||
assert workflows == mock_workflows[:3]
|
||||
assert has_more is True
|
||||
|
||||
# Test page 2
|
||||
mock_scalar_result.all.return_value = mock_workflows[3:]
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=2, limit=3, user_id=None
|
||||
)
|
||||
|
||||
assert len(workflows) == 2
|
||||
assert has_more is False
|
||||
|
||||
def test_get_all_published_workflow_user_filter(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
# Filter workflows for user-id-1
|
||||
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1"]
|
||||
mock_scalar_result.all.return_value = filtered_workflows
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1"
|
||||
)
|
||||
|
||||
assert workflows == filtered_workflows
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
# Verify that the select contains a user filter clause
|
||||
args = mock_session.scalars.call_args[0][0]
|
||||
assert "created_by" in str(args)
|
||||
|
||||
def test_get_all_published_workflow_named_only(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
# Filter workflows that have a marked_name
|
||||
named_workflows = [w for w in mock_workflows if w.marked_name]
|
||||
mock_scalar_result.all.return_value = named_workflows
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None, named_only=True
|
||||
)
|
||||
|
||||
assert workflows == named_workflows
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
# Verify that the select contains a named_only filter clause
|
||||
args = mock_session.scalars.call_args[0][0]
|
||||
assert "marked_name !=" in str(args)
|
||||
|
||||
def test_get_all_published_workflow_combined_filters(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
# Combined filter: user-id-1 and has marked_name
|
||||
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1" and w.marked_name]
|
||||
mock_scalar_result.all.return_value = filtered_workflows
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1", named_only=True
|
||||
)
|
||||
|
||||
assert workflows == filtered_workflows
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
# Verify that both filters are applied
|
||||
args = mock_session.scalars.call_args[0][0]
|
||||
assert "created_by" in str(args)
|
||||
assert "marked_name !=" in str(args)
|
||||
|
||||
def test_get_all_published_workflow_empty_result(self, workflow_service, mock_app):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
mock_scalar_result.all.return_value = []
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
|
||||
)
|
||||
|
||||
assert workflows == []
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
Reference in New Issue
Block a user