This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1,70 @@
import json
from collections.abc import Generator
from core.agent.entities import AgentScratchpadUnit
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta
def mock_llm_response(text) -> Generator[LLMResultChunk, None, None]:
for i in range(len(text)):
yield LLMResultChunk(
model="model",
prompt_messages=[],
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text[i], tool_calls=[])),
)
def test_cot_output_parser():
test_cases = [
{
"input": 'Through: abc\nAction: ```{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# code block with json
{
"input": 'Through: abc\nAction: ```json\n{"action": "Final Answer", "action_input": "```echarts\n {'
'}\n```"}```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# code block with JSON
{
"input": 'Through: abc\nAction: ```JSON\n{"action": "Final Answer", "action_input": "```echarts\n {'
'}\n```"}```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# list
{
"input": 'Through: abc\nAction: ```[{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}]```',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# no code block
{
"input": 'Through: abc\nAction: {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}',
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
},
# no code block and json
{"input": "Through: abc\nAction: efg", "action": {}, "output": "Through: abc\n efg"},
]
parser = CotAgentOutputParser()
usage_dict = {}
for test_case in test_cases:
# mock llm_response as a generator by text
llm_response: Generator[LLMResultChunk, None, None] = mock_llm_response(test_case["input"])
results = parser.handle_react_stream_output(llm_response, usage_dict)
output = ""
for result in results:
if isinstance(result, str):
output += result
elif isinstance(result, AgentScratchpadUnit.Action):
if test_case["action"]:
assert result.to_dict() == test_case["action"]
output += json.dumps(result.to_dict())
if test_case["output"]:
assert output == test_case["output"]

View File

@@ -0,0 +1,65 @@
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file.models import FileTransferMethod, FileUploadConfig, ImageConfig
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
def test_convert_with_vision():
config = {
"file_upload": {
"enabled": True,
"number_limits": 5,
"allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL],
"image": {"detail": "high"},
}
}
result = FileUploadConfigManager.convert(config, is_vision=True)
expected = FileUploadConfig(
image_config=ImageConfig(
number_limits=5,
transfer_methods=[FileTransferMethod.REMOTE_URL],
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
allowed_file_upload_methods=[FileTransferMethod.REMOTE_URL],
number_limits=5,
)
assert result == expected
def test_convert_without_vision():
config = {
"file_upload": {
"enabled": True,
"number_limits": 5,
"allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL],
}
}
result = FileUploadConfigManager.convert(config, is_vision=False)
expected = FileUploadConfig(
image_config=ImageConfig(number_limits=5, transfer_methods=[FileTransferMethod.REMOTE_URL]),
allowed_file_upload_methods=[FileTransferMethod.REMOTE_URL],
number_limits=5,
)
assert result == expected
def test_validate_and_set_defaults():
config = {}
result, keys = FileUploadConfigManager.validate_and_set_defaults(config)
assert "file_upload" in result
assert keys == ["file_upload"]
def test_validate_and_set_defaults_with_existing_config():
config = {
"file_upload": {
"enabled": True,
"number_limits": 5,
"allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL],
}
}
result, keys = FileUploadConfigManager.validate_and_set_defaults(config)
assert "file_upload" in result
assert keys == ["file_upload"]
assert result["file_upload"]["enabled"] is True
assert result["file_upload"]["number_limits"] == 5
assert result["file_upload"]["allowed_file_upload_methods"] == [FileTransferMethod.REMOTE_URL]

View File

@@ -0,0 +1,443 @@
"""Test conversation variable handling in AdvancedChatAppRunner."""
from unittest.mock import MagicMock, patch
from uuid import uuid4
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.variables import SegmentType
from factories import variable_factory
from models import ConversationVariable, Workflow
class TestAdvancedChatAppRunnerConversationVariables:
"""Test that AdvancedChatAppRunner correctly handles conversation variables."""
def test_missing_conversation_variables_are_added(self):
"""Test that new conversation variables added to workflow are created for existing conversations."""
# Setup
app_id = str(uuid4())
conversation_id = str(uuid4())
workflow_id = str(uuid4())
# Create workflow with two conversation variables
workflow_vars = [
variable_factory.build_conversation_variable_from_mapping(
{
"id": "var1",
"name": "existing_var",
"value_type": SegmentType.STRING,
"value": "default1",
}
),
variable_factory.build_conversation_variable_from_mapping(
{
"id": "var2",
"name": "new_var",
"value_type": SegmentType.STRING,
"value": "default2",
}
),
]
# Mock workflow with conversation variables
mock_workflow = MagicMock(spec=Workflow)
mock_workflow.conversation_variables = workflow_vars
mock_workflow.tenant_id = str(uuid4())
mock_workflow.app_id = app_id
mock_workflow.id = workflow_id
mock_workflow.type = "chat"
mock_workflow.graph_dict = {}
mock_workflow.environment_variables = []
# Create existing conversation variable (only var1 exists in DB)
existing_db_var = MagicMock(spec=ConversationVariable)
existing_db_var.id = "var1"
existing_db_var.app_id = app_id
existing_db_var.conversation_id = conversation_id
existing_db_var.to_variable = MagicMock(return_value=workflow_vars[0])
# Mock conversation and message
mock_conversation = MagicMock()
mock_conversation.app_id = app_id
mock_conversation.id = conversation_id
mock_message = MagicMock()
mock_message.id = str(uuid4())
# Mock app config
mock_app_config = MagicMock()
mock_app_config.app_id = app_id
mock_app_config.workflow_id = workflow_id
mock_app_config.tenant_id = str(uuid4())
# Mock app generate entity
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
mock_app_generate_entity.app_config = mock_app_config
mock_app_generate_entity.inputs = {}
mock_app_generate_entity.query = "test query"
mock_app_generate_entity.files = []
mock_app_generate_entity.user_id = str(uuid4())
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
mock_app_generate_entity.workflow_run_id = str(uuid4())
mock_app_generate_entity.task_id = str(uuid4())
mock_app_generate_entity.call_depth = 0
mock_app_generate_entity.single_iteration_run = None
mock_app_generate_entity.single_loop_run = None
mock_app_generate_entity.trace_manager = None
# Create runner
runner = AdvancedChatAppRunner(
application_generate_entity=mock_app_generate_entity,
queue_manager=MagicMock(),
conversation=mock_conversation,
message=mock_message,
dialogue_count=1,
variable_loader=MagicMock(),
workflow=mock_workflow,
system_user_id=str(uuid4()),
app=MagicMock(),
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
# Mock database session
mock_session = MagicMock(spec=Session)
# First query returns only existing variable
mock_scalars_result = MagicMock()
mock_scalars_result.all.return_value = [existing_db_var]
mock_session.scalars.return_value = mock_scalars_result
# Track what gets added to session
added_items = []
def track_add_all(items):
added_items.extend(items)
mock_session.add_all.side_effect = track_add_all
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
patch.object(runner, "_init_graph") as mock_init_graph,
patch.object(runner, "handle_input_moderation", return_value=False),
patch.object(runner, "handle_annotation_reply", return_value=False),
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_session_class.return_value.__enter__.return_value = mock_session
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock GraphRuntimeState to accept the variable pool
mock_graph_runtime_state_class.return_value = MagicMock()
# Mock graph initialization
mock_init_graph.return_value = MagicMock()
# Mock workflow entry
mock_workflow_entry = MagicMock()
mock_workflow_entry.run.return_value = iter([]) # Empty generator
mock_workflow_entry_class.return_value = mock_workflow_entry
# Run the method
runner.run()
# Verify that the missing variable was added
assert len(added_items) == 1, "Should have added exactly one missing variable"
# Check that the added item is the missing variable (var2)
added_var = added_items[0]
assert hasattr(added_var, "id"), "Added item should be a ConversationVariable"
# Note: Since we're mocking ConversationVariable.from_variable,
# we can't directly check the id, but we can verify add_all was called
assert mock_session.add_all.called, "Session add_all should have been called"
assert mock_session.commit.called, "Session commit should have been called"
def test_no_variables_creates_all(self):
"""Test that all conversation variables are created when none exist in DB."""
# Setup
app_id = str(uuid4())
conversation_id = str(uuid4())
workflow_id = str(uuid4())
# Create workflow with conversation variables
workflow_vars = [
variable_factory.build_conversation_variable_from_mapping(
{
"id": "var1",
"name": "var1",
"value_type": SegmentType.STRING,
"value": "default1",
}
),
variable_factory.build_conversation_variable_from_mapping(
{
"id": "var2",
"name": "var2",
"value_type": SegmentType.STRING,
"value": "default2",
}
),
]
# Mock workflow
mock_workflow = MagicMock(spec=Workflow)
mock_workflow.conversation_variables = workflow_vars
mock_workflow.tenant_id = str(uuid4())
mock_workflow.app_id = app_id
mock_workflow.id = workflow_id
mock_workflow.type = "chat"
mock_workflow.graph_dict = {}
mock_workflow.environment_variables = []
# Mock conversation and message
mock_conversation = MagicMock()
mock_conversation.app_id = app_id
mock_conversation.id = conversation_id
mock_message = MagicMock()
mock_message.id = str(uuid4())
# Mock app config
mock_app_config = MagicMock()
mock_app_config.app_id = app_id
mock_app_config.workflow_id = workflow_id
mock_app_config.tenant_id = str(uuid4())
# Mock app generate entity
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
mock_app_generate_entity.app_config = mock_app_config
mock_app_generate_entity.inputs = {}
mock_app_generate_entity.query = "test query"
mock_app_generate_entity.files = []
mock_app_generate_entity.user_id = str(uuid4())
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
mock_app_generate_entity.workflow_run_id = str(uuid4())
mock_app_generate_entity.task_id = str(uuid4())
mock_app_generate_entity.call_depth = 0
mock_app_generate_entity.single_iteration_run = None
mock_app_generate_entity.single_loop_run = None
mock_app_generate_entity.trace_manager = None
# Create runner
runner = AdvancedChatAppRunner(
application_generate_entity=mock_app_generate_entity,
queue_manager=MagicMock(),
conversation=mock_conversation,
message=mock_message,
dialogue_count=1,
variable_loader=MagicMock(),
workflow=mock_workflow,
system_user_id=str(uuid4()),
app=MagicMock(),
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
# Mock database session
mock_session = MagicMock(spec=Session)
# Query returns empty list (no existing variables)
mock_scalars_result = MagicMock()
mock_scalars_result.all.return_value = []
mock_session.scalars.return_value = mock_scalars_result
# Track what gets added to session
added_items = []
def track_add_all(items):
added_items.extend(items)
mock_session.add_all.side_effect = track_add_all
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
patch.object(runner, "_init_graph") as mock_init_graph,
patch.object(runner, "handle_input_moderation", return_value=False),
patch.object(runner, "handle_annotation_reply", return_value=False),
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
patch("core.app.apps.advanced_chat.app_runner.ConversationVariable") as mock_conv_var_class,
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_session_class.return_value.__enter__.return_value = mock_session
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock ConversationVariable.from_variable to return mock objects
mock_conv_vars = []
for var in workflow_vars:
mock_cv = MagicMock()
mock_cv.id = var.id
mock_cv.to_variable.return_value = var
mock_conv_vars.append(mock_cv)
mock_conv_var_class.from_variable.side_effect = mock_conv_vars
# Mock GraphRuntimeState to accept the variable pool
mock_graph_runtime_state_class.return_value = MagicMock()
# Mock graph initialization
mock_init_graph.return_value = MagicMock()
# Mock workflow entry
mock_workflow_entry = MagicMock()
mock_workflow_entry.run.return_value = iter([]) # Empty generator
mock_workflow_entry_class.return_value = mock_workflow_entry
# Run the method
runner.run()
# Verify that all variables were created
assert len(added_items) == 2, "Should have added both variables"
assert mock_session.add_all.called, "Session add_all should have been called"
assert mock_session.commit.called, "Session commit should have been called"
def test_all_variables_exist_no_changes(self):
"""Test that no changes are made when all variables already exist in DB."""
# Setup
app_id = str(uuid4())
conversation_id = str(uuid4())
workflow_id = str(uuid4())
# Create workflow with conversation variables
workflow_vars = [
variable_factory.build_conversation_variable_from_mapping(
{
"id": "var1",
"name": "var1",
"value_type": SegmentType.STRING,
"value": "default1",
}
),
variable_factory.build_conversation_variable_from_mapping(
{
"id": "var2",
"name": "var2",
"value_type": SegmentType.STRING,
"value": "default2",
}
),
]
# Mock workflow
mock_workflow = MagicMock(spec=Workflow)
mock_workflow.conversation_variables = workflow_vars
mock_workflow.tenant_id = str(uuid4())
mock_workflow.app_id = app_id
mock_workflow.id = workflow_id
mock_workflow.type = "chat"
mock_workflow.graph_dict = {}
mock_workflow.environment_variables = []
# Create existing conversation variables (both exist in DB)
existing_db_vars = []
for var in workflow_vars:
db_var = MagicMock(spec=ConversationVariable)
db_var.id = var.id
db_var.app_id = app_id
db_var.conversation_id = conversation_id
db_var.to_variable = MagicMock(return_value=var)
existing_db_vars.append(db_var)
# Mock conversation and message
mock_conversation = MagicMock()
mock_conversation.app_id = app_id
mock_conversation.id = conversation_id
mock_message = MagicMock()
mock_message.id = str(uuid4())
# Mock app config
mock_app_config = MagicMock()
mock_app_config.app_id = app_id
mock_app_config.workflow_id = workflow_id
mock_app_config.tenant_id = str(uuid4())
# Mock app generate entity
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
mock_app_generate_entity.app_config = mock_app_config
mock_app_generate_entity.inputs = {}
mock_app_generate_entity.query = "test query"
mock_app_generate_entity.files = []
mock_app_generate_entity.user_id = str(uuid4())
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
mock_app_generate_entity.workflow_run_id = str(uuid4())
mock_app_generate_entity.task_id = str(uuid4())
mock_app_generate_entity.call_depth = 0
mock_app_generate_entity.single_iteration_run = None
mock_app_generate_entity.single_loop_run = None
mock_app_generate_entity.trace_manager = None
# Create runner
runner = AdvancedChatAppRunner(
application_generate_entity=mock_app_generate_entity,
queue_manager=MagicMock(),
conversation=mock_conversation,
message=mock_message,
dialogue_count=1,
variable_loader=MagicMock(),
workflow=mock_workflow,
system_user_id=str(uuid4()),
app=MagicMock(),
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
# Mock database session
mock_session = MagicMock(spec=Session)
# Query returns all existing variables
mock_scalars_result = MagicMock()
mock_scalars_result.all.return_value = existing_db_vars
mock_session.scalars.return_value = mock_scalars_result
# Patch the necessary components
with (
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
patch.object(runner, "_init_graph") as mock_init_graph,
patch.object(runner, "handle_input_moderation", return_value=False),
patch.object(runner, "handle_annotation_reply", return_value=False),
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
):
# Setup mocks
mock_session_class.return_value.__enter__.return_value = mock_session
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
mock_db.engine = MagicMock()
# Mock GraphRuntimeState to accept the variable pool
mock_graph_runtime_state_class.return_value = MagicMock()
# Mock graph initialization
mock_init_graph.return_value = MagicMock()
# Mock workflow entry
mock_workflow_entry = MagicMock()
mock_workflow_entry.run.return_value = iter([]) # Empty generator
mock_workflow_entry_class.return_value = mock_workflow_entry
# Run the method
runner.run()
# Verify that no variables were added
assert not mock_session.add_all.called, "Session add_all should not have been called"
assert mock_session.commit.called, "Session commit should still be called"

View File

@@ -0,0 +1,63 @@
from types import SimpleNamespace
import pytest
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.workflow.runtime import GraphRuntimeState
from core.workflow.runtime.variable_pool import VariablePool
from core.workflow.system_variable import SystemVariable
def _make_state(workflow_run_id: str | None) -> GraphRuntimeState:
variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=workflow_run_id))
return GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
class _StubPipeline(GraphRuntimeStateSupport):
def __init__(self, *, cached_state: GraphRuntimeState | None, queue_state: GraphRuntimeState | None):
self._graph_runtime_state = cached_state
self._base_task_pipeline = SimpleNamespace(queue_manager=SimpleNamespace(graph_runtime_state=queue_state))
def test_ensure_graph_runtime_initialized_caches_explicit_state():
explicit_state = _make_state("run-explicit")
pipeline = _StubPipeline(cached_state=None, queue_state=None)
resolved = pipeline._ensure_graph_runtime_initialized(explicit_state)
assert resolved is explicit_state
assert pipeline._graph_runtime_state is explicit_state
def test_resolve_graph_runtime_state_reads_from_queue_when_cache_empty():
queued_state = _make_state("run-queue")
pipeline = _StubPipeline(cached_state=None, queue_state=queued_state)
resolved = pipeline._resolve_graph_runtime_state()
assert resolved is queued_state
assert pipeline._graph_runtime_state is queued_state
def test_resolve_graph_runtime_state_raises_when_no_state_available():
pipeline = _StubPipeline(cached_state=None, queue_state=None)
with pytest.raises(ValueError):
pipeline._resolve_graph_runtime_state()
def test_extract_workflow_run_id_returns_value():
state = _make_state("run-identifier")
pipeline = _StubPipeline(cached_state=state, queue_state=None)
run_id = pipeline._extract_workflow_run_id(state)
assert run_id == "run-identifier"
def test_extract_workflow_run_id_raises_when_missing():
state = _make_state(None)
pipeline = _StubPipeline(cached_state=state, queue_state=None)
with pytest.raises(ValueError):
pipeline._extract_workflow_run_id(state)

View File

@@ -0,0 +1,259 @@
from collections.abc import Mapping, Sequence
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.variables.segments import ArrayFileSegment, FileSegment
class TestWorkflowResponseConverterFetchFilesFromVariableValue:
"""Test class for WorkflowResponseConverter._fetch_files_from_variable_value method"""
def create_test_file(self, file_id: str = "test_file_1") -> File:
"""Create a test File object"""
return File(
id=file_id,
tenant_id="test_tenant",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related_123",
filename=f"{file_id}.txt",
extension=".txt",
mime_type="text/plain",
size=1024,
storage_key="storage_key_123",
)
def create_file_dict(self, file_id: str = "test_file_dict"):
"""Create a file dictionary with correct dify_model_identity"""
return {
"dify_model_identity": FILE_MODEL_IDENTITY,
"id": file_id,
"tenant_id": "test_tenant",
"type": "document",
"transfer_method": "local_file",
"related_id": "related_456",
"filename": f"{file_id}.txt",
"extension": ".txt",
"mime_type": "text/plain",
"size": 2048,
"url": "http://example.com/file.txt",
}
def test_fetch_files_from_variable_value_with_none(self):
"""Test with None input"""
# The method signature expects Union[dict, list, Segment], but implementation handles None
# We'll test the actual behavior by passing an empty dict instead
result = WorkflowResponseConverter._fetch_files_from_variable_value(None)
assert result == []
def test_fetch_files_from_variable_value_with_empty_dict(self):
"""Test with empty dictionary"""
result = WorkflowResponseConverter._fetch_files_from_variable_value({})
assert result == []
def test_fetch_files_from_variable_value_with_empty_list(self):
"""Test with empty list"""
result = WorkflowResponseConverter._fetch_files_from_variable_value([])
assert result == []
def test_fetch_files_from_variable_value_with_file_segment(self):
"""Test with valid FileSegment"""
test_file = self.create_test_file("segment_file")
file_segment = FileSegment(value=test_file)
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_segment)
assert len(result) == 1
assert isinstance(result[0], dict)
assert result[0]["id"] == "segment_file"
assert result[0]["dify_model_identity"] == FILE_MODEL_IDENTITY
def test_fetch_files_from_variable_value_with_array_file_segment_single(self):
"""Test with ArrayFileSegment containing single file"""
test_file = self.create_test_file("array_file_1")
array_segment = ArrayFileSegment(value=[test_file])
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
assert len(result) == 1
assert isinstance(result[0], dict)
assert result[0]["id"] == "array_file_1"
def test_fetch_files_from_variable_value_with_array_file_segment_multiple(self):
"""Test with ArrayFileSegment containing multiple files"""
test_file_1 = self.create_test_file("array_file_1")
test_file_2 = self.create_test_file("array_file_2")
array_segment = ArrayFileSegment(value=[test_file_1, test_file_2])
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
assert len(result) == 2
assert result[0]["id"] == "array_file_1"
assert result[1]["id"] == "array_file_2"
def test_fetch_files_from_variable_value_with_array_file_segment_empty(self):
"""Test with ArrayFileSegment containing empty array"""
array_segment = ArrayFileSegment(value=[])
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
assert result == []
def test_fetch_files_from_variable_value_with_list_of_file_dicts(self):
"""Test with list containing file dictionaries"""
file_dict_1 = self.create_file_dict("list_file_1")
file_dict_2 = self.create_file_dict("list_file_2")
test_list = [file_dict_1, file_dict_2]
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
assert len(result) == 2
assert result[0]["id"] == "list_file_1"
assert result[1]["id"] == "list_file_2"
def test_fetch_files_from_variable_value_with_list_of_file_objects(self):
"""Test with list containing File objects"""
file_obj_1 = self.create_test_file("list_obj_1")
file_obj_2 = self.create_test_file("list_obj_2")
test_list = [file_obj_1, file_obj_2]
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
assert len(result) == 2
assert result[0]["id"] == "list_obj_1"
assert result[1]["id"] == "list_obj_2"
def test_fetch_files_from_variable_value_with_list_mixed_valid_invalid(self):
"""Test with list containing mix of valid files and invalid items"""
file_dict = self.create_file_dict("mixed_file")
invalid_dict = {"not_a_file": "value"}
test_list = [file_dict, invalid_dict, "string_item", 123]
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
assert len(result) == 1
assert result[0]["id"] == "mixed_file"
def test_fetch_files_from_variable_value_with_list_nested_structures(self):
"""Test with list containing nested structures"""
file_dict = self.create_file_dict("nested_file")
nested_list = [file_dict, ["inner_list"]]
test_list = [nested_list, {"nested": "dict"}]
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
# Should not process nested structures in list items
assert result == []
def test_fetch_files_from_variable_value_with_dict_incorrect_identity(self):
"""Test with dictionary having incorrect dify_model_identity"""
invalid_dict = {"dify_model_identity": "wrong_identity", "id": "invalid_file", "filename": "test.txt"}
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
assert result == []
def test_fetch_files_from_variable_value_with_dict_missing_identity(self):
"""Test with dictionary missing dify_model_identity"""
invalid_dict = {"id": "no_identity_file", "filename": "test.txt"}
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
assert result == []
def test_fetch_files_from_variable_value_with_dict_file_object(self):
"""Test with dictionary containing File object"""
file_obj = self.create_test_file("dict_obj_file")
test_dict = {"file_key": file_obj}
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_dict)
# Should not extract File objects from dict values
assert result == []
def test_fetch_files_from_variable_value_with_mixed_data_types(self):
"""Test with various mixed data types"""
mixed_data = {"string": "text", "number": 42, "boolean": True, "null": None, "dify_model_identity": "wrong"}
result = WorkflowResponseConverter._fetch_files_from_variable_value(mixed_data)
assert result == []
def test_fetch_files_from_variable_value_with_invalid_objects(self):
"""Test with invalid objects that are not supported types"""
# Test with an invalid dict that doesn't match expected patterns
invalid_dict = {"custom_key": "custom_value"}
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
assert result == []
def test_fetch_files_from_variable_value_with_string_input(self):
"""Test with string input (unsupported type)"""
# Since method expects Union[dict, list, Segment], test with empty list instead
result = WorkflowResponseConverter._fetch_files_from_variable_value([])
assert result == []
def test_fetch_files_from_variable_value_with_number_input(self):
"""Test with number input (unsupported type)"""
# Test with list containing numbers (should be ignored)
result = WorkflowResponseConverter._fetch_files_from_variable_value([42, "string", None])
assert result == []
def test_fetch_files_from_variable_value_return_type_is_sequence(self):
"""Test that return type is Sequence[Mapping[str, Any]]"""
file_dict = self.create_file_dict("type_test_file")
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_dict)
assert isinstance(result, Sequence)
assert len(result) == 1
assert isinstance(result[0], Mapping)
assert all(isinstance(key, str) for key in result[0])
def test_fetch_files_from_variable_value_preserves_file_properties(self):
"""Test that all file properties are preserved in the result"""
original_file = self.create_test_file("property_test")
file_segment = FileSegment(value=original_file)
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_segment)
assert len(result) == 1
file_dict = result[0]
assert file_dict["id"] == "property_test"
assert file_dict["tenant_id"] == "test_tenant"
assert file_dict["type"] == "document"
assert file_dict["transfer_method"] == "local_file"
assert file_dict["filename"] == "property_test.txt"
assert file_dict["extension"] == ".txt"
assert file_dict["mime_type"] == "text/plain"
assert file_dict["size"] == 1024
def test_fetch_files_from_variable_value_with_complex_nested_scenario(self):
"""Test complex scenario with nested valid and invalid data"""
file_dict = self.create_file_dict("complex_file")
file_obj = self.create_test_file("complex_obj")
# Complex nested structure
complex_data = [
file_dict, # Valid file dict
file_obj, # Valid file object
{ # Invalid dict
"not_file": "data",
"nested": {"deep": "value"},
},
[ # Nested list (should be ignored)
self.create_file_dict("nested_file")
],
"string", # Invalid string
None, # None value
42, # Invalid number
]
result = WorkflowResponseConverter._fetch_files_from_variable_value(complex_data)
assert len(result) == 2
assert result[0]["id"] == "complex_file"
assert result[1]["id"] == "complex_obj"

View File

@@ -0,0 +1,810 @@
"""
Unit tests for WorkflowResponseConverter focusing on process_data truncation functionality.
"""
import uuid
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any
from unittest.mock import Mock
import pytest
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueEvent,
QueueIterationStartEvent,
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.workflow.enums import NodeType
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from models import Account
from models.model import AppMode
class TestWorkflowResponseConverter:
"""Test truncation in WorkflowResponseConverter."""
def create_mock_generate_entity(self) -> WorkflowAppGenerateEntity:
"""Create a mock WorkflowAppGenerateEntity."""
mock_entity = Mock(spec=WorkflowAppGenerateEntity)
mock_app_config = Mock()
mock_app_config.tenant_id = "test-tenant-id"
mock_entity.invoke_from = InvokeFrom.WEB_APP
mock_entity.app_config = mock_app_config
mock_entity.inputs = {}
return mock_entity
def create_workflow_response_converter(self) -> WorkflowResponseConverter:
"""Create a WorkflowResponseConverter for testing."""
mock_entity = self.create_mock_generate_entity()
mock_user = Mock(spec=Account)
mock_user.id = "test-user-id"
mock_user.name = "Test User"
mock_user.email = "test@example.com"
system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id")
return WorkflowResponseConverter(
application_generate_entity=mock_entity,
user=mock_user,
system_variables=system_variables,
)
def create_node_started_event(self, *, node_execution_id: str | None = None) -> QueueNodeStartedEvent:
"""Create a QueueNodeStartedEvent for testing."""
return QueueNodeStartedEvent(
node_execution_id=node_execution_id or str(uuid.uuid4()),
node_id="test-node-id",
node_title="Test Node",
node_type=NodeType.CODE,
start_at=naive_utc_now(),
in_iteration_id=None,
in_loop_id=None,
provider_type="built-in",
provider_id="code",
)
def create_node_succeeded_event(
self,
*,
node_execution_id: str,
process_data: Mapping[str, Any] | None = None,
) -> QueueNodeSucceededEvent:
"""Create a QueueNodeSucceededEvent for testing."""
return QueueNodeSucceededEvent(
node_id="test-node-id",
node_type=NodeType.CODE,
node_execution_id=node_execution_id,
start_at=naive_utc_now(),
in_iteration_id=None,
in_loop_id=None,
inputs={},
process_data=process_data or {},
outputs={},
execution_metadata={},
)
def create_node_retry_event(
self,
*,
node_execution_id: str,
process_data: Mapping[str, Any] | None = None,
) -> QueueNodeRetryEvent:
"""Create a QueueNodeRetryEvent for testing."""
return QueueNodeRetryEvent(
inputs={"data": "inputs"},
outputs={"data": "outputs"},
process_data=process_data or {},
error="oops",
retry_index=1,
node_id="test-node-id",
node_type=NodeType.CODE,
node_title="test code",
provider_type="built-in",
provider_id="code",
node_execution_id=node_execution_id,
start_at=naive_utc_now(),
in_iteration_id=None,
in_loop_id=None,
)
def test_workflow_node_finish_response_uses_truncated_process_data(self):
"""Test that node finish response uses get_response_process_data()."""
converter = self.create_workflow_response_converter()
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
if mapping == dict(original_data):
return truncated_data, True
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
# Response should use truncated data, not original
assert response is not None
assert response.data.process_data == truncated_data
assert response.data.process_data != original_data
assert response.data.process_data_truncated is True
def test_workflow_node_finish_response_without_truncation(self):
"""Test node finish response when no truncation is applied."""
converter = self.create_workflow_response_converter()
original_data = {"small": "data"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
# Response should use original data
assert response is not None
assert response.data.process_data == original_data
assert response.data.process_data_truncated is False
def test_workflow_node_finish_response_with_none_process_data(self):
"""Test node finish response when process_data is None."""
converter = self.create_workflow_response_converter()
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=None,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
# Response should normalize missing process_data to an empty mapping
assert response is not None
assert response.data.process_data == {}
assert response.data.process_data_truncated is False
def test_workflow_node_retry_response_uses_truncated_process_data(self):
"""Test that node retry response uses get_response_process_data()."""
converter = self.create_workflow_response_converter()
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_retry_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
if mapping == dict(original_data):
return truncated_data, True
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_retry_to_stream_response(
event=event,
task_id="test-task-id",
)
# Response should use truncated data, not original
assert response is not None
assert response.data.process_data == truncated_data
assert response.data.process_data != original_data
assert response.data.process_data_truncated is True
def test_workflow_node_retry_response_without_truncation(self):
"""Test node retry response when no truncation is applied."""
converter = self.create_workflow_response_converter()
original_data = {"small": "data"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_retry_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_retry_to_stream_response(
event=event,
task_id="test-task-id",
)
assert response is not None
assert response.data.process_data == original_data
assert response.data.process_data_truncated is False
def test_iteration_and_loop_nodes_return_none(self):
"""Test that iteration and loop nodes return None (no streaming events)."""
converter = self.create_workflow_response_converter()
iteration_event = QueueNodeSucceededEvent(
node_id="iteration-node",
node_type=NodeType.ITERATION,
node_execution_id=str(uuid.uuid4()),
start_at=naive_utc_now(),
in_iteration_id=None,
in_loop_id=None,
inputs={},
process_data={},
outputs={},
execution_metadata={},
)
response = converter.workflow_node_finish_to_stream_response(
event=iteration_event,
task_id="test-task-id",
)
assert response is None
loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP})
response = converter.workflow_node_finish_to_stream_response(
event=loop_event,
task_id="test-task-id",
)
assert response is None
def test_finish_without_start_raises(self):
"""Ensure finish responses require a prior workflow start."""
converter = self.create_workflow_response_converter()
event = self.create_node_succeeded_event(
node_execution_id=str(uuid.uuid4()),
process_data={},
)
with pytest.raises(ValueError):
converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
@dataclass
class TestCase:
"""Test case data for table-driven tests."""
name: str
invoke_from: InvokeFrom
expected_truncation_enabled: bool
description: str
class TestWorkflowResponseConverterServiceApiTruncation:
"""Test class for Service API truncation functionality in WorkflowResponseConverter."""
def create_test_app_generate_entity(self, invoke_from: InvokeFrom) -> WorkflowAppGenerateEntity:
"""Create a test WorkflowAppGenerateEntity with specified invoke_from."""
# Create a minimal WorkflowUIBasedAppConfig for testing
app_config = WorkflowUIBasedAppConfig(
tenant_id="test_tenant",
app_id="test_app",
app_mode=AppMode.WORKFLOW,
workflow_id="test_workflow_id",
)
entity = WorkflowAppGenerateEntity(
task_id="test_task_id",
app_id="test_app_id",
app_config=app_config,
tenant_id="test_tenant",
app_mode=AppMode.WORKFLOW,
invoke_from=invoke_from,
inputs={"test_input": "test_value"},
user_id="test_user_id",
stream=True,
files=[],
workflow_execution_id="test_workflow_exec_id",
)
return entity
def create_test_user(self) -> Account:
"""Create a test user account."""
account = Account(
name="Test User",
email="test@example.com",
)
# Manually set the ID for testing purposes
account.id = "test_user_id"
return account
def create_test_system_variables(self) -> SystemVariable:
"""Create test system variables."""
return SystemVariable()
def create_test_converter(self, invoke_from: InvokeFrom) -> WorkflowResponseConverter:
"""Create WorkflowResponseConverter with specified invoke_from."""
entity = self.create_test_app_generate_entity(invoke_from)
user = self.create_test_user()
system_variables = self.create_test_system_variables()
converter = WorkflowResponseConverter(
application_generate_entity=entity,
user=user,
system_variables=system_variables,
)
# ensure `workflow_run_id` is set.
converter.workflow_start_to_stream_response(
task_id="test-task-id",
workflow_run_id="test-workflow-run-id",
workflow_id="test-workflow-id",
)
return converter
@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="service_api_truncation_disabled",
invoke_from=InvokeFrom.SERVICE_API,
expected_truncation_enabled=False,
description="Service API calls should have truncation disabled",
),
TestCase(
name="web_app_truncation_enabled",
invoke_from=InvokeFrom.WEB_APP,
expected_truncation_enabled=True,
description="Web app calls should have truncation enabled",
),
TestCase(
name="debugger_truncation_enabled",
invoke_from=InvokeFrom.DEBUGGER,
expected_truncation_enabled=True,
description="Debugger calls should have truncation enabled",
),
TestCase(
name="explore_truncation_enabled",
invoke_from=InvokeFrom.EXPLORE,
expected_truncation_enabled=True,
description="Explore calls should have truncation enabled",
),
TestCase(
name="published_truncation_enabled",
invoke_from=InvokeFrom.PUBLISHED,
expected_truncation_enabled=True,
description="Published app calls should have truncation enabled",
),
],
ids=lambda x: x.name,
)
def test_truncator_selection_based_on_invoke_from(self, test_case: TestCase):
"""Test that the correct truncator is selected based on invoke_from."""
converter = self.create_test_converter(test_case.invoke_from)
# Test truncation behavior instead of checking private attribute
# Create a test event with large data
large_value = {"key": ["x"] * 2000} # Large data that would be truncated
event = QueueNodeSucceededEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
start_at=naive_utc_now(),
inputs=large_value,
process_data=large_value,
outputs=large_value,
error=None,
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test_task",
)
# Verify response is not None
assert response is not None
# Verify truncation behavior matches expectations
if test_case.expected_truncation_enabled:
# Truncation should be enabled for non-service-api calls
assert response.data.inputs_truncated
assert response.data.process_data_truncated
assert response.data.outputs_truncated
else:
# SERVICE_API should not truncate
assert not response.data.inputs_truncated
assert not response.data.process_data_truncated
assert not response.data.outputs_truncated
def test_service_api_truncator_no_op_mapping(self):
"""Test that Service API truncator doesn't truncate variable mappings."""
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
# Create a test event with large data
large_value: dict[str, Any] = {
"large_string": "x" * 10000, # Large string
"large_list": list(range(2000)), # Large array
"nested_data": {"deep_nested": {"very_deep": {"value": "x" * 5000}}},
}
event = QueueNodeSucceededEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
start_at=naive_utc_now(),
inputs=large_value,
process_data=large_value,
outputs=large_value,
error=None,
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test_task",
)
# Verify response is not None
data = response.data
assert data.inputs == large_value
assert data.process_data == large_value
assert data.outputs == large_value
# Service API should not truncate
assert data.inputs_truncated is False
assert data.process_data_truncated is False
assert data.outputs_truncated is False
def test_web_app_truncator_works_normally(self):
"""Test that web app truncator still works normally."""
converter = self.create_test_converter(InvokeFrom.WEB_APP)
# Create a test event with large data
large_value = {
"large_string": "x" * 10000, # Large string
"large_list": list(range(2000)), # Large array
}
event = QueueNodeSucceededEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
start_at=naive_utc_now(),
inputs=large_value,
process_data=large_value,
outputs=large_value,
error=None,
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test_task",
)
# Verify response is not None
assert response is not None
# Web app should truncate
data = response.data
assert data.inputs != large_value
assert data.process_data != large_value
assert data.outputs != large_value
# The exact behavior depends on VariableTruncator implementation
# Just verify that truncation flags are present
assert data.inputs_truncated is True
assert data.process_data_truncated is True
assert data.outputs_truncated is True
@staticmethod
def _create_event_by_type(
type_: QueueEvent, inputs: Mapping[str, Any], process_data: Mapping[str, Any], outputs: Mapping[str, Any]
) -> QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent:
if type_ == QueueEvent.NODE_SUCCEEDED:
return QueueNodeSucceededEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
start_at=naive_utc_now(),
inputs=inputs,
process_data=process_data,
outputs=outputs,
error=None,
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
)
elif type_ == QueueEvent.NODE_FAILED:
return QueueNodeFailedEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
start_at=naive_utc_now(),
inputs=inputs,
process_data=process_data,
outputs=outputs,
error="oops",
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
)
elif type_ == QueueEvent.NODE_EXCEPTION:
return QueueNodeExceptionEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
start_at=naive_utc_now(),
inputs=inputs,
process_data=process_data,
outputs=outputs,
error="oops",
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
)
else:
raise Exception("unknown type.")
@pytest.mark.parametrize(
"event_type",
[
QueueEvent.NODE_SUCCEEDED,
QueueEvent.NODE_FAILED,
QueueEvent.NODE_EXCEPTION,
],
)
def test_service_api_node_finish_event_no_truncation(self, event_type: QueueEvent):
"""Test that Service API doesn't truncate node finish events."""
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
# Create test event with large data
large_inputs = {"input1": "x" * 5000, "input2": list(range(2000))}
large_process_data = {"process1": "y" * 5000, "process2": {"nested": ["z"] * 2000}}
large_outputs = {"output1": "result" * 1000, "output2": list(range(2000))}
event = TestWorkflowResponseConverterServiceApiTruncation._create_event_by_type(
event_type, large_inputs, large_process_data, large_outputs
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test_task",
)
# Verify response is not None
assert response is not None
# Verify response contains full data (not truncated)
assert response.data.inputs == large_inputs
assert response.data.process_data == large_process_data
assert response.data.outputs == large_outputs
assert not response.data.inputs_truncated
assert not response.data.process_data_truncated
assert not response.data.outputs_truncated
def test_service_api_node_retry_event_no_truncation(self):
"""Test that Service API doesn't truncate node retry events."""
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
# Create test event with large data
large_inputs = {"retry_input": "x" * 5000}
large_process_data = {"retry_process": "y" * 5000}
large_outputs = {"retry_output": "z" * 5000}
# First, we need to store a snapshot by simulating a start event
start_event = QueueNodeStartedEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
node_title="Test Node",
node_run_index=1,
start_at=naive_utc_now(),
in_iteration_id=None,
in_loop_id=None,
agent_strategy=None,
provider_type="plugin",
provider_id="test/test_plugin",
)
converter.workflow_node_start_to_stream_response(event=start_event, task_id="test_task")
# Now create retry event
event = QueueNodeRetryEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
node_title="Test Node",
node_run_index=1,
start_at=naive_utc_now(),
inputs=large_inputs,
process_data=large_process_data,
outputs=large_outputs,
error="Retry error",
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
retry_index=1,
provider_type="plugin",
provider_id="test/test_plugin",
)
response = converter.workflow_node_retry_to_stream_response(
event=event,
task_id="test_task",
)
# Verify response is not None
assert response is not None
# Verify response contains full data (not truncated)
assert response.data.inputs == large_inputs
assert response.data.process_data == large_process_data
assert response.data.outputs == large_outputs
assert not response.data.inputs_truncated
assert not response.data.process_data_truncated
assert not response.data.outputs_truncated
def test_service_api_iteration_events_no_truncation(self):
"""Test that Service API doesn't truncate iteration events."""
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
# Test iteration start event
large_value = {"iteration_input": ["x"] * 2000}
start_event = QueueIterationStartEvent(
node_execution_id="test_iter_exec_id",
node_id="test_iteration",
node_type=NodeType.ITERATION,
node_title="Test Iteration",
node_run_index=0,
start_at=naive_utc_now(),
inputs=large_value,
metadata={},
)
response = converter.workflow_iteration_start_to_stream_response(
task_id="test_task",
workflow_execution_id="test_workflow_exec_id",
event=start_event,
)
assert response is not None
assert response.data.inputs == large_value
assert not response.data.inputs_truncated
def test_service_api_loop_events_no_truncation(self):
"""Test that Service API doesn't truncate loop events."""
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
# Test loop start event
large_inputs = {"loop_input": ["x"] * 2000}
start_event = QueueLoopStartEvent(
node_execution_id="test_loop_exec_id",
node_id="test_loop",
node_type=NodeType.LOOP,
node_title="Test Loop",
start_at=naive_utc_now(),
inputs=large_inputs,
metadata={},
node_run_index=0,
)
response = converter.workflow_loop_start_to_stream_response(
task_id="test_task",
workflow_execution_id="test_workflow_exec_id",
event=start_event,
)
assert response is not None
assert response.data.inputs == large_inputs
assert not response.data.inputs_truncated
def test_web_app_node_finish_event_truncation_works(self):
"""Test that web app still truncates node finish events."""
converter = self.create_test_converter(InvokeFrom.WEB_APP)
# Create test event with large data that should be truncated
large_inputs = {"input1": ["x"] * 2000}
large_process_data = {"process1": ["y"] * 2000}
large_outputs = {"output1": ["z"] * 2000}
event = QueueNodeSucceededEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
start_at=naive_utc_now(),
inputs=large_inputs,
process_data=large_process_data,
outputs=large_outputs,
error=None,
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test_task",
)
# Verify response is not None
assert response is not None
# Verify response contains truncated data
# The exact behavior depends on VariableTruncator implementation
# Just verify truncation flags are set correctly (may or may not be truncated depending on size)
# At minimum, the truncation mechanism should work
assert isinstance(response.data.inputs, dict)
assert response.data.inputs_truncated
assert isinstance(response.data.process_data, dict)
assert response.data.process_data_truncated
assert isinstance(response.data.outputs, dict)
assert response.data.outputs_truncated

View File

@@ -0,0 +1,267 @@
import pytest
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.base_app_generator import BaseAppGenerator
def test_validate_inputs_with_zero():
base_app_generator = BaseAppGenerator()
var = VariableEntity(
variable="test_var",
label="test_var",
type=VariableEntityType.NUMBER,
required=True,
)
# Test with input 0
result = base_app_generator._validate_inputs(
variable_entity=var,
value=0,
)
assert result == 0
# Test with input "0" (string)
result = base_app_generator._validate_inputs(
variable_entity=var,
value="0",
)
assert result == 0
def test_validate_input_with_none_for_required_variable():
base_app_generator = BaseAppGenerator()
for var_type in VariableEntityType:
var = VariableEntity(
variable="test_var",
label="test_var",
type=var_type,
required=True,
)
# Test with input None
with pytest.raises(ValueError) as exc_info:
base_app_generator._validate_inputs(
variable_entity=var,
value=None,
)
assert str(exc_info.value) == "test_var is required in input form"
def test_validate_inputs_with_default_value():
"""Test that default values are used when input is None for optional variables"""
base_app_generator = BaseAppGenerator()
# Test with string default value for TEXT_INPUT
var_string = VariableEntity(
variable="test_var",
label="test_var",
type=VariableEntityType.TEXT_INPUT,
required=False,
default="default_string",
)
result = base_app_generator._validate_inputs(
variable_entity=var_string,
value=None,
)
assert result == "default_string"
# Test with string default value for PARAGRAPH
var_paragraph = VariableEntity(
variable="test_paragraph",
label="test_paragraph",
type=VariableEntityType.PARAGRAPH,
required=False,
default="default paragraph text",
)
result = base_app_generator._validate_inputs(
variable_entity=var_paragraph,
value=None,
)
assert result == "default paragraph text"
# Test with SELECT default value
var_select = VariableEntity(
variable="test_select",
label="test_select",
type=VariableEntityType.SELECT,
required=False,
default="option1",
options=["option1", "option2", "option3"],
)
result = base_app_generator._validate_inputs(
variable_entity=var_select,
value=None,
)
assert result == "option1"
# Test with number default value (int)
var_number_int = VariableEntity(
variable="test_number_int",
label="test_number_int",
type=VariableEntityType.NUMBER,
required=False,
default=42,
)
result = base_app_generator._validate_inputs(
variable_entity=var_number_int,
value=None,
)
assert result == 42
# Test with number default value (float)
var_number_float = VariableEntity(
variable="test_number_float",
label="test_number_float",
type=VariableEntityType.NUMBER,
required=False,
default=3.14,
)
result = base_app_generator._validate_inputs(
variable_entity=var_number_float,
value=None,
)
assert result == 3.14
# Test with number default value as string (frontend sends as string)
var_number_string = VariableEntity(
variable="test_number_string",
label="test_number_string",
type=VariableEntityType.NUMBER,
required=False,
default="123",
)
result = base_app_generator._validate_inputs(
variable_entity=var_number_string,
value=None,
)
assert result == 123
assert isinstance(result, int)
# Test with float number default value as string
var_number_float_string = VariableEntity(
variable="test_number_float_string",
label="test_number_float_string",
type=VariableEntityType.NUMBER,
required=False,
default="45.67",
)
result = base_app_generator._validate_inputs(
variable_entity=var_number_float_string,
value=None,
)
assert result == 45.67
assert isinstance(result, float)
# Test with CHECKBOX default value (bool)
var_checkbox_true = VariableEntity(
variable="test_checkbox_true",
label="test_checkbox_true",
type=VariableEntityType.CHECKBOX,
required=False,
default=True,
)
result = base_app_generator._validate_inputs(
variable_entity=var_checkbox_true,
value=None,
)
assert result is True
var_checkbox_false = VariableEntity(
variable="test_checkbox_false",
label="test_checkbox_false",
type=VariableEntityType.CHECKBOX,
required=False,
default=False,
)
result = base_app_generator._validate_inputs(
variable_entity=var_checkbox_false,
value=None,
)
assert result is False
# Test with None as explicit default value
var_none_default = VariableEntity(
variable="test_none",
label="test_none",
type=VariableEntityType.TEXT_INPUT,
required=False,
default=None,
)
result = base_app_generator._validate_inputs(
variable_entity=var_none_default,
value=None,
)
assert result is None
# Test that actual input value takes precedence over default
result = base_app_generator._validate_inputs(
variable_entity=var_string,
value="actual_value",
)
assert result == "actual_value"
# Test that actual number input takes precedence over default
result = base_app_generator._validate_inputs(
variable_entity=var_number_int,
value=999,
)
assert result == 999
# Test with FILE default value (dict format from frontend)
var_file = VariableEntity(
variable="test_file",
label="test_file",
type=VariableEntityType.FILE,
required=False,
default={"id": "file123", "name": "default.pdf"},
)
result = base_app_generator._validate_inputs(
variable_entity=var_file,
value=None,
)
assert result == {"id": "file123", "name": "default.pdf"}
# Test with FILE_LIST default value (list of dicts)
var_file_list = VariableEntity(
variable="test_file_list",
label="test_file_list",
type=VariableEntityType.FILE_LIST,
required=False,
default=[{"id": "file1", "name": "doc1.pdf"}, {"id": "file2", "name": "doc2.pdf"}],
)
result = base_app_generator._validate_inputs(
variable_entity=var_file_list,
value=None,
)
assert result == [{"id": "file1", "name": "doc1.pdf"}, {"id": "file2", "name": "doc2.pdf"}]

View File

@@ -0,0 +1,19 @@
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
def test_should_prepare_user_inputs_defaults_to_true():
args = {"inputs": {}}
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
def test_should_prepare_user_inputs_skips_when_flag_truthy():
args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: True}
assert not WorkflowAppGenerator()._should_prepare_user_inputs(args)
def test_should_prepare_user_inputs_keeps_validation_when_flag_false():
args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False}
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)

View File

@@ -0,0 +1,124 @@
import time
from unittest.mock import MagicMock, patch
import pytest
from core.app.features.rate_limiting.rate_limit import RateLimit
@pytest.fixture
def mock_redis():
"""Mock Redis client with realistic behavior for rate limiting tests."""
mock_client = MagicMock()
# Redis data storage for simulation
mock_data = {}
mock_hashes = {}
mock_expiry = {}
def mock_setex(key, ttl, value):
mock_data[key] = str(value)
mock_expiry[key] = time.time() + ttl.total_seconds() if hasattr(ttl, "total_seconds") else time.time() + ttl
return True
def mock_get(key):
if key in mock_data and (key not in mock_expiry or time.time() < mock_expiry[key]):
return mock_data[key].encode("utf-8")
return None
def mock_exists(key):
return key in mock_data or key in mock_hashes
def mock_expire(key, ttl):
if key in mock_data or key in mock_hashes:
mock_expiry[key] = time.time() + ttl.total_seconds() if hasattr(ttl, "total_seconds") else time.time() + ttl
return True
def mock_hset(key, field, value):
if key not in mock_hashes:
mock_hashes[key] = {}
mock_hashes[key][field] = str(value).encode("utf-8")
return True
def mock_hgetall(key):
return mock_hashes.get(key, {})
def mock_hdel(key, *fields):
if key in mock_hashes:
count = 0
for field in fields:
if field in mock_hashes[key]:
del mock_hashes[key][field]
count += 1
return count
return 0
def mock_hlen(key):
return len(mock_hashes.get(key, {}))
# Configure mock methods
mock_client.setex = mock_setex
mock_client.get = mock_get
mock_client.exists = mock_exists
mock_client.expire = mock_expire
mock_client.hset = mock_hset
mock_client.hgetall = mock_hgetall
mock_client.hdel = mock_hdel
mock_client.hlen = mock_hlen
# Store references for test verification
mock_client._mock_data = mock_data
mock_client._mock_hashes = mock_hashes
mock_client._mock_expiry = mock_expiry
return mock_client
@pytest.fixture
def mock_time():
"""Mock time.time() for deterministic tests."""
mock_time_val = 1000.0
def increment_time(seconds=1):
nonlocal mock_time_val
mock_time_val += seconds
return mock_time_val
with patch("time.time", return_value=mock_time_val) as mock:
mock.increment = increment_time
yield mock
@pytest.fixture
def sample_generator():
"""Sample generator for testing RateLimitGenerator."""
def _create_generator(items=None, raise_error=False):
items = items or ["item1", "item2", "item3"]
for item in items:
if raise_error and item == "item2":
raise ValueError("Test error")
yield item
return _create_generator
@pytest.fixture
def sample_mapping():
"""Sample mapping for testing RateLimitGenerator."""
return {"key1": "value1", "key2": "value2"}
@pytest.fixture(autouse=True)
def reset_rate_limit_instances():
"""Clear RateLimit singleton instances between tests."""
RateLimit._instance_dict.clear()
yield
RateLimit._instance_dict.clear()
@pytest.fixture
def redis_patch():
"""Patch redis_client globally for rate limit tests."""
with patch("core.app.features.rate_limiting.rate_limit.redis_client") as mock:
yield mock

View File

@@ -0,0 +1,569 @@
import threading
import time
from datetime import timedelta
from unittest.mock import patch
import pytest
from core.app.features.rate_limiting.rate_limit import RateLimit
from core.errors.error import AppInvokeQuotaExceededError
class TestRateLimit:
"""Core rate limiting functionality tests."""
def test_should_return_same_instance_for_same_client_id(self, redis_patch):
"""Test singleton behavior for same client ID."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
rate_limit1 = RateLimit("client1", 5)
rate_limit2 = RateLimit("client1", 10) # Second instance with different limit
assert rate_limit1 is rate_limit2
# Current implementation: last constructor call overwrites max_active_requests
# This reflects the actual behavior where __init__ always sets max_active_requests
assert rate_limit1.max_active_requests == 10
def test_should_create_different_instances_for_different_client_ids(self, redis_patch):
"""Test different instances for different client IDs."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
rate_limit1 = RateLimit("client1", 5)
rate_limit2 = RateLimit("client2", 10)
assert rate_limit1 is not rate_limit2
assert rate_limit1.client_id == "client1"
assert rate_limit2.client_id == "client2"
def test_should_initialize_with_valid_parameters(self, redis_patch):
"""Test normal initialization."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
assert rate_limit.client_id == "test_client"
assert rate_limit.max_active_requests == 5
assert hasattr(rate_limit, "initialized")
redis_patch.setex.assert_called_once()
def test_should_skip_initialization_if_disabled(self):
"""Test no initialization when rate limiting is disabled."""
rate_limit = RateLimit("test_client", 0)
assert rate_limit.disabled()
assert not hasattr(rate_limit, "initialized")
def test_should_skip_reinitialization_of_existing_instance(self, redis_patch):
"""Test that existing instance doesn't reinitialize."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
RateLimit("client1", 5)
redis_patch.reset_mock()
RateLimit("client1", 10)
redis_patch.setex.assert_not_called()
def test_should_be_disabled_when_max_requests_is_zero_or_negative(self):
"""Test disabled state for zero or negative limits."""
rate_limit_zero = RateLimit("client1", 0)
rate_limit_negative = RateLimit("client2", -5)
assert rate_limit_zero.disabled()
assert rate_limit_negative.disabled()
def test_should_set_redis_keys_on_first_flush(self, redis_patch):
"""Test Redis keys are set correctly on initial flush."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
expected_max_key = "dify:rate_limit:test_client:max_active_requests"
redis_patch.setex.assert_called_with(expected_max_key, timedelta(days=1), 5)
def test_should_sync_max_requests_from_redis_on_subsequent_flush(self, redis_patch):
"""Test max requests syncs from Redis when key exists."""
redis_patch.configure_mock(
**{
"exists.return_value": True,
"get.return_value": b"10",
"expire.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
rate_limit.flush_cache()
assert rate_limit.max_active_requests == 10
@patch("time.time")
def test_should_clean_timeout_requests_from_active_list(self, mock_time, redis_patch):
"""Test cleanup of timed-out requests."""
current_time = 1000.0
mock_time.return_value = current_time
# Setup mock Redis with timed-out requests
timeout_requests = {
b"req1": str(current_time - 700).encode(), # 700 seconds ago (timeout)
b"req2": str(current_time - 100).encode(), # 100 seconds ago (active)
}
redis_patch.configure_mock(
**{
"exists.return_value": True,
"get.return_value": b"5",
"expire.return_value": True,
"hgetall.return_value": timeout_requests,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
redis_patch.reset_mock() # Reset to avoid counting initialization calls
rate_limit.flush_cache()
# Verify timeout request was cleaned up
redis_patch.hdel.assert_called_once()
call_args = redis_patch.hdel.call_args[0]
assert call_args[0] == "dify:rate_limit:test_client:active_requests"
assert b"req1" in call_args # Timeout request should be removed
assert b"req2" not in call_args # Active request should remain
class TestRateLimitEnterExit:
"""Rate limiting enter/exit logic tests."""
def test_should_allow_request_within_limit(self, redis_patch):
"""Test allowing requests within the rate limit."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 2,
"hset.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
request_id = rate_limit.enter()
assert request_id != RateLimit._UNLIMITED_REQUEST_ID
redis_patch.hset.assert_called_once()
def test_should_generate_request_id_if_not_provided(self, redis_patch):
"""Test auto-generation of request ID."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 0,
"hset.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
request_id = rate_limit.enter()
assert len(request_id) == 36 # UUID format
def test_should_use_provided_request_id(self, redis_patch):
"""Test using provided request ID."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 0,
"hset.return_value": True,
}
)
rate_limit = RateLimit("test_client", 5)
custom_id = "custom_request_123"
request_id = rate_limit.enter(custom_id)
assert request_id == custom_id
def test_should_remove_request_on_exit(self, redis_patch):
"""Test request removal on exit."""
redis_patch.configure_mock(
**{
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
rate_limit.exit("test_request_id")
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", "test_request_id")
def test_should_raise_quota_exceeded_when_at_limit(self, redis_patch):
"""Test quota exceeded error when at limit."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 5, # At limit
}
)
rate_limit = RateLimit("test_client", 5)
with pytest.raises(AppInvokeQuotaExceededError) as exc_info:
rate_limit.enter()
assert "Too many requests" in str(exc_info.value)
assert "test_client" in str(exc_info.value)
def test_should_allow_request_after_previous_exit(self, redis_patch):
"""Test allowing new request after previous exit."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 4, # Under limit after exit
"hset.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
request_id = rate_limit.enter()
rate_limit.exit(request_id)
new_request_id = rate_limit.enter()
assert new_request_id is not None
@patch("time.time")
def test_should_flush_cache_when_interval_exceeded(self, mock_time, redis_patch):
"""Test cache flush when time interval exceeded."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.return_value": 0,
}
)
mock_time.return_value = 1000.0
rate_limit = RateLimit("test_client", 5)
# Advance time beyond flush interval
mock_time.return_value = 1400.0 # 400 seconds later
redis_patch.reset_mock()
rate_limit.enter()
# Should have called setex again due to cache flush
redis_patch.setex.assert_called()
def test_should_return_unlimited_id_when_disabled(self):
"""Test unlimited ID return when rate limiting disabled."""
rate_limit = RateLimit("test_client", 0)
request_id = rate_limit.enter()
assert request_id == RateLimit._UNLIMITED_REQUEST_ID
def test_should_ignore_exit_for_unlimited_requests(self, redis_patch):
"""Test ignoring exit for unlimited requests."""
rate_limit = RateLimit("test_client", 0)
rate_limit.exit(RateLimit._UNLIMITED_REQUEST_ID)
redis_patch.hdel.assert_not_called()
class TestRateLimitGenerator:
"""Rate limit generator wrapper tests."""
def test_should_wrap_generator_and_iterate_normally(self, redis_patch, sample_generator):
"""Test normal generator iteration with rate limit wrapper."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
generator = sample_generator()
request_id = "test_request"
wrapped_gen = rate_limit.generate(generator, request_id)
result = list(wrapped_gen)
assert result == ["item1", "item2", "item3"]
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", request_id)
def test_should_handle_mapping_input_directly(self, sample_mapping):
"""Test direct return of mapping input."""
rate_limit = RateLimit("test_client", 0) # Disabled
result = rate_limit.generate(sample_mapping, "test_request")
assert result is sample_mapping
def test_should_cleanup_on_exception_during_iteration(self, redis_patch, sample_generator):
"""Test cleanup when exception occurs during iteration."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
generator = sample_generator(raise_error=True)
request_id = "test_request"
wrapped_gen = rate_limit.generate(generator, request_id)
with pytest.raises(ValueError):
list(wrapped_gen)
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", request_id)
def test_should_cleanup_on_explicit_close(self, redis_patch, sample_generator):
"""Test cleanup on explicit generator close."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
generator = sample_generator()
request_id = "test_request"
wrapped_gen = rate_limit.generate(generator, request_id)
wrapped_gen.close()
redis_patch.hdel.assert_called_once()
def test_should_handle_generator_without_close_method(self, redis_patch):
"""Test handling generator without close method."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
# Create a generator-like object without close method
class SimpleGenerator:
def __init__(self):
self.items = ["test"]
self.index = 0
def __iter__(self):
return self
def __next__(self):
if self.index >= len(self.items):
raise StopIteration
item = self.items[self.index]
self.index += 1
return item
rate_limit = RateLimit("test_client", 5)
generator = SimpleGenerator()
wrapped_gen = rate_limit.generate(generator, "test_request")
wrapped_gen.close() # Should not raise error
redis_patch.hdel.assert_called_once()
def test_should_prevent_iteration_after_close(self, redis_patch, sample_generator):
"""Test StopIteration after generator is closed."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hdel.return_value": 1,
}
)
rate_limit = RateLimit("test_client", 5)
generator = sample_generator()
wrapped_gen = rate_limit.generate(generator, "test_request")
wrapped_gen.close()
with pytest.raises(StopIteration):
next(wrapped_gen)
class TestRateLimitConcurrency:
"""Concurrent access safety tests."""
def test_should_handle_concurrent_instance_creation(self, redis_patch):
"""Test thread-safe singleton instance creation."""
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
}
)
instances = []
errors = []
def create_instance():
try:
instance = RateLimit("concurrent_client", 5)
instances.append(instance)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=create_instance) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0
assert len({id(inst) for inst in instances}) == 1 # All same instance
def test_should_handle_concurrent_enter_requests(self, redis_patch):
"""Test concurrent enter requests handling."""
# Setup mock to simulate realistic Redis behavior
request_count = 0
def mock_hlen(key):
nonlocal request_count
return request_count
def mock_hset(key, field, value):
nonlocal request_count
request_count += 1
return True
redis_patch.configure_mock(
**{
"exists.return_value": False,
"setex.return_value": True,
"hlen.side_effect": mock_hlen,
"hset.side_effect": mock_hset,
}
)
rate_limit = RateLimit("concurrent_client", 3)
results = []
errors = []
def try_enter():
try:
request_id = rate_limit.enter()
results.append(request_id)
except AppInvokeQuotaExceededError as e:
errors.append(e)
threads = [threading.Thread(target=try_enter) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
# Should have some successful requests and some quota exceeded
assert len(results) + len(errors) == 5
assert len(errors) > 0 # Some should be rejected
@patch("time.time")
def test_should_maintain_accurate_count_under_load(self, mock_time, redis_patch):
"""Test accurate count maintenance under concurrent load."""
mock_time.return_value = 1000.0
# Use real mock_redis fixture for better simulation
mock_client = self._create_mock_redis()
redis_patch.configure_mock(**mock_client)
rate_limit = RateLimit("load_test_client", 10)
active_requests = []
def enter_and_exit():
try:
request_id = rate_limit.enter()
active_requests.append(request_id)
time.sleep(0.01) # Simulate some work
rate_limit.exit(request_id)
active_requests.remove(request_id)
except AppInvokeQuotaExceededError:
pass # Expected under load
threads = [threading.Thread(target=enter_and_exit) for _ in range(20)]
for t in threads:
t.start()
for t in threads:
t.join()
# All requests should have been cleaned up
assert len(active_requests) == 0
def _create_mock_redis(self):
"""Create a thread-safe mock Redis for concurrency tests."""
import threading
lock = threading.Lock()
data = {}
hashes = {}
def mock_hlen(key):
with lock:
return len(hashes.get(key, {}))
def mock_hset(key, field, value):
with lock:
if key not in hashes:
hashes[key] = {}
hashes[key][field] = str(value).encode("utf-8")
return True
def mock_hdel(key, *fields):
with lock:
if key in hashes:
count = 0
for field in fields:
if field in hashes[key]:
del hashes[key][field]
count += 1
return count
return 0
return {
"exists.return_value": False,
"setex.return_value": True,
"hlen.side_effect": mock_hlen,
"hset.side_effect": mock_hset,
"hdel.side_effect": mock_hdel,
}

View File

@@ -0,0 +1,410 @@
import json
from time import time
from unittest.mock import Mock
import pytest
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import (
PauseStatePersistenceLayer,
WorkflowResumptionContext,
_AdvancedChatAppGenerateEntityWrapper,
_WorkflowGenerateEntityWrapper,
)
from core.variables.segments import Segment
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
from core.workflow.graph_events.graph import (
GraphRunFailedEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory
class TestDataFactory:
"""Factory helpers for constructing graph events used in tests."""
@staticmethod
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
@staticmethod
def create_graph_run_started_event() -> GraphRunStartedEvent:
return GraphRunStartedEvent()
@staticmethod
def create_graph_run_succeeded_event(outputs: dict[str, object] | None = None) -> GraphRunSucceededEvent:
return GraphRunSucceededEvent(outputs=outputs or {})
@staticmethod
def create_graph_run_failed_event(
error: str = "Test error",
exceptions_count: int = 1,
) -> GraphRunFailedEvent:
return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count)
class MockSystemVariableReadOnlyView:
"""Minimal read-only system variable view for testing."""
def __init__(self, workflow_execution_id: str | None = None) -> None:
self._workflow_execution_id = workflow_execution_id
@property
def workflow_execution_id(self) -> str | None:
return self._workflow_execution_id
class MockReadOnlyVariablePool:
"""Mock implementation of ReadOnlyVariablePool for testing."""
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
self._variables = variables or {}
def get(self, node_id: str, variable_key: str) -> Segment | None:
value = self._variables.get((node_id, variable_key))
if value is None:
return None
mock_segment = Mock(spec=Segment)
mock_segment.value = value
return mock_segment
def get_all_by_node(self, node_id: str) -> dict[str, object]:
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
def get_by_prefix(self, prefix: str) -> dict[str, object]:
return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)}
class MockReadOnlyGraphRuntimeState:
"""Mock implementation of ReadOnlyGraphRuntimeState for testing."""
def __init__(
self,
start_at: float | None = None,
total_tokens: int = 0,
node_run_steps: int = 0,
ready_queue_size: int = 0,
exceptions_count: int = 0,
outputs: dict[str, object] | None = None,
variables: dict[tuple[str, str], object] | None = None,
workflow_execution_id: str | None = None,
):
self._start_at = start_at or time()
self._total_tokens = total_tokens
self._node_run_steps = node_run_steps
self._ready_queue_size = ready_queue_size
self._exceptions_count = exceptions_count
self._outputs = outputs or {}
self._variable_pool = MockReadOnlyVariablePool(variables)
self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id)
@property
def system_variable(self) -> MockSystemVariableReadOnlyView:
return self._system_variable
@property
def variable_pool(self) -> ReadOnlyVariablePool:
return self._variable_pool
@property
def start_at(self) -> float:
return self._start_at
@property
def total_tokens(self) -> int:
return self._total_tokens
@property
def node_run_steps(self) -> int:
return self._node_run_steps
@property
def ready_queue_size(self) -> int:
return self._ready_queue_size
@property
def exceptions_count(self) -> int:
return self._exceptions_count
@property
def outputs(self) -> dict[str, object]:
return self._outputs.copy()
@property
def llm_usage(self):
mock_usage = Mock()
mock_usage.prompt_tokens = 10
mock_usage.completion_tokens = 20
mock_usage.total_tokens = 30
return mock_usage
def get_output(self, key: str, default: object = None) -> object:
return self._outputs.get(key, default)
def dumps(self) -> str:
return json.dumps(
{
"start_at": self._start_at,
"total_tokens": self._total_tokens,
"node_run_steps": self._node_run_steps,
"ready_queue_size": self._ready_queue_size,
"exceptions_count": self._exceptions_count,
"outputs": self._outputs,
"variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()},
"workflow_execution_id": self._system_variable.workflow_execution_id,
}
)
class MockCommandChannel:
"""Mock implementation of CommandChannel for testing."""
def __init__(self):
self._commands: list[GraphEngineCommand] = []
def fetch_commands(self) -> list[GraphEngineCommand]:
return self._commands.copy()
def send_command(self, command: GraphEngineCommand) -> None:
self._commands.append(command)
class TestPauseStatePersistenceLayer:
"""Unit tests for PauseStatePersistenceLayer."""
@staticmethod
def _create_generate_entity(workflow_execution_id: str = "run-123") -> WorkflowAppGenerateEntity:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-123",
app_id="app-123",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-123",
)
return WorkflowAppGenerateEntity(
task_id="task-123",
app_config=app_config,
inputs={},
files=[],
user_id="user-123",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_id=workflow_execution_id,
)
def test_init_with_dependency_injection(self):
session_factory = Mock(name="session_factory")
state_owner_user_id = "user-123"
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id=state_owner_user_id,
generate_entity=self._create_generate_entity(),
)
assert layer._session_maker is session_factory
assert layer._state_owner_user_id == state_owner_user_id
assert not hasattr(layer, "graph_runtime_state")
assert not hasattr(layer, "command_channel")
def test_initialize_sets_dependencies(self):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner",
generate_entity=self._create_generate_entity(),
)
graph_runtime_state = MockReadOnlyGraphRuntimeState()
command_channel = MockCommandChannel()
layer.initialize(graph_runtime_state, command_channel)
assert layer.graph_runtime_state is graph_runtime_state
assert layer.command_channel is command_channel
def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
generate_entity = self._create_generate_entity(workflow_execution_id="run-123")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=generate_entity,
)
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
graph_runtime_state = MockReadOnlyGraphRuntimeState(
outputs={"result": "test_output"},
total_tokens=100,
workflow_execution_id="run-123",
)
command_channel = MockCommandChannel()
layer.initialize(graph_runtime_state, command_channel)
event = TestDataFactory.create_graph_run_paused_event(outputs={"intermediate": "result"})
expected_state = graph_runtime_state.dumps()
layer.on_event(event)
mock_factory.assert_called_once_with(session_factory)
mock_repo.create_workflow_pause.assert_called_once_with(
workflow_run_id="run-123",
state_owner_user_id="owner-123",
state=mock_repo.create_workflow_pause.call_args.kwargs["state"],
)
serialized_state = mock_repo.create_workflow_pause.call_args.kwargs["state"]
resumption_context = WorkflowResumptionContext.loads(serialized_state)
assert resumption_context.serialized_graph_runtime_state == expected_state
assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump()
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=self._create_generate_entity(),
)
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
graph_runtime_state = MockReadOnlyGraphRuntimeState()
command_channel = MockCommandChannel()
layer.initialize(graph_runtime_state, command_channel)
events = [
TestDataFactory.create_graph_run_started_event(),
TestDataFactory.create_graph_run_succeeded_event(),
TestDataFactory.create_graph_run_failed_event(),
]
for event in events:
layer.on_event(event)
mock_factory.assert_not_called()
mock_repo.create_workflow_pause.assert_not_called()
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=self._create_generate_entity(),
)
event = TestDataFactory.create_graph_run_paused_event()
with pytest.raises(AttributeError):
layer.on_event(event)
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=self._create_generate_entity(),
)
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
graph_runtime_state = MockReadOnlyGraphRuntimeState(workflow_execution_id=None)
command_channel = MockCommandChannel()
layer.initialize(graph_runtime_state, command_channel)
event = TestDataFactory.create_graph_run_paused_event()
with pytest.raises(AssertionError):
layer.on_event(event)
mock_factory.assert_not_called()
mock_repo.create_workflow_pause.assert_not_called()
def _build_workflow_generate_entity_for_roundtrip() -> WorkflowResumptionContext:
"""Create a WorkflowAppGenerateEntity with realistic data for WorkflowResumptionContext tests."""
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-roundtrip",
app_id="app-roundtrip",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-roundtrip",
)
serialized_state = json.dumps({"state": "workflow"})
return WorkflowResumptionContext(
serialized_graph_runtime_state=serialized_state,
generate_entity=_WorkflowGenerateEntityWrapper(
entity=WorkflowAppGenerateEntity(
task_id="workflow-task",
app_config=app_config,
inputs={"input_key": "input_value"},
files=[],
user_id="user-roundtrip",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_id="workflow-exec-roundtrip",
)
),
)
def _build_advanced_chat_generate_entity_for_roundtrip() -> WorkflowResumptionContext:
"""Create an AdvancedChatAppGenerateEntity with realistic data for WorkflowResumptionContext tests."""
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-advanced",
app_id="app-advanced",
app_mode=AppMode.ADVANCED_CHAT,
workflow_id="workflow-advanced",
)
serialized_state = json.dumps({"state": "workflow"})
return WorkflowResumptionContext(
serialized_graph_runtime_state=serialized_state,
generate_entity=_AdvancedChatAppGenerateEntityWrapper(
entity=AdvancedChatAppGenerateEntity(
task_id="advanced-task",
app_config=app_config,
inputs={"topic": "roundtrip"},
files=[],
user_id="advanced-user",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_run_id="advanced-run-id",
query="Explain serialization behavior",
)
),
)
@pytest.mark.parametrize(
"state",
[
pytest.param(
_build_advanced_chat_generate_entity_for_roundtrip(),
id="advanced_chat",
),
pytest.param(
_build_workflow_generate_entity_for_roundtrip(),
id="workflow",
),
],
)
def test_workflow_resumption_context_dumps_loads_roundtrip(state: WorkflowResumptionContext):
"""WorkflowResumptionContext roundtrip preserves workflow generate entity metadata."""
dumped = state.dumps()
loaded = WorkflowResumptionContext.loads(dumped)
assert loaded == state
assert loaded.serialized_graph_runtime_state == state.serialized_graph_runtime_state
restored_entity = loaded.get_generate_entity()
assert isinstance(restored_entity, type(state.generate_entity.entity))

View File

@@ -0,0 +1,54 @@
from core.file import File, FileTransferMethod, FileType
def test_file():
file = File(
id="test-file",
tenant_id="test-tenant-id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="test-related-id",
filename="image.png",
extension=".png",
mime_type="image/png",
size=67,
storage_key="test-storage-key",
url="https://example.com/image.png",
)
assert file.tenant_id == "test-tenant-id"
assert file.type == FileType.IMAGE
assert file.transfer_method == FileTransferMethod.TOOL_FILE
assert file.related_id == "test-related-id"
assert file.filename == "image.png"
assert file.extension == ".png"
assert file.mime_type == "image/png"
assert file.size == 67
def test_file_model_validate_with_legacy_fields():
"""Test `File` model can handle data containing compatibility fields."""
data = {
"id": "test-file",
"tenant_id": "test-tenant-id",
"type": "image",
"transfer_method": "tool_file",
"related_id": "test-related-id",
"filename": "image.png",
"extension": ".png",
"mime_type": "image/png",
"size": 67,
"storage_key": "test-storage-key",
"url": "https://example.com/image.png",
# Extra legacy fields
"tool_file_id": "tool-file-123",
"upload_file_id": "upload-file-456",
"datasource_file_id": "datasource-file-789",
}
# Should be able to create `File` object without raising an exception
file = File.model_validate(data)
# The File object does not have tool_file_id, upload_file_id, or datasource_file_id as attributes.
# Instead, check it does not expose unrecognized legacy fields (should raise on getattr).
for legacy_field in ("tool_file_id", "upload_file_id", "datasource_file_id"):
assert not hasattr(file, legacy_field)

View File

@@ -0,0 +1,12 @@
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
def test_get_runner_script():
code = JavascriptCodeProvider.get_default_code()
inputs = {"arg1": "hello, ", "arg2": "world!"}
script = NodeJsTemplateTransformer.assemble_runner_script(code, inputs)
script_lines = script.splitlines()
code_lines = code.splitlines()
# Check that the first lines of script are exactly the same as code
assert script_lines[: len(code_lines)] == code_lines

View File

@@ -0,0 +1,12 @@
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
def test_get_runner_script():
code = Python3CodeProvider.get_default_code()
inputs = {"arg1": "hello, ", "arg2": "world!"}
script = Python3TemplateTransformer.assemble_runner_script(code, inputs)
script_lines = script.splitlines()
code_lines = code.splitlines()
# Check that the first lines of script are exactly the same as code
assert script_lines[: len(code_lines)] == code_lines

View File

@@ -0,0 +1,280 @@
import base64
import binascii
from unittest.mock import MagicMock, patch
import pytest
from core.helper.encrypter import (
batch_decrypt_token,
decrypt_token,
encrypt_token,
get_decrypt_decoding,
obfuscated_token,
)
from libs.rsa import PrivkeyNotFoundError
class TestObfuscatedToken:
@pytest.mark.parametrize(
("token", "expected"),
[
("", ""), # Empty token
("1234567", "*" * 20), # Short token (<8 chars)
("12345678", "*" * 20), # Boundary case (8 chars)
("123456789abcdef", "123456" + "*" * 12 + "ef"), # Long token
("abc!@#$%^&*()def", "abc!@#" + "*" * 12 + "ef"), # Special chars
],
)
def test_obfuscation_logic(self, token, expected):
"""Test core obfuscation logic for various token lengths"""
assert obfuscated_token(token) == expected
def test_sensitive_data_protection(self):
"""Ensure obfuscation never reveals full sensitive data"""
token = "api_key_secret_12345"
obfuscated = obfuscated_token(token)
assert token not in obfuscated
assert "*" * 12 in obfuscated
class TestEncryptToken:
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_successful_encryption(self, mock_encrypt, mock_query):
"""Test successful token encryption"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_data"
result = encrypt_token("tenant-123", "test_token")
assert result == base64.b64encode(b"encrypted_data").decode()
mock_encrypt.assert_called_with("test_token", "mock_public_key")
@patch("models.engine.db.session.query")
def test_tenant_not_found(self, mock_query):
"""Test error when tenant doesn't exist"""
mock_query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError) as exc_info:
encrypt_token("invalid-tenant", "test_token")
assert "Tenant with id invalid-tenant not found" in str(exc_info.value)
class TestDecryptToken:
@patch("libs.rsa.decrypt")
def test_successful_decryption(self, mock_decrypt):
"""Test successful token decryption"""
mock_decrypt.return_value = "decrypted_token"
encrypted_data = base64.b64encode(b"encrypted_data").decode()
result = decrypt_token("tenant-123", encrypted_data)
assert result == "decrypted_token"
mock_decrypt.assert_called_once_with(b"encrypted_data", "tenant-123")
def test_invalid_base64(self):
"""Test handling of invalid base64 input"""
with pytest.raises(binascii.Error):
decrypt_token("tenant-123", "invalid_base64!!!")
class TestBatchDecryptToken:
@patch("libs.rsa.get_decrypt_decoding")
@patch("libs.rsa.decrypt_token_with_decoding")
def test_batch_decryption(self, mock_decrypt_with_decoding, mock_get_decoding):
"""Test batch decryption functionality"""
mock_rsa_key = MagicMock()
mock_cipher_rsa = MagicMock()
mock_get_decoding.return_value = (mock_rsa_key, mock_cipher_rsa)
# Test multiple tokens
mock_decrypt_with_decoding.side_effect = ["token1", "token2", "token3"]
tokens = [
base64.b64encode(b"encrypted1").decode(),
base64.b64encode(b"encrypted2").decode(),
base64.b64encode(b"encrypted3").decode(),
]
result = batch_decrypt_token("tenant-123", tokens)
assert result == ["token1", "token2", "token3"]
# Key should only be loaded once
mock_get_decoding.assert_called_once_with("tenant-123")
class TestGetDecryptDecoding:
@patch("extensions.ext_redis.redis_client.get")
@patch("extensions.ext_storage.storage.load")
def test_private_key_not_found(self, mock_storage_load, mock_redis_get):
"""Test error when private key file doesn't exist"""
mock_redis_get.return_value = None
mock_storage_load.side_effect = FileNotFoundError()
with pytest.raises(PrivkeyNotFoundError) as exc_info:
get_decrypt_decoding("tenant-123")
assert "Private key not found, tenant_id: tenant-123" in str(exc_info.value)
class TestEncryptDecryptIntegration:
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
@patch("libs.rsa.decrypt")
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
"""Test that encryption and decryption are consistent"""
# Setup mock tenant
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.where.return_value.first.return_value = mock_tenant
# Setup mock encryption/decryption
original_token = "test_token_123"
mock_encrypt.return_value = b"encrypted_data"
mock_decrypt.return_value = original_token
# Test encryption
encrypted = encrypt_token("tenant-123", original_token)
# Test decryption
decrypted = decrypt_token("tenant-123", encrypted)
assert decrypted == original_token
class TestSecurity:
"""Critical security tests for encryption system"""
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
"""Ensure tokens encrypted for one tenant cannot be used by another"""
# Setup mock tenant
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "tenant1_public_key"
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_for_tenant1"
# Encrypt token for tenant1
encrypted = encrypt_token("tenant-123", "sensitive_data")
# Attempt to decrypt with different tenant should fail
with patch("libs.rsa.decrypt") as mock_decrypt:
mock_decrypt.side_effect = Exception("Invalid tenant key")
with pytest.raises(Exception, match="Invalid tenant key"):
decrypt_token("different-tenant", encrypted)
@patch("libs.rsa.decrypt")
def test_tampered_ciphertext_rejection(self, mock_decrypt):
"""Detect and reject tampered ciphertext"""
valid_encrypted = base64.b64encode(b"valid_data").decode()
# Tamper with ciphertext
tampered_bytes = bytearray(base64.b64decode(valid_encrypted))
tampered_bytes[0] ^= 0xFF
tampered = base64.b64encode(bytes(tampered_bytes)).decode()
mock_decrypt.side_effect = Exception("Decryption error")
with pytest.raises(Exception, match="Decryption error"):
decrypt_token("tenant-123", tampered)
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_encryption_randomness(self, mock_encrypt, mock_query):
"""Ensure same plaintext produces different ciphertext"""
mock_tenant = MagicMock(encrypt_public_key="key")
mock_query.return_value.where.return_value.first.return_value = mock_tenant
# Different outputs for same input
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
results = [encrypt_token("tenant-123", "token") for _ in range(3)]
# All results should be different
assert len(set(results)) == 3
class TestEdgeCases:
"""Additional security-focused edge case tests"""
def test_should_handle_empty_string_in_obfuscation(self):
"""Test handling of empty string in obfuscation"""
# Test empty string (which is a valid str type)
assert obfuscated_token("") == ""
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
"""Test encryption of empty token"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_empty"
result = encrypt_token("tenant-123", "")
assert result == base64.b64encode(b"encrypted_empty").decode()
mock_encrypt.assert_called_with("", "mock_public_key")
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
"""Test tokens containing special/unicode characters"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_special"
# Test various special characters
special_tokens = [
"token\x00with\x00null", # Null bytes
"token_with_emoji_😀🎉", # Unicode emoji
"token\nwith\nnewlines", # Newlines
"token\twith\ttabs", # Tabs
"token_with_中文字符", # Chinese characters
]
for token in special_tokens:
result = encrypt_token("tenant-123", token)
assert result == base64.b64encode(b"encrypted_special").decode()
mock_encrypt.assert_called_with(token, "mock_public_key")
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
"""Test behavior when token exceeds RSA encryption limits"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.where.return_value.first.return_value = mock_tenant
# RSA 2048-bit can only encrypt ~245 bytes
# The actual limit depends on padding scheme
mock_encrypt.side_effect = ValueError("Message too long for RSA key size")
# Create a token that would exceed RSA limits
long_token = "x" * 300
with pytest.raises(ValueError, match="Message too long for RSA key size"):
encrypt_token("tenant-123", long_token)
@patch("libs.rsa.get_decrypt_decoding")
@patch("libs.rsa.decrypt_token_with_decoding")
def test_batch_decrypt_loads_key_only_once(self, mock_decrypt_with_decoding, mock_get_decoding):
"""Verify batch decryption optimization - loads key only once"""
mock_rsa_key = MagicMock()
mock_cipher_rsa = MagicMock()
mock_get_decoding.return_value = (mock_rsa_key, mock_cipher_rsa)
# Test with multiple tokens
mock_decrypt_with_decoding.side_effect = ["token1", "token2", "token3", "token4", "token5"]
tokens = [base64.b64encode(f"encrypted{i}".encode()).decode() for i in range(5)]
result = batch_decrypt_token("tenant-123", tokens)
assert result == ["token1", "token2", "token3", "token4", "token5"]
# Key should only be loaded once regardless of token count
mock_get_decoding.assert_called_once_with("tenant-123")
assert mock_decrypt_with_decoding.call_count == 5

View File

@@ -0,0 +1,52 @@
import secrets
from unittest.mock import MagicMock, patch
import pytest
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
@patch("httpx.Client.request")
def test_successful_request(mock_request):
mock_response = MagicMock()
mock_response.status_code = 200
mock_request.return_value = mock_response
response = make_request("GET", "http://example.com")
assert response.status_code == 200
@patch("httpx.Client.request")
def test_retry_exceed_max_retries(mock_request):
mock_response = MagicMock()
mock_response.status_code = 500
side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES
mock_request.side_effect = side_effects
with pytest.raises(Exception) as e:
make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
@patch("httpx.Client.request")
def test_retry_logic_success(mock_request):
side_effects = []
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
status_code = secrets.choice(STATUS_FORCELIST)
mock_response = MagicMock()
mock_response.status_code = status_code
side_effects.append(mock_response)
mock_response_200 = MagicMock()
mock_response_200.status_code = 200
side_effects.append(mock_response_200)
mock_request.side_effect = side_effects
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
assert response.status_code == 200
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
assert mock_request.call_args_list[0][1].get("method") == "GET"

View File

@@ -0,0 +1,86 @@
import pytest
from core.helper.trace_id_helper import extract_external_trace_id_from_args, get_external_trace_id, is_valid_trace_id
class DummyRequest:
def __init__(self, headers=None, args=None, json=None, is_json=False):
self.headers = headers or {}
self.args = args or {}
self.json = json
self.is_json = is_json
class TestTraceIdHelper:
"""Test cases for trace_id_helper.py"""
@pytest.mark.parametrize(
("trace_id", "expected"),
[
("abc123", True),
("A-B_C-123", True),
("a" * 128, True),
("", False),
("a" * 129, False),
("abc!@#", False),
("空格", False),
("with space", False),
],
)
def test_is_valid_trace_id(self, trace_id, expected):
"""Test trace_id validation for various cases"""
assert is_valid_trace_id(trace_id) is expected
def test_get_external_trace_id_from_header(self):
"""Should extract valid trace_id from header"""
req = DummyRequest(headers={"X-Trace-Id": "abc123"})
assert get_external_trace_id(req) == "abc123"
def test_get_external_trace_id_from_args(self):
"""Should extract valid trace_id from args if header missing"""
req = DummyRequest(args={"trace_id": "abc123"})
assert get_external_trace_id(req) == "abc123"
def test_get_external_trace_id_from_json(self):
"""Should extract valid trace_id from JSON body if header and args missing"""
req = DummyRequest(is_json=True, json={"trace_id": "abc123"})
assert get_external_trace_id(req) == "abc123"
def test_get_external_trace_id_priority(self):
"""Header > args > json priority"""
req = DummyRequest(
headers={"X-Trace-Id": "header_id"},
args={"trace_id": "args_id"},
is_json=True,
json={"trace_id": "json_id"},
)
assert get_external_trace_id(req) == "header_id"
req2 = DummyRequest(args={"trace_id": "args_id"}, is_json=True, json={"trace_id": "json_id"})
assert get_external_trace_id(req2) == "args_id"
req3 = DummyRequest(is_json=True, json={"trace_id": "json_id"})
assert get_external_trace_id(req3) == "json_id"
@pytest.mark.parametrize(
"req",
[
DummyRequest(headers={"X-Trace-Id": "!!!"}),
DummyRequest(args={"trace_id": "!!!"}),
DummyRequest(is_json=True, json={"trace_id": "!!!"}),
DummyRequest(),
],
)
def test_get_external_trace_id_invalid(self, req):
"""Should return None for invalid or missing trace_id"""
assert get_external_trace_id(req) is None
@pytest.mark.parametrize(
("args", "expected"),
[
({"external_trace_id": "abc123"}, {"external_trace_id": "abc123"}),
({"other": "value"}, {}),
({}, {}),
],
)
def test_extract_external_trace_id_from_args(self, args, expected):
"""Test extraction of external_trace_id from args mapping"""
assert extract_external_trace_id_from_args(args) == expected

View File

@@ -0,0 +1,766 @@
"""Unit tests for MCP OAuth authentication flow."""
from unittest.mock import Mock, patch
import pytest
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.auth.auth_flow import (
OAUTH_STATE_EXPIRY_SECONDS,
OAUTH_STATE_REDIS_KEY_PREFIX,
OAuthCallbackState,
_create_secure_redis_state,
_retrieve_redis_state,
auth,
check_support_resource_discovery,
discover_oauth_metadata,
exchange_authorization,
generate_pkce_challenge,
handle_callback,
refresh_authorization,
register_client,
start_authorization,
)
from core.mcp.entities import AuthActionType, AuthResult
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
ProtectedResourceMetadata,
)
class TestPKCEGeneration:
"""Test PKCE challenge generation."""
def test_generate_pkce_challenge(self):
"""Test PKCE challenge and verifier generation."""
code_verifier, code_challenge = generate_pkce_challenge()
# Verify format - should be URL-safe base64 without padding
assert "=" not in code_verifier
assert "+" not in code_verifier
assert "/" not in code_verifier
assert "=" not in code_challenge
assert "+" not in code_challenge
assert "/" not in code_challenge
# Verify length
assert len(code_verifier) > 40 # Should be around 54 characters
assert len(code_challenge) > 40 # Should be around 43 characters
def test_generate_pkce_challenge_uniqueness(self):
"""Test that PKCE generation produces unique values."""
results = set()
for _ in range(10):
code_verifier, code_challenge = generate_pkce_challenge()
results.add((code_verifier, code_challenge))
# All should be unique
assert len(results) == 10
class TestRedisStateManagement:
"""Test Redis state management functions."""
@patch("core.mcp.auth.auth_flow.redis_client")
def test_create_secure_redis_state(self, mock_redis):
"""Test creating secure Redis state."""
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
state_key = _create_secure_redis_state(state_data)
# Verify state key format
assert len(state_key) > 20 # Should be a secure random token
# Verify Redis call
mock_redis.setex.assert_called_once()
call_args = mock_redis.setex.call_args
assert call_args[0][0].startswith(OAUTH_STATE_REDIS_KEY_PREFIX)
assert call_args[0][1] == OAUTH_STATE_EXPIRY_SECONDS
assert state_data.model_dump_json() in call_args[0][2]
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_success(self, mock_redis):
"""Test retrieving state from Redis."""
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_redis.get.return_value = state_data.model_dump_json()
result = _retrieve_redis_state("test-state-key")
# Verify result
assert result.provider_id == "test-provider"
assert result.tenant_id == "test-tenant"
assert result.server_url == "https://example.com"
# Verify Redis calls
mock_redis.get.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key")
mock_redis.delete.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key")
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_not_found(self, mock_redis):
"""Test retrieving non-existent state from Redis."""
mock_redis.get.return_value = None
with pytest.raises(ValueError) as exc_info:
_retrieve_redis_state("nonexistent-key")
assert "State parameter has expired or does not exist" in str(exc_info.value)
@patch("core.mcp.auth.auth_flow.redis_client")
def test_retrieve_redis_state_invalid_json(self, mock_redis):
"""Test retrieving invalid JSON state from Redis."""
mock_redis.get.return_value = '{"invalid": json}'
with pytest.raises(ValueError) as exc_info:
_retrieve_redis_state("test-key")
assert "Invalid state parameter" in str(exc_info.value)
# State should still be deleted
mock_redis.delete.assert_called_once()
class TestOAuthDiscovery:
"""Test OAuth discovery functions."""
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_success(self, mock_get):
"""Test successful resource discovery check."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]}
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com/endpoint")
assert supported is True
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource",
headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"},
)
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_not_supported(self, mock_get):
"""Test resource discovery not supported."""
mock_response = Mock()
mock_response.status_code = 404
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com")
assert supported is False
assert auth_url == ""
@patch("core.helper.ssrf_proxy.get")
def test_check_support_resource_discovery_with_query_fragment(self, mock_get):
"""Test resource discovery with query and fragment."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]}
mock_get.return_value = mock_response
supported, auth_url = check_support_resource_discovery("https://api.example.com/path?query=1#fragment")
assert supported is True
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment",
headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"},
)
def test_discover_oauth_metadata_with_resource_discovery(self):
"""Test OAuth metadata discovery with resource discovery support."""
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
# Mock protected resource metadata with auth server URL
mock_prm.return_value = ProtectedResourceMetadata(
resource="https://api.example.com",
authorization_servers=["https://auth.example.com"],
)
# Mock OAuth authorization server metadata
mock_asm.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
)
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
assert oauth_metadata is not None
assert oauth_metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert oauth_metadata.token_endpoint == "https://auth.example.com/token"
assert prm is not None
assert prm.authorization_servers == ["https://auth.example.com"]
# Verify the discovery functions were called
mock_prm.assert_called_once()
mock_asm.assert_called_once()
def test_discover_oauth_metadata_without_resource_discovery(self):
"""Test OAuth metadata discovery without resource discovery."""
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
# Mock no protected resource metadata
mock_prm.return_value = None
# Mock OAuth authorization server metadata
mock_asm.return_value = OAuthMetadata(
authorization_endpoint="https://api.example.com/oauth/authorize",
token_endpoint="https://api.example.com/oauth/token",
response_types_supported=["code"],
)
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
assert oauth_metadata is not None
assert oauth_metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
assert prm is None
# Verify the discovery functions were called
mock_prm.assert_called_once()
mock_asm.assert_called_once()
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_not_found(self, mock_get):
"""Test OAuth metadata discovery when not found."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (False, "")
mock_response = Mock()
mock_response.status_code = 404
mock_get.return_value = mock_response
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
assert oauth_metadata is None
class TestAuthorizationFlow:
"""Test authorization flow functions."""
@patch("core.mcp.auth.auth_flow._create_secure_redis_state")
def test_start_authorization_with_metadata(self, mock_create_state):
"""Test starting authorization with metadata."""
mock_create_state.return_value = "secure-state-key"
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
code_challenge_methods_supported=["S256"],
)
client_info = OAuthClientInformation(client_id="test-client-id")
auth_url, code_verifier = start_authorization(
"https://api.example.com",
metadata,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
# Verify URL format
assert auth_url.startswith("https://auth.example.com/authorize?")
assert "response_type=code" in auth_url
assert "client_id=test-client-id" in auth_url
assert "code_challenge=" in auth_url
assert "code_challenge_method=S256" in auth_url
assert "redirect_uri=https%3A%2F%2Fredirect.example.com" in auth_url
assert "state=secure-state-key" in auth_url
# Verify code verifier
assert len(code_verifier) > 40
# Verify state was stored
mock_create_state.assert_called_once()
state_data = mock_create_state.call_args[0][0]
assert state_data.provider_id == "provider-id"
assert state_data.tenant_id == "tenant-id"
assert state_data.code_verifier == code_verifier
def test_start_authorization_without_metadata(self):
"""Test starting authorization without metadata."""
with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create_state:
mock_create_state.return_value = "secure-state-key"
client_info = OAuthClientInformation(client_id="test-client-id")
auth_url, code_verifier = start_authorization(
"https://api.example.com",
None,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
# Should use default authorization endpoint
assert auth_url.startswith("https://api.example.com/authorize?")
def test_start_authorization_invalid_metadata(self):
"""Test starting authorization with invalid metadata."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["token"], # No "code" support
code_challenge_methods_supported=["plain"], # No "S256" support
)
client_info = OAuthClientInformation(client_id="test-client-id")
with pytest.raises(ValueError) as exc_info:
start_authorization(
"https://api.example.com",
metadata,
client_info,
"https://redirect.example.com",
"provider-id",
"tenant-id",
)
assert "does not support response type code" in str(exc_info.value)
@patch("core.helper.ssrf_proxy.post")
def test_exchange_authorization_success(self, mock_post):
"""Test successful authorization code exchange."""
mock_response = Mock()
mock_response.is_success = True
mock_response.headers = {"content-type": "application/json"}
mock_response.json.return_value = {
"access_token": "new-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "new-refresh-token",
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
)
client_info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret")
tokens = exchange_authorization(
"https://api.example.com",
metadata,
client_info,
"auth-code-123",
"code-verifier-xyz",
"https://redirect.example.com",
)
assert tokens.access_token == "new-access-token"
assert tokens.token_type == "Bearer"
assert tokens.expires_in == 3600
assert tokens.refresh_token == "new-refresh-token"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/token",
data={
"grant_type": "authorization_code",
"client_id": "test-client-id",
"client_secret": "test-secret",
"code": "auth-code-123",
"code_verifier": "code-verifier-xyz",
"redirect_uri": "https://redirect.example.com",
},
)
@patch("core.helper.ssrf_proxy.post")
def test_exchange_authorization_failure(self, mock_post):
"""Test failed authorization code exchange."""
mock_response = Mock()
mock_response.is_success = False
mock_response.status_code = 400
mock_post.return_value = mock_response
client_info = OAuthClientInformation(client_id="test-client-id")
with pytest.raises(ValueError) as exc_info:
exchange_authorization(
"https://api.example.com",
None,
client_info,
"invalid-code",
"code-verifier",
"https://redirect.example.com",
)
assert "Token exchange failed: HTTP 400" in str(exc_info.value)
@patch("core.helper.ssrf_proxy.post")
def test_refresh_authorization_success(self, mock_post):
"""Test successful token refresh."""
mock_response = Mock()
mock_response.is_success = True
mock_response.headers = {"content-type": "application/json"}
mock_response.json.return_value = {
"access_token": "refreshed-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "new-refresh-token",
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["refresh_token"],
)
client_info = OAuthClientInformation(client_id="test-client-id")
tokens = refresh_authorization("https://api.example.com", metadata, client_info, "old-refresh-token")
assert tokens.access_token == "refreshed-access-token"
assert tokens.refresh_token == "new-refresh-token"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/token",
data={
"grant_type": "refresh_token",
"client_id": "test-client-id",
"refresh_token": "old-refresh-token",
},
)
@patch("core.helper.ssrf_proxy.post")
def test_register_client_success(self, mock_post):
"""Test successful client registration."""
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"client_id": "new-client-id",
"client_secret": "new-client-secret",
"client_name": "Dify",
"redirect_uris": ["https://redirect.example.com"],
}
mock_post.return_value = mock_response
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint="https://auth.example.com/register",
response_types_supported=["code"],
)
client_metadata = OAuthClientMetadata(
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
grant_types=["authorization_code"],
response_types=["code"],
)
client_info = register_client("https://api.example.com", metadata, client_metadata)
assert isinstance(client_info, OAuthClientInformationFull)
assert client_info.client_id == "new-client-id"
assert client_info.client_secret == "new-client-secret"
# Verify request
mock_post.assert_called_once_with(
"https://auth.example.com/register",
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
)
def test_register_client_no_endpoint(self):
"""Test client registration when no endpoint available."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint=None,
response_types_supported=["code"],
)
client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://redirect.example.com"])
with pytest.raises(ValueError) as exc_info:
register_client("https://api.example.com", metadata, client_metadata)
assert "does not support dynamic client registration" in str(exc_info.value)
class TestCallbackHandling:
"""Test OAuth callback handling."""
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
@patch("core.mcp.auth.auth_flow.exchange_authorization")
def test_handle_callback_success(self, mock_exchange, mock_retrieve_state):
"""Test successful callback handling."""
# Setup state
state_data = OAuthCallbackState(
provider_id="test-provider",
tenant_id="test-tenant",
server_url="https://api.example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="test-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_retrieve_state.return_value = state_data
# Setup token exchange
tokens = OAuthTokens(
access_token="new-token",
token_type="Bearer",
expires_in=3600,
)
mock_exchange.return_value = tokens
# Setup service
mock_service = Mock()
state_result, tokens_result = handle_callback("state-key", "auth-code")
assert state_result == state_data
assert tokens_result == tokens
# Verify calls
mock_retrieve_state.assert_called_once_with("state-key")
mock_exchange.assert_called_once_with(
"https://api.example.com",
None,
state_data.client_information,
"auth-code",
"test-verifier",
"https://redirect.example.com",
)
# Note: handle_callback no longer saves tokens directly, it just returns them
# The caller (e.g., controller) is responsible for saving via execute_auth_actions
class TestAuthOrchestration:
"""Test the main auth orchestration function."""
@pytest.fixture
def mock_provider(self):
"""Create a mock provider entity."""
provider = Mock(spec=MCPProviderEntity)
provider.id = "provider-id"
provider.tenant_id = "tenant-id"
provider.decrypt_server_url.return_value = "https://api.example.com"
provider.client_metadata = OAuthClientMetadata(
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
)
provider.redirect_url = "https://redirect.example.com"
provider.retrieve_client_information.return_value = None
provider.retrieve_tokens.return_value = None
return provider
@pytest.fixture
def mock_service(self):
"""Create a mock MCP service."""
return Mock()
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
@patch("core.mcp.auth.auth_flow.register_client")
@patch("core.mcp.auth.auth_flow.start_authorization")
def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service):
"""Test auth flow for new client registration."""
# Setup
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
mock_register.return_value = OAuthClientInformationFull(
client_id="new-client-id",
client_name="Dify",
redirect_uris=["https://redirect.example.com"],
)
mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier")
result = auth(mock_provider)
# auth() now returns AuthResult
assert isinstance(result, AuthResult)
assert result.response == {"authorization_url": "https://auth.example.com/authorize?..."}
# Verify that the result contains the correct actions
assert len(result.actions) == 2
# Check for SAVE_CLIENT_INFO action
client_info_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CLIENT_INFO)
assert client_info_action.data == {"client_information": mock_register.return_value.model_dump()}
assert client_info_action.provider_id == "provider-id"
assert client_info_action.tenant_id == "tenant-id"
# Check for SAVE_CODE_VERIFIER action
verifier_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CODE_VERIFIER)
assert verifier_action.data == {"code_verifier": "code-verifier"}
assert verifier_action.provider_id == "provider-id"
assert verifier_action.tenant_id == "tenant-id"
# Verify calls
mock_register.assert_called_once()
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
@patch("core.mcp.auth.auth_flow.exchange_authorization")
def test_auth_exchange_code(self, mock_exchange, mock_retrieve_state, mock_discover, mock_provider, mock_service):
"""Test auth flow for exchanging authorization code."""
# Setup metadata discovery
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
# Setup existing client
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
# Setup state retrieval
state_data = OAuthCallbackState(
provider_id="provider-id",
tenant_id="tenant-id",
server_url="https://api.example.com",
metadata=None,
client_information=OAuthClientInformation(client_id="existing-client"),
code_verifier="test-verifier",
redirect_uri="https://redirect.example.com",
)
mock_retrieve_state.return_value = state_data
# Setup token exchange
tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600)
mock_exchange.return_value = tokens
result = auth(mock_provider, authorization_code="auth-code", state_param="state-key")
# auth() now returns AuthResult, not a dict
assert isinstance(result, AuthResult)
assert result.response == {"result": "success"}
# Verify that the result contains the correct action
assert len(result.actions) == 1
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
assert result.actions[0].data == tokens.model_dump()
assert result.actions[0].provider_id == "provider-id"
assert result.actions[0].tenant_id == "tenant-id"
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
"""Test auth flow fails when exchanging code without state."""
# Setup metadata discovery
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
with pytest.raises(ValueError) as exc_info:
auth(mock_provider, authorization_code="auth-code")
assert "State parameter is required" in str(exc_info.value)
@patch("core.mcp.auth.auth_flow.refresh_authorization")
def test_auth_refresh_token(self, mock_refresh, mock_provider, mock_service):
"""Test auth flow for refreshing tokens."""
# Setup existing client and tokens
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
mock_provider.retrieve_tokens.return_value = OAuthTokens(
access_token="old-token",
token_type="Bearer",
expires_in=0,
refresh_token="refresh-token",
)
# Setup refresh
new_tokens = OAuthTokens(
access_token="refreshed-token",
token_type="Bearer",
expires_in=3600,
refresh_token="new-refresh-token",
)
mock_refresh.return_value = new_tokens
with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover:
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
result = auth(mock_provider)
# auth() now returns AuthResult
assert isinstance(result, AuthResult)
assert result.response == {"result": "success"}
# Verify that the result contains the correct action
assert len(result.actions) == 1
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
assert result.actions[0].data == new_tokens.model_dump()
assert result.actions[0].provider_id == "provider-id"
assert result.actions[0].tenant_id == "tenant-id"
# Verify refresh was called
mock_refresh.assert_called_once()
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
"""Test auth fails when no client info exists but code is provided."""
# Setup metadata discovery
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
mock_provider.retrieve_client_information.return_value = None
with pytest.raises(ValueError) as exc_info:
auth(mock_provider, authorization_code="auth-code")
assert "Existing OAuth client information is required" in str(exc_info.value)

View File

@@ -0,0 +1,468 @@
import queue
import threading
from typing import Any
from core.mcp import types
from core.mcp.entities import RequestContext
from core.mcp.session.base_session import RequestResponder
from core.mcp.session.client_session import DEFAULT_CLIENT_INFO, ClientSession
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
ClientNotification,
ClientRequest,
Implementation,
InitializedNotification,
InitializeRequest,
InitializeResult,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ServerCapabilities,
ServerResult,
SessionMessage,
)
def test_client_session_initialize():
# Create synchronous queues to replace async streams
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
initialized_notification = None
def mock_server():
nonlocal initialized_notification
# Receive initialization request
session_message = client_to_server.get(timeout=5.0)
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request.root, InitializeRequest)
# Create response
result = ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(
logging=None,
resources=None,
tools=None,
experimental=None,
prompts=None,
),
serverInfo=Implementation(name="mock-server", version="0.1.0"),
instructions="The server instructions.",
)
)
# Send response
server_to_client.put(
SessionMessage(
message=JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
)
# Receive initialized notification
session_notification = client_to_server.get(timeout=5.0)
jsonrpc_notification = session_notification.message
assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
initialized_notification = ClientNotification.model_validate(
jsonrpc_notification.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Create message handler
def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
):
if isinstance(message, Exception):
raise message
# Start mock server thread
server_thread = threading.Thread(target=mock_server, daemon=True)
server_thread.start()
# Create and use client session
with ClientSession(
server_to_client,
client_to_server,
message_handler=message_handler,
) as session:
result = session.initialize()
# Wait for server thread to complete
server_thread.join(timeout=10.0)
# Assert results
assert isinstance(result, InitializeResult)
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
assert isinstance(result.capabilities, ServerCapabilities)
assert result.serverInfo == Implementation(name="mock-server", version="0.1.0")
assert result.instructions == "The server instructions."
# Check that client sent initialized notification
assert initialized_notification
assert isinstance(initialized_notification.root, InitializedNotification)
def test_client_session_custom_client_info():
# Create synchronous queues to replace async streams
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
custom_client_info = Implementation(name="test-client", version="1.2.3")
received_client_info = None
def mock_server():
nonlocal received_client_info
session_message = client_to_server.get(timeout=5.0)
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request.root, InitializeRequest)
received_client_info = request.root.params.clientInfo
result = ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
serverInfo=Implementation(name="mock-server", version="0.1.0"),
)
)
server_to_client.put(
SessionMessage(
message=JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
)
# Receive initialized notification
client_to_server.get(timeout=5.0)
# Start mock server thread
server_thread = threading.Thread(target=mock_server, daemon=True)
server_thread.start()
with ClientSession(
server_to_client,
client_to_server,
client_info=custom_client_info,
) as session:
session.initialize()
# Wait for server thread to complete
server_thread.join(timeout=10.0)
# Assert that custom client info was sent
assert received_client_info == custom_client_info
def test_client_session_default_client_info():
# Create synchronous queues to replace async streams
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
received_client_info = None
def mock_server():
nonlocal received_client_info
session_message = client_to_server.get(timeout=5.0)
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request.root, InitializeRequest)
received_client_info = request.root.params.clientInfo
result = ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
serverInfo=Implementation(name="mock-server", version="0.1.0"),
)
)
server_to_client.put(
SessionMessage(
message=JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
)
# Receive initialized notification
client_to_server.get(timeout=5.0)
# Start mock server thread
server_thread = threading.Thread(target=mock_server, daemon=True)
server_thread.start()
with ClientSession(
server_to_client,
client_to_server,
) as session:
session.initialize()
# Wait for server thread to complete
server_thread.join(timeout=10.0)
# Assert that default client info was used
assert received_client_info == DEFAULT_CLIENT_INFO
def test_client_session_version_negotiation_success():
# Create synchronous queues to replace async streams
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
def mock_server():
session_message = client_to_server.get(timeout=5.0)
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request.root, InitializeRequest)
# Send supported protocol version
result = ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
serverInfo=Implementation(name="mock-server", version="0.1.0"),
)
)
server_to_client.put(
SessionMessage(
message=JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
)
# Receive initialized notification
client_to_server.get(timeout=5.0)
# Start mock server thread
server_thread = threading.Thread(target=mock_server, daemon=True)
server_thread.start()
with ClientSession(
server_to_client,
client_to_server,
) as session:
result = session.initialize()
# Wait for server thread to complete
server_thread.join(timeout=10.0)
# Should successfully initialize
assert isinstance(result, InitializeResult)
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
def test_client_session_version_negotiation_failure():
# Create synchronous queues to replace async streams
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
def mock_server():
session_message = client_to_server.get(timeout=5.0)
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request.root, InitializeRequest)
# Send unsupported protocol version
result = ServerResult(
InitializeResult(
protocolVersion="99.99.99", # Unsupported version
capabilities=ServerCapabilities(),
serverInfo=Implementation(name="mock-server", version="0.1.0"),
)
)
server_to_client.put(
SessionMessage(
message=JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
)
# Start mock server thread
server_thread = threading.Thread(target=mock_server, daemon=True)
server_thread.start()
with ClientSession(
server_to_client,
client_to_server,
) as session:
import pytest
with pytest.raises(RuntimeError, match="Unsupported protocol version"):
session.initialize()
# Wait for server thread to complete
server_thread.join(timeout=10.0)
def test_client_capabilities_default():
# Create synchronous queues to replace async streams
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
received_capabilities = None
def mock_server():
nonlocal received_capabilities
session_message = client_to_server.get(timeout=5.0)
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request.root, InitializeRequest)
received_capabilities = request.root.params.capabilities
result = ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
serverInfo=Implementation(name="mock-server", version="0.1.0"),
)
)
server_to_client.put(
SessionMessage(
message=JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
)
# Receive initialized notification
client_to_server.get(timeout=5.0)
# Start mock server thread
server_thread = threading.Thread(target=mock_server, daemon=True)
server_thread.start()
with ClientSession(
server_to_client,
client_to_server,
) as session:
session.initialize()
# Wait for server thread to complete
server_thread.join(timeout=10.0)
# Assert default capabilities
assert received_capabilities is not None
def test_client_capabilities_with_custom_callbacks():
# Create synchronous queues to replace async streams
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
def custom_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.CreateMessageResult(
model="test-model",
role="assistant",
content=types.TextContent(type="text", text="Custom response"),
)
def custom_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ListRootsResult(roots=[])
def mock_server():
session_message = client_to_server.get(timeout=5.0)
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request.root, InitializeRequest)
result = ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
serverInfo=Implementation(name="mock-server", version="0.1.0"),
)
)
server_to_client.put(
SessionMessage(
message=JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
)
)
)
)
# Receive initialized notification
client_to_server.get(timeout=5.0)
# Start mock server thread
server_thread = threading.Thread(target=mock_server, daemon=True)
server_thread.start()
with ClientSession(
server_to_client,
client_to_server,
sampling_callback=custom_sampling_callback,
list_roots_callback=custom_list_roots_callback,
) as session:
result = session.initialize()
# Wait for server thread to complete
server_thread.join(timeout=10.0)
# Verify initialization succeeded
assert isinstance(result, InitializeResult)
assert result.protocolVersion == LATEST_PROTOCOL_VERSION

View File

@@ -0,0 +1,324 @@
import contextlib
import json
import queue
import threading
import time
from typing import Any
from unittest.mock import Mock, patch
import httpx
import pytest
from core.mcp import types
from core.mcp.client.sse_client import sse_client
from core.mcp.error import MCPAuthError, MCPConnectionError
SERVER_NAME = "test_server_for_SSE"
def test_sse_message_id_coercion():
"""Test that string message IDs that look like integers are parsed as integers.
See <https://github.com/modelcontextprotocol/python-sdk/pull/851> for more details.
"""
json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}'
msg = types.JSONRPCMessage.model_validate_json(json_message)
expected = types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123))
# Check if both are JSONRPCRequest instances
assert isinstance(msg.root, types.JSONRPCRequest)
assert isinstance(expected.root, types.JSONRPCRequest)
assert msg.root.id == expected.root.id
assert msg.root.method == expected.root.method
assert msg.root.jsonrpc == expected.root.jsonrpc
class MockSSEClient:
"""Mock SSE client for testing."""
def __init__(self, url: str, headers: dict[str, Any] | None = None):
self.url = url
self.headers = headers or {}
self.connected = False
self.read_queue: queue.Queue = queue.Queue()
self.write_queue: queue.Queue = queue.Queue()
def connect(self):
"""Simulate connection establishment."""
self.connected = True
# Send endpoint event
endpoint_data = "/messages/?session_id=test-session-123"
self.read_queue.put(("endpoint", endpoint_data))
return self.read_queue, self.write_queue
def send_initialize_response(self):
"""Send a mock initialize response."""
response = {
"jsonrpc": "2.0",
"id": 1,
"result": {
"protocolVersion": types.LATEST_PROTOCOL_VERSION,
"capabilities": {
"logging": None,
"resources": None,
"tools": None,
"experimental": None,
"prompts": None,
},
"serverInfo": {"name": SERVER_NAME, "version": "0.1.0"},
"instructions": "Test server instructions.",
},
}
self.read_queue.put(("message", json.dumps(response)))
def test_sse_client_message_id_handling():
"""Test SSE client properly handles message ID coercion."""
mock_client = MockSSEClient("http://test.example/sse")
read_queue, write_queue = mock_client.connect()
# Send a message with string ID that should be coerced to int
message_data = {
"jsonrpc": "2.0",
"id": "456", # String ID
"result": {"test": "data"},
}
read_queue.put(("message", json.dumps(message_data)))
read_queue.get(timeout=1.0)
# Get the message from queue
event_type, data = read_queue.get(timeout=1.0)
assert event_type == "message"
# Parse the message
parsed_message = types.JSONRPCMessage.model_validate_json(data)
# Check that it's a JSONRPCResponse and verify the ID
assert isinstance(parsed_message.root, types.JSONRPCResponse)
assert parsed_message.root.id == 456 # Should be converted to int
def test_sse_client_connection_validation():
"""Test SSE client validates endpoint URLs properly."""
test_url = "http://test.example/sse"
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
# Mock the HTTP client
mock_client = Mock()
mock_client_factory.return_value.__enter__.return_value = mock_client
# Mock the SSE connection
mock_event_source = Mock()
mock_event_source.response.raise_for_status.return_value = None
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
# Mock SSE events
class MockSSEEvent:
def __init__(self, event_type: str, data: str):
self.event = event_type
self.data = data
# Simulate endpoint event
endpoint_event = MockSSEEvent("endpoint", "/messages/?session_id=test-123")
mock_event_source.iter_sse.return_value = [endpoint_event]
# Test connection
with contextlib.suppress(Exception):
with sse_client(test_url) as (read_queue, write_queue):
assert read_queue is not None
assert write_queue is not None
def test_sse_client_error_handling():
"""Test SSE client properly handles various error conditions."""
test_url = "http://test.example/sse"
# Test 401 error handling
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
# Mock 401 HTTP error
mock_response = Mock(status_code=401)
mock_response.headers = {"WWW-Authenticate": 'Bearer realm="example"'}
mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=mock_response)
mock_sse_connect.side_effect = mock_error
with pytest.raises(MCPAuthError):
with sse_client(test_url):
pass
# Test other HTTP errors
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
# Mock other HTTP error
mock_response = Mock(status_code=500)
mock_response.headers = {}
mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=mock_response)
mock_sse_connect.side_effect = mock_error
with pytest.raises(MCPConnectionError):
with sse_client(test_url):
pass
def test_sse_client_timeout_configuration():
"""Test SSE client timeout configuration."""
test_url = "http://test.example/sse"
custom_timeout = 10.0
custom_sse_timeout = 300.0
custom_headers = {"Authorization": "Bearer test-token"}
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
# Mock successful connection
mock_client = Mock()
mock_client_factory.return_value.__enter__.return_value = mock_client
mock_event_source = Mock()
mock_event_source.response.raise_for_status.return_value = None
mock_event_source.iter_sse.return_value = []
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
with contextlib.suppress(Exception):
with sse_client(
test_url, headers=custom_headers, timeout=custom_timeout, sse_read_timeout=custom_sse_timeout
) as (read_queue, write_queue):
# Verify the configuration was passed correctly
mock_client_factory.assert_called_with(headers=custom_headers)
# Check that timeout was configured
call_args = mock_sse_connect.call_args
assert call_args is not None
timeout_arg = call_args[1]["timeout"]
assert timeout_arg.read == custom_sse_timeout
def test_sse_transport_endpoint_validation():
"""Test SSE transport validates endpoint URLs correctly."""
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
# Valid endpoint (same origin)
valid_endpoint = "http://example.com/messages/session123"
assert transport._validate_endpoint_url(valid_endpoint) == True
# Invalid endpoint (different origin)
invalid_endpoint = "http://malicious.com/messages/session123"
assert transport._validate_endpoint_url(invalid_endpoint) == False
# Invalid endpoint (different scheme)
invalid_scheme = "https://example.com/messages/session123"
assert transport._validate_endpoint_url(invalid_scheme) == False
def test_sse_transport_message_parsing():
"""Test SSE transport properly parses different message types."""
from core.mcp.client.sse_client import SSETransport
transport = SSETransport("http://example.com/sse")
read_queue: queue.Queue = queue.Queue()
# Test valid JSON-RPC message
valid_message = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
transport._handle_message_event(valid_message, read_queue)
# Should have a SessionMessage in the queue
message = read_queue.get(timeout=1.0)
assert message is not None
assert hasattr(message, "message")
# Test invalid JSON
invalid_json = '{"invalid": json}'
transport._handle_message_event(invalid_json, read_queue)
# Should have an exception in the queue
error = read_queue.get(timeout=1.0)
assert isinstance(error, Exception)
def test_sse_client_queue_cleanup():
"""Test that SSE client properly cleans up queues on exit."""
test_url = "http://test.example/sse"
read_queue = None
write_queue = None
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
# Mock connection that raises an exception
mock_sse_connect.side_effect = Exception("Connection failed")
with contextlib.suppress(Exception):
with sse_client(test_url) as (rq, wq):
read_queue = rq
write_queue = wq
# Queues should be cleaned up even on exception
# Note: In real implementation, cleanup should put None to signal shutdown
def test_sse_client_headers_propagation():
"""Test that custom headers are properly propagated in SSE client."""
test_url = "http://test.example/sse"
custom_headers = {
"Authorization": "Bearer test-token",
"X-Custom-Header": "test-value",
"User-Agent": "test-client/1.0",
}
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
# Mock the client factory to capture headers
mock_client = Mock()
mock_client_factory.return_value.__enter__.return_value = mock_client
# Mock the SSE connection
mock_event_source = Mock()
mock_event_source.response.raise_for_status.return_value = None
mock_event_source.iter_sse.return_value = []
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
with contextlib.suppress(Exception):
with sse_client(test_url, headers=custom_headers):
pass
# Verify headers were passed to client factory
mock_client_factory.assert_called_with(headers=custom_headers)
def test_sse_client_concurrent_access():
"""Test SSE client behavior with concurrent queue access."""
test_read_queue: queue.Queue = queue.Queue()
# Simulate concurrent producers and consumers
def producer():
for i in range(10):
test_read_queue.put(f"message_{i}")
time.sleep(0.01) # Small delay to simulate real conditions
def consumer():
received = []
for _ in range(10):
try:
msg = test_read_queue.get(timeout=2.0)
received.append(msg)
except queue.Empty:
break
return received
# Start producer in separate thread
producer_thread = threading.Thread(target=producer, daemon=True)
producer_thread.start()
# Consume messages
received_messages = consumer()
# Wait for producer to finish
producer_thread.join(timeout=5.0)
# Verify all messages were received
assert len(received_messages) == 10
for i in range(10):
assert f"message_{i}" in received_messages

View File

@@ -0,0 +1,450 @@
"""
Tests for the StreamableHTTP client transport.
Contains tests for only the client side of the StreamableHTTP transport.
"""
import queue
import threading
import time
from typing import Any
from unittest.mock import Mock, patch
from core.mcp import types
from core.mcp.client.streamable_client import streamablehttp_client
# Test constants
SERVER_NAME = "test_streamable_http_server"
TEST_SESSION_ID = "test-session-id-12345"
INIT_REQUEST = {
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"clientInfo": {"name": "test-client", "version": "1.0"},
"protocolVersion": "2025-03-26",
"capabilities": {},
},
"id": "init-1",
}
class MockStreamableHTTPClient:
"""Mock StreamableHTTP client for testing."""
def __init__(self, url: str, headers: dict[str, Any] | None = None):
self.url = url
self.headers = headers or {}
self.connected = False
self.read_queue: queue.Queue = queue.Queue()
self.write_queue: queue.Queue = queue.Queue()
self.session_id = TEST_SESSION_ID
def connect(self):
"""Simulate connection establishment."""
self.connected = True
return self.read_queue, self.write_queue, lambda: self.session_id
def send_initialize_response(self):
"""Send a mock initialize response."""
session_message = types.SessionMessage(
message=types.JSONRPCMessage(
root=types.JSONRPCResponse(
jsonrpc="2.0",
id="init-1",
result={
"protocolVersion": types.LATEST_PROTOCOL_VERSION,
"capabilities": {
"logging": None,
"resources": None,
"tools": None,
"experimental": None,
"prompts": None,
},
"serverInfo": {"name": SERVER_NAME, "version": "0.1.0"},
"instructions": "Test server instructions.",
},
)
)
)
self.read_queue.put(session_message)
def send_tools_response(self):
"""Send a mock tools list response."""
session_message = types.SessionMessage(
message=types.JSONRPCMessage(
root=types.JSONRPCResponse(
jsonrpc="2.0",
id="tools-1",
result={
"tools": [
{
"name": "test_tool",
"description": "A test tool",
"inputSchema": {"type": "object", "properties": {}},
}
],
},
)
)
)
self.read_queue.put(session_message)
def test_streamablehttp_client_message_id_handling():
"""Test StreamableHTTP client properly handles message ID coercion."""
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
read_queue, write_queue, get_session_id = mock_client.connect()
# Send a message with string ID that should be coerced to int
response_message = types.SessionMessage(
message=types.JSONRPCMessage(root=types.JSONRPCResponse(jsonrpc="2.0", id="789", result={"test": "data"}))
)
read_queue.put(response_message)
# Get the message from queue
message = read_queue.get(timeout=1.0)
assert message is not None
assert isinstance(message, types.SessionMessage)
# Check that the ID was properly handled
assert isinstance(message.message.root, types.JSONRPCResponse)
assert message.message.root.id == 789 # ID should be coerced to int due to union_mode="left_to_right"
def test_streamablehttp_client_connection_validation():
"""Test StreamableHTTP client validates connections properly."""
test_url = "http://test.example/mcp"
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
# Mock the HTTP client
mock_client = Mock()
mock_client_factory.return_value.__enter__.return_value = mock_client
# Mock successful response
mock_response = Mock()
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json"}
mock_response.raise_for_status.return_value = None
mock_client.post.return_value = mock_response
# Test connection
try:
with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
assert read_queue is not None
assert write_queue is not None
assert get_session_id is not None
except Exception:
# Connection might fail due to mocking, but we're testing the validation logic
pass
def test_streamablehttp_client_timeout_configuration():
"""Test StreamableHTTP client timeout configuration."""
test_url = "http://test.example/mcp"
custom_headers = {"Authorization": "Bearer test-token"}
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
# Mock successful connection
mock_client = Mock()
mock_client_factory.return_value.__enter__.return_value = mock_client
mock_response = Mock()
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json"}
mock_response.raise_for_status.return_value = None
mock_client.post.return_value = mock_response
try:
with streamablehttp_client(test_url, headers=custom_headers) as (read_queue, write_queue, get_session_id):
# Verify the configuration was passed correctly
mock_client_factory.assert_called_with(headers=custom_headers)
except Exception:
# Connection might fail due to mocking, but we tested the configuration
pass
def test_streamablehttp_client_session_id_handling():
"""Test StreamableHTTP client properly handles session IDs."""
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
read_queue, write_queue, get_session_id = mock_client.connect()
# Test that session ID is available
session_id = get_session_id()
assert session_id == TEST_SESSION_ID
# Test that we can use the session ID in subsequent requests
assert session_id is not None
assert len(session_id) > 0
def test_streamablehttp_client_message_parsing():
"""Test StreamableHTTP client properly parses different message types."""
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
read_queue, write_queue, get_session_id = mock_client.connect()
# Test valid initialization response
mock_client.send_initialize_response()
# Should have a SessionMessage in the queue
message = read_queue.get(timeout=1.0)
assert message is not None
assert isinstance(message, types.SessionMessage)
assert isinstance(message.message.root, types.JSONRPCResponse)
# Test tools response
mock_client.send_tools_response()
tools_message = read_queue.get(timeout=1.0)
assert tools_message is not None
assert isinstance(tools_message, types.SessionMessage)
def test_streamablehttp_client_queue_cleanup():
"""Test that StreamableHTTP client properly cleans up queues on exit."""
test_url = "http://test.example/mcp"
read_queue = None
write_queue = None
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
# Mock connection that raises an exception
mock_client_factory.side_effect = Exception("Connection failed")
try:
with streamablehttp_client(test_url) as (rq, wq, get_session_id):
read_queue = rq
write_queue = wq
except Exception:
pass # Expected to fail
# Queues should be cleaned up even on exception
# Note: In real implementation, cleanup should put None to signal shutdown
def test_streamablehttp_client_headers_propagation():
"""Test that custom headers are properly propagated in StreamableHTTP client."""
test_url = "http://test.example/mcp"
custom_headers = {
"Authorization": "Bearer test-token",
"X-Custom-Header": "test-value",
"User-Agent": "test-client/1.0",
}
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
# Mock the client factory to capture headers
mock_client = Mock()
mock_client_factory.return_value.__enter__.return_value = mock_client
mock_response = Mock()
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json"}
mock_response.raise_for_status.return_value = None
mock_client.post.return_value = mock_response
try:
with streamablehttp_client(test_url, headers=custom_headers):
pass
except Exception:
pass # Expected due to mocking
# Verify headers were passed to client factory
# Check that the call was made with headers that include our custom headers
mock_client_factory.assert_called_once()
call_args = mock_client_factory.call_args
assert "headers" in call_args.kwargs
passed_headers = call_args.kwargs["headers"]
# Verify all custom headers are present
for key, value in custom_headers.items():
assert key in passed_headers
assert passed_headers[key] == value
def test_streamablehttp_client_concurrent_access():
"""Test StreamableHTTP client behavior with concurrent queue access."""
test_read_queue: queue.Queue = queue.Queue()
test_write_queue: queue.Queue = queue.Queue()
# Simulate concurrent producers and consumers
def producer():
for i in range(10):
test_read_queue.put(f"message_{i}")
time.sleep(0.01) # Small delay to simulate real conditions
def consumer():
received = []
for _ in range(10):
try:
msg = test_read_queue.get(timeout=2.0)
received.append(msg)
except queue.Empty:
break
return received
# Start producer in separate thread
producer_thread = threading.Thread(target=producer, daemon=True)
producer_thread.start()
# Consume messages
received_messages = consumer()
# Wait for producer to finish
producer_thread.join(timeout=5.0)
# Verify all messages were received
assert len(received_messages) == 10
for i in range(10):
assert f"message_{i}" in received_messages
def test_streamablehttp_client_json_vs_sse_mode():
"""Test StreamableHTTP client handling of JSON vs SSE response modes."""
test_url = "http://test.example/mcp"
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
mock_client = Mock()
mock_client_factory.return_value.__enter__.return_value = mock_client
# Mock JSON response
mock_json_response = Mock()
mock_json_response.status_code = 200
mock_json_response.headers = {"content-type": "application/json"}
mock_json_response.json.return_value = {"result": "json_mode"}
mock_json_response.raise_for_status.return_value = None
# Mock SSE response
mock_sse_response = Mock()
mock_sse_response.status_code = 200
mock_sse_response.headers = {"content-type": "text/event-stream"}
mock_sse_response.raise_for_status.return_value = None
# Test JSON mode
mock_client.post.return_value = mock_json_response
try:
with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
# Should handle JSON responses
assert read_queue is not None
assert write_queue is not None
except Exception:
pass # Expected due to mocking
# Test SSE mode
mock_client.post.return_value = mock_sse_response
try:
with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
# Should handle SSE responses
assert read_queue is not None
assert write_queue is not None
except Exception:
pass # Expected due to mocking
def test_streamablehttp_client_terminate_on_close():
"""Test StreamableHTTP client terminate_on_close parameter."""
test_url = "http://test.example/mcp"
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
mock_client = Mock()
mock_client_factory.return_value.__enter__.return_value = mock_client
mock_response = Mock()
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json"}
mock_response.raise_for_status.return_value = None
mock_client.post.return_value = mock_response
mock_client.delete.return_value = mock_response
# Test with terminate_on_close=True (default)
try:
with streamablehttp_client(test_url, terminate_on_close=True) as (read_queue, write_queue, get_session_id):
pass
except Exception:
pass # Expected due to mocking
# Test with terminate_on_close=False
try:
with streamablehttp_client(test_url, terminate_on_close=False) as (read_queue, write_queue, get_session_id):
pass
except Exception:
pass # Expected due to mocking
def test_streamablehttp_client_protocol_version_handling():
"""Test StreamableHTTP client protocol version handling."""
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
read_queue, write_queue, get_session_id = mock_client.connect()
# Send initialize response with specific protocol version
session_message = types.SessionMessage(
message=types.JSONRPCMessage(
root=types.JSONRPCResponse(
jsonrpc="2.0",
id="init-1",
result={
"protocolVersion": "2024-11-05",
"capabilities": {},
"serverInfo": {"name": SERVER_NAME, "version": "0.1.0"},
},
)
)
)
read_queue.put(session_message)
# Get the message and verify protocol version
message = read_queue.get(timeout=1.0)
assert message is not None
assert isinstance(message.message.root, types.JSONRPCResponse)
result = message.message.root.result
assert result["protocolVersion"] == "2024-11-05"
def test_streamablehttp_client_error_response_handling():
"""Test StreamableHTTP client handling of error responses."""
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
read_queue, write_queue, get_session_id = mock_client.connect()
# Send an error response
session_message = types.SessionMessage(
message=types.JSONRPCMessage(
root=types.JSONRPCError(
jsonrpc="2.0",
id="test-1",
error=types.ErrorData(code=-32601, message="Method not found", data=None),
)
)
)
read_queue.put(session_message)
# Get the error message
message = read_queue.get(timeout=1.0)
assert message is not None
assert isinstance(message.message.root, types.JSONRPCError)
assert message.message.root.error.code == -32601
assert message.message.root.error.message == "Method not found"
def test_streamablehttp_client_resumption_token_handling():
"""Test StreamableHTTP client resumption token functionality."""
test_url = "http://test.example/mcp"
test_resumption_token = "resume-token-123"
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
mock_client = Mock()
mock_client_factory.return_value.__enter__.return_value = mock_client
mock_response = Mock()
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json", "last-event-id": test_resumption_token}
mock_response.raise_for_status.return_value = None
mock_client.post.return_value = mock_response
try:
with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
# Test that resumption token can be captured from headers
assert read_queue is not None
assert write_queue is not None
except Exception:
pass # Expected due to mocking

View File

@@ -0,0 +1 @@
# MCP server tests

View File

@@ -0,0 +1,512 @@
import json
from unittest.mock import Mock, patch
import jsonschema
import pytest
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types
from core.mcp.server.streamable_http import (
build_parameter_schema,
convert_input_form_to_parameters,
extract_answer_from_response,
handle_call_tool,
handle_initialize,
handle_list_tools,
handle_mcp_request,
handle_ping,
prepare_tool_arguments,
process_mapping_response,
)
from models.model import App, AppMCPServer, AppMode, EndUser
class TestHandleMCPRequest:
"""Test handle_mcp_request function"""
def setup_method(self):
"""Setup test fixtures"""
self.app = Mock(spec=App)
self.app.name = "test_app"
self.app.mode = AppMode.CHAT
self.mcp_server = Mock(spec=AppMCPServer)
self.mcp_server.description = "Test server"
self.mcp_server.parameters_dict = {}
self.end_user = Mock(spec=EndUser)
self.user_input_form = []
# Create mock request
self.mock_request = Mock()
self.mock_request.root = Mock()
self.mock_request.root.id = 123
def test_handle_ping_request(self):
"""Test handling ping request"""
# Setup ping request
self.mock_request.root = Mock(spec=types.PingRequest)
self.mock_request.root.id = 123
request_type = Mock(return_value=types.PingRequest)
with patch("core.mcp.server.streamable_http.type", request_type):
result = handle_mcp_request(
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
)
assert isinstance(result, types.JSONRPCResponse)
assert result.jsonrpc == "2.0"
assert result.id == 123
def test_handle_initialize_request(self):
"""Test handling initialize request"""
# Setup initialize request
self.mock_request.root = Mock(spec=types.InitializeRequest)
self.mock_request.root.id = 123
request_type = Mock(return_value=types.InitializeRequest)
with patch("core.mcp.server.streamable_http.type", request_type):
result = handle_mcp_request(
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
)
assert isinstance(result, types.JSONRPCResponse)
assert result.jsonrpc == "2.0"
assert result.id == 123
def test_handle_list_tools_request(self):
"""Test handling list tools request"""
# Setup list tools request
self.mock_request.root = Mock(spec=types.ListToolsRequest)
self.mock_request.root.id = 123
request_type = Mock(return_value=types.ListToolsRequest)
with patch("core.mcp.server.streamable_http.type", request_type):
result = handle_mcp_request(
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
)
assert isinstance(result, types.JSONRPCResponse)
assert result.jsonrpc == "2.0"
assert result.id == 123
@patch("core.mcp.server.streamable_http.AppGenerateService")
def test_handle_call_tool_request(self, mock_app_generate):
"""Test handling call tool request"""
# Setup call tool request
mock_call_request = Mock(spec=types.CallToolRequest)
mock_call_request.params = Mock()
mock_call_request.params.arguments = {"query": "test question"}
mock_call_request.id = 123
self.mock_request.root = mock_call_request
request_type = Mock(return_value=types.CallToolRequest)
# Mock app generate service response
mock_response = {"answer": "test answer"}
mock_app_generate.generate.return_value = mock_response
with patch("core.mcp.server.streamable_http.type", request_type):
result = handle_mcp_request(
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
)
assert isinstance(result, types.JSONRPCResponse)
assert result.jsonrpc == "2.0"
assert result.id == 123
# Verify AppGenerateService was called
mock_app_generate.generate.assert_called_once()
def test_handle_unknown_request_type(self):
"""Test handling unknown request type"""
# Setup unknown request
class UnknownRequest:
pass
self.mock_request.root = Mock(spec=UnknownRequest)
self.mock_request.root.id = 123
request_type = Mock(return_value=UnknownRequest)
with patch("core.mcp.server.streamable_http.type", request_type):
result = handle_mcp_request(
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
)
assert isinstance(result, types.JSONRPCError)
assert result.jsonrpc == "2.0"
assert result.id == 123
assert result.error.code == types.METHOD_NOT_FOUND
def test_handle_value_error(self):
"""Test handling ValueError"""
# Setup request that will cause ValueError
self.mock_request.root = Mock(spec=types.CallToolRequest)
self.mock_request.root.params = Mock()
self.mock_request.root.params.arguments = {}
request_type = Mock(return_value=types.CallToolRequest)
# Don't provide end_user to cause ValueError
with patch("core.mcp.server.streamable_http.type", request_type):
result = handle_mcp_request(self.app, self.mock_request, self.user_input_form, self.mcp_server, None, 123)
assert isinstance(result, types.JSONRPCError)
assert result.error.code == types.INVALID_PARAMS
def test_handle_generic_exception(self):
"""Test handling generic exception"""
# Setup request that will cause generic exception
self.mock_request.root = Mock(spec=types.PingRequest)
self.mock_request.root.id = 123
# Patch handle_ping to raise exception instead of type
with patch("core.mcp.server.streamable_http.handle_ping", side_effect=Exception("Test error")):
with patch("core.mcp.server.streamable_http.type", return_value=types.PingRequest):
result = handle_mcp_request(
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
)
assert isinstance(result, types.JSONRPCError)
assert result.error.code == types.INTERNAL_ERROR
class TestIndividualHandlers:
"""Test individual handler functions"""
def test_handle_ping(self):
"""Test ping handler"""
result = handle_ping()
assert isinstance(result, types.EmptyResult)
def test_handle_initialize(self):
"""Test initialize handler"""
description = "Test server"
with patch("core.mcp.server.streamable_http.dify_config") as mock_config:
mock_config.project.version = "1.0.0"
result = handle_initialize(description)
assert isinstance(result, types.InitializeResult)
assert result.protocolVersion == types.SERVER_LATEST_PROTOCOL_VERSION
assert result.instructions == "Test server"
def test_handle_list_tools(self):
"""Test list tools handler"""
app_name = "test_app"
app_mode = AppMode.CHAT
description = "Test server"
parameters_dict: dict[str, str] = {}
user_input_form: list[VariableEntity] = []
result = handle_list_tools(app_name, app_mode, user_input_form, description, parameters_dict)
assert isinstance(result, types.ListToolsResult)
assert len(result.tools) == 1
assert result.tools[0].name == "test_app"
assert result.tools[0].description == "Test server"
@patch("core.mcp.server.streamable_http.AppGenerateService")
def test_handle_call_tool(self, mock_app_generate):
"""Test call tool handler"""
app = Mock(spec=App)
app.mode = AppMode.CHAT
# Create mock request
mock_request = Mock()
mock_call_request = Mock(spec=types.CallToolRequest)
mock_call_request.params = Mock()
mock_call_request.params.arguments = {"query": "test question"}
mock_request.root = mock_call_request
user_input_form: list[VariableEntity] = []
end_user = Mock(spec=EndUser)
# Mock app generate service response
mock_response = {"answer": "test answer"}
mock_app_generate.generate.return_value = mock_response
result = handle_call_tool(app, mock_request, user_input_form, end_user)
assert isinstance(result, types.CallToolResult)
assert len(result.content) == 1
# Type assertion needed due to union type
text_content = result.content[0]
assert hasattr(text_content, "text")
assert text_content.text == "test answer"
def test_handle_call_tool_no_end_user(self):
"""Test call tool handler without end user"""
app = Mock(spec=App)
mock_request = Mock()
user_input_form: list[VariableEntity] = []
with pytest.raises(ValueError, match="End user not found"):
handle_call_tool(app, mock_request, user_input_form, None)
class TestUtilityFunctions:
"""Test utility functions"""
def test_build_parameter_schema_chat_mode(self):
"""Test building parameter schema for chat mode"""
app_mode = AppMode.CHAT
parameters_dict: dict[str, str] = {"name": "Enter your name"}
user_input_form = [
VariableEntity(
type=VariableEntityType.TEXT_INPUT,
variable="name",
description="User name",
label="Name",
required=True,
)
]
schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
assert schema["type"] == "object"
assert "query" in schema["properties"]
assert "name" in schema["properties"]
assert "query" in schema["required"]
assert "name" in schema["required"]
def test_build_parameter_schema_workflow_mode(self):
"""Test building parameter schema for workflow mode"""
app_mode = AppMode.WORKFLOW
parameters_dict: dict[str, str] = {"input_text": "Enter text"}
user_input_form = [
VariableEntity(
type=VariableEntityType.TEXT_INPUT,
variable="input_text",
description="Input text",
label="Input",
required=True,
)
]
schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
assert schema["type"] == "object"
assert "query" not in schema["properties"]
assert "input_text" in schema["properties"]
assert "input_text" in schema["required"]
def test_prepare_tool_arguments_chat_mode(self):
"""Test preparing tool arguments for chat mode"""
app = Mock(spec=App)
app.mode = AppMode.CHAT
arguments = {"query": "test question", "name": "John"}
result = prepare_tool_arguments(app, arguments)
assert result["query"] == "test question"
assert result["inputs"]["name"] == "John"
# Original arguments should not be modified
assert arguments["query"] == "test question"
def test_prepare_tool_arguments_workflow_mode(self):
"""Test preparing tool arguments for workflow mode"""
app = Mock(spec=App)
app.mode = AppMode.WORKFLOW
arguments = {"input_text": "test input"}
result = prepare_tool_arguments(app, arguments)
assert "inputs" in result
assert result["inputs"]["input_text"] == "test input"
def test_prepare_tool_arguments_completion_mode(self):
"""Test preparing tool arguments for completion mode"""
app = Mock(spec=App)
app.mode = AppMode.COMPLETION
arguments = {"name": "John"}
result = prepare_tool_arguments(app, arguments)
assert result["query"] == ""
assert result["inputs"]["name"] == "John"
def test_extract_answer_from_mapping_response_chat(self):
"""Test extracting answer from mapping response for chat mode"""
app = Mock(spec=App)
app.mode = AppMode.CHAT
response = {"answer": "test answer", "other": "data"}
result = extract_answer_from_response(app, response)
assert result == "test answer"
def test_extract_answer_from_mapping_response_workflow(self):
"""Test extracting answer from mapping response for workflow mode"""
app = Mock(spec=App)
app.mode = AppMode.WORKFLOW
response = {"data": {"outputs": {"result": "test result"}}}
result = extract_answer_from_response(app, response)
expected = json.dumps({"result": "test result"}, ensure_ascii=False)
assert result == expected
def test_extract_answer_from_streaming_response(self):
"""Test extracting answer from streaming response"""
app = Mock(spec=App)
# Mock RateLimitGenerator
mock_generator = Mock(spec=RateLimitGenerator)
mock_generator.generator = [
'data: {"event": "agent_thought", "thought": "thinking..."}',
'data: {"event": "agent_thought", "thought": "more thinking"}',
'data: {"event": "other", "content": "ignore this"}',
"not data format",
]
result = extract_answer_from_response(app, mock_generator)
assert result == "thinking...more thinking"
def test_process_mapping_response_invalid_mode(self):
"""Test processing mapping response with invalid app mode"""
app = Mock(spec=App)
app.mode = "invalid_mode"
response = {"answer": "test"}
with pytest.raises(ValueError, match="Invalid app mode"):
process_mapping_response(app, response)
def test_convert_input_form_to_parameters(self):
"""Test converting input form to parameters"""
user_input_form = [
VariableEntity(
type=VariableEntityType.TEXT_INPUT,
variable="name",
description="User name",
label="Name",
required=True,
),
VariableEntity(
type=VariableEntityType.SELECT,
variable="category",
description="Category",
label="Category",
required=False,
options=["A", "B", "C"],
),
VariableEntity(
type=VariableEntityType.NUMBER,
variable="count",
description="Count",
label="Count",
required=True,
),
VariableEntity(
type=VariableEntityType.FILE,
variable="upload",
description="File upload",
label="Upload",
required=False,
),
]
parameters_dict: dict[str, str] = {
"name": "Enter your name",
"category": "Select category",
"count": "Enter count",
}
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
# Check parameters
assert "name" in parameters
assert parameters["name"]["type"] == "string"
assert parameters["name"]["description"] == "Enter your name"
assert "category" in parameters
assert parameters["category"]["type"] == "string"
assert parameters["category"]["enum"] == ["A", "B", "C"]
assert "count" in parameters
assert parameters["count"]["type"] == "number"
# FILE type should be skipped - it creates empty dict but gets filtered later
# Check that it doesn't have any meaningful content
if "upload" in parameters:
assert parameters["upload"] == {}
# Check required fields
assert "name" in required
assert "count" in required
assert "category" not in required
# Note: _get_request_id function has been removed as request_id is now passed as parameter
def test_convert_input_form_to_parameters_jsonschema_validation_ok(self):
"""Current schema uses 'number' for numeric fields; it should be a valid JSON Schema."""
user_input_form = [
VariableEntity(
type=VariableEntityType.NUMBER,
variable="count",
description="Count",
label="Count",
required=True,
),
VariableEntity(
type=VariableEntityType.TEXT_INPUT,
variable="name",
description="User name",
label="Name",
required=False,
),
]
parameters_dict = {
"count": "Enter count",
"name": "Enter your name",
}
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
# Build a complete JSON Schema
schema = {
"type": "object",
"properties": parameters,
"required": required,
}
# 1) The schema itself must be valid
jsonschema.Draft202012Validator.check_schema(schema)
# 2) Both float and integer instances should pass validation
jsonschema.validate(instance={"count": 3.14, "name": "alice"}, schema=schema)
jsonschema.validate(instance={"count": 2, "name": "bob"}, schema=schema)
def test_legacy_float_type_schema_is_invalid(self):
"""Legacy/buggy behavior: using 'float' should produce an invalid JSON Schema."""
# Manually construct a legacy/incorrect schema (simulating old behavior)
bad_schema = {
"type": "object",
"properties": {
"count": {
"type": "float", # Invalid type: JSON Schema does not support 'float'
"description": "Enter count",
}
},
"required": ["count"],
}
# The schema itself should raise a SchemaError
with pytest.raises(jsonschema.exceptions.SchemaError):
jsonschema.Draft202012Validator.check_schema(bad_schema)
# Or validation should also raise SchemaError
with pytest.raises(jsonschema.exceptions.SchemaError):
jsonschema.validate(instance={"count": 1.23}, schema=bad_schema)

View File

@@ -0,0 +1,239 @@
"""Unit tests for MCP entities module."""
from unittest.mock import Mock
from core.mcp.entities import (
SUPPORTED_PROTOCOL_VERSIONS,
LifespanContextT,
RequestContext,
SessionT,
)
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
class TestProtocolVersions:
"""Test protocol version constants."""
def test_supported_protocol_versions(self):
"""Test supported protocol versions list."""
assert isinstance(SUPPORTED_PROTOCOL_VERSIONS, list)
assert len(SUPPORTED_PROTOCOL_VERSIONS) >= 3
assert "2024-11-05" in SUPPORTED_PROTOCOL_VERSIONS
assert "2025-03-26" in SUPPORTED_PROTOCOL_VERSIONS
assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS
def test_latest_protocol_version_is_supported(self):
"""Test that latest protocol version is in supported versions."""
assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS
class TestRequestContext:
"""Test RequestContext dataclass."""
def test_request_context_creation(self):
"""Test creating a RequestContext instance."""
mock_session = Mock(spec=BaseSession)
mock_lifespan = {"key": "value"}
mock_meta = RequestParams.Meta(progressToken="test-token")
context = RequestContext(
request_id="test-request-123",
meta=mock_meta,
session=mock_session,
lifespan_context=mock_lifespan,
)
assert context.request_id == "test-request-123"
assert context.meta == mock_meta
assert context.session == mock_session
assert context.lifespan_context == mock_lifespan
def test_request_context_with_none_meta(self):
"""Test creating RequestContext with None meta."""
mock_session = Mock(spec=BaseSession)
context = RequestContext(
request_id=42, # Can be int or string
meta=None,
session=mock_session,
lifespan_context=None,
)
assert context.request_id == 42
assert context.meta is None
assert context.session == mock_session
assert context.lifespan_context is None
def test_request_context_attributes(self):
"""Test RequestContext attributes are accessible."""
mock_session = Mock(spec=BaseSession)
context = RequestContext(
request_id="test-123",
meta=None,
session=mock_session,
lifespan_context=None,
)
# Verify attributes are accessible
assert hasattr(context, "request_id")
assert hasattr(context, "meta")
assert hasattr(context, "session")
assert hasattr(context, "lifespan_context")
# Verify values
assert context.request_id == "test-123"
assert context.meta is None
assert context.session == mock_session
assert context.lifespan_context is None
def test_request_context_generic_typing(self):
"""Test RequestContext with different generic types."""
# Create a mock session with specific type
mock_session = Mock(spec=BaseSession)
# Create context with string lifespan context
context_str = RequestContext[BaseSession, str](
request_id="test-1",
meta=None,
session=mock_session,
lifespan_context="string-context",
)
assert isinstance(context_str.lifespan_context, str)
# Create context with dict lifespan context
context_dict = RequestContext[BaseSession, dict](
request_id="test-2",
meta=None,
session=mock_session,
lifespan_context={"key": "value"},
)
assert isinstance(context_dict.lifespan_context, dict)
# Create context with custom object lifespan context
class CustomLifespan:
def __init__(self, data):
self.data = data
custom_lifespan = CustomLifespan("test-data")
context_custom = RequestContext[BaseSession, CustomLifespan](
request_id="test-3",
meta=None,
session=mock_session,
lifespan_context=custom_lifespan,
)
assert isinstance(context_custom.lifespan_context, CustomLifespan)
assert context_custom.lifespan_context.data == "test-data"
def test_request_context_with_progress_meta(self):
"""Test RequestContext with progress metadata."""
mock_session = Mock(spec=BaseSession)
progress_meta = RequestParams.Meta(progressToken="progress-123")
context = RequestContext(
request_id="req-456",
meta=progress_meta,
session=mock_session,
lifespan_context=None,
)
assert context.meta is not None
assert context.meta.progressToken == "progress-123"
def test_request_context_equality(self):
"""Test RequestContext equality comparison."""
mock_session1 = Mock(spec=BaseSession)
mock_session2 = Mock(spec=BaseSession)
context1 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session1,
lifespan_context="context",
)
context2 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session1,
lifespan_context="context",
)
context3 = RequestContext(
request_id="test-456",
meta=None,
session=mock_session1,
lifespan_context="context",
)
# Same values should be equal
assert context1 == context2
# Different request_id should not be equal
assert context1 != context3
# Different session should not be equal
context4 = RequestContext(
request_id="test-123",
meta=None,
session=mock_session2,
lifespan_context="context",
)
assert context1 != context4
def test_request_context_repr(self):
"""Test RequestContext string representation."""
mock_session = Mock(spec=BaseSession)
mock_session.__repr__ = Mock(return_value="<MockSession>")
context = RequestContext(
request_id="test-123",
meta=None,
session=mock_session,
lifespan_context={"data": "test"},
)
repr_str = repr(context)
assert "RequestContext" in repr_str
assert "test-123" in repr_str
assert "MockSession" in repr_str
class TestTypeVariables:
"""Test type variables defined in the module."""
def test_session_type_var(self):
"""Test SessionT type variable."""
# Create a custom session class
class CustomSession(BaseSession):
pass
# Use in generic context
def process_session(session: SessionT) -> SessionT:
return session
mock_session = Mock(spec=CustomSession)
result = process_session(mock_session)
assert result == mock_session
def test_lifespan_context_type_var(self):
"""Test LifespanContextT type variable."""
# Use in generic context
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
return context
# Test with different types
str_context = "string-context"
assert process_lifespan(str_context) == str_context
dict_context = {"key": "value"}
assert process_lifespan(dict_context) == dict_context
class CustomContext:
pass
custom_context = CustomContext()
assert process_lifespan(custom_context) == custom_context

View File

@@ -0,0 +1,205 @@
"""Unit tests for MCP error classes."""
import pytest
from core.mcp.error import MCPAuthError, MCPConnectionError, MCPError
class TestMCPError:
"""Test MCPError base exception class."""
def test_mcp_error_creation(self):
"""Test creating MCPError instance."""
error = MCPError("Test error message")
assert str(error) == "Test error message"
assert isinstance(error, Exception)
def test_mcp_error_inheritance(self):
"""Test MCPError inherits from Exception."""
error = MCPError()
assert isinstance(error, Exception)
assert type(error).__name__ == "MCPError"
def test_mcp_error_with_empty_message(self):
"""Test MCPError with empty message."""
error = MCPError()
assert str(error) == ""
def test_mcp_error_raise(self):
"""Test raising MCPError."""
with pytest.raises(MCPError) as exc_info:
raise MCPError("Something went wrong")
assert str(exc_info.value) == "Something went wrong"
class TestMCPConnectionError:
"""Test MCPConnectionError exception class."""
def test_mcp_connection_error_creation(self):
"""Test creating MCPConnectionError instance."""
error = MCPConnectionError("Connection failed")
assert str(error) == "Connection failed"
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_connection_error_inheritance(self):
"""Test MCPConnectionError inheritance chain."""
error = MCPConnectionError()
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_connection_error_raise(self):
"""Test raising MCPConnectionError."""
with pytest.raises(MCPConnectionError) as exc_info:
raise MCPConnectionError("Unable to connect to server")
assert str(exc_info.value) == "Unable to connect to server"
def test_mcp_connection_error_catch_as_mcp_error(self):
"""Test catching MCPConnectionError as MCPError."""
with pytest.raises(MCPError) as exc_info:
raise MCPConnectionError("Connection issue")
assert isinstance(exc_info.value, MCPConnectionError)
assert str(exc_info.value) == "Connection issue"
class TestMCPAuthError:
"""Test MCPAuthError exception class."""
def test_mcp_auth_error_creation(self):
"""Test creating MCPAuthError instance."""
error = MCPAuthError("Authentication failed")
assert str(error) == "Authentication failed"
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_auth_error_inheritance(self):
"""Test MCPAuthError inheritance chain."""
error = MCPAuthError()
assert isinstance(error, MCPAuthError)
assert isinstance(error, MCPConnectionError)
assert isinstance(error, MCPError)
assert isinstance(error, Exception)
def test_mcp_auth_error_raise(self):
"""Test raising MCPAuthError."""
with pytest.raises(MCPAuthError) as exc_info:
raise MCPAuthError("Invalid credentials")
assert str(exc_info.value) == "Invalid credentials"
def test_mcp_auth_error_catch_hierarchy(self):
"""Test catching MCPAuthError at different levels."""
# Catch as MCPAuthError
with pytest.raises(MCPAuthError) as exc_info:
raise MCPAuthError("Auth specific error")
assert str(exc_info.value) == "Auth specific error"
# Catch as MCPConnectionError
with pytest.raises(MCPConnectionError) as exc_info:
raise MCPAuthError("Auth connection error")
assert isinstance(exc_info.value, MCPAuthError)
assert str(exc_info.value) == "Auth connection error"
# Catch as MCPError
with pytest.raises(MCPError) as exc_info:
raise MCPAuthError("Auth base error")
assert isinstance(exc_info.value, MCPAuthError)
assert str(exc_info.value) == "Auth base error"
class TestErrorHierarchy:
"""Test the complete error hierarchy."""
def test_exception_hierarchy(self):
"""Test the complete exception hierarchy."""
# Create instances
base_error = MCPError("base")
connection_error = MCPConnectionError("connection")
auth_error = MCPAuthError("auth")
# Test type relationships
assert not isinstance(base_error, MCPConnectionError)
assert not isinstance(base_error, MCPAuthError)
assert isinstance(connection_error, MCPError)
assert not isinstance(connection_error, MCPAuthError)
assert isinstance(auth_error, MCPError)
assert isinstance(auth_error, MCPConnectionError)
def test_error_handling_patterns(self):
"""Test common error handling patterns."""
def raise_auth_error():
raise MCPAuthError("401 Unauthorized")
def raise_connection_error():
raise MCPConnectionError("Connection timeout")
def raise_base_error():
raise MCPError("Generic error")
# Pattern 1: Catch specific errors first
errors_caught = []
for error_func in [raise_auth_error, raise_connection_error, raise_base_error]:
try:
error_func()
except MCPAuthError:
errors_caught.append("auth")
except MCPConnectionError:
errors_caught.append("connection")
except MCPError:
errors_caught.append("base")
assert errors_caught == ["auth", "connection", "base"]
# Pattern 2: Catch all as base error
for error_func in [raise_auth_error, raise_connection_error, raise_base_error]:
with pytest.raises(MCPError) as exc_info:
error_func()
assert isinstance(exc_info.value, MCPError)
def test_error_with_cause(self):
"""Test errors with cause (chained exceptions)."""
original_error = ValueError("Original error")
def raise_chained_error():
try:
raise original_error
except ValueError as e:
raise MCPConnectionError("Connection failed") from e
with pytest.raises(MCPConnectionError) as exc_info:
raise_chained_error()
assert str(exc_info.value) == "Connection failed"
assert exc_info.value.__cause__ == original_error
def test_error_comparison(self):
"""Test error instance comparison."""
error1 = MCPError("Test message")
error2 = MCPError("Test message")
error3 = MCPError("Different message")
# Errors are not equal even with same message (different instances)
assert error1 != error2
assert error1 != error3
# But they have the same type
assert type(error1) == type(error2) == type(error3)
def test_error_representation(self):
"""Test error string representation."""
base_error = MCPError("Base error message")
connection_error = MCPConnectionError("Connection error message")
auth_error = MCPAuthError("Auth error message")
assert repr(base_error) == "MCPError('Base error message')"
assert repr(connection_error) == "MCPConnectionError('Connection error message')"
assert repr(auth_error) == "MCPAuthError('Auth error message')"

View File

@@ -0,0 +1,382 @@
"""Unit tests for MCP client."""
from contextlib import ExitStack
from types import TracebackType
from unittest.mock import Mock, patch
import pytest
from core.mcp.error import MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import CallToolResult, ListToolsResult, TextContent, Tool, ToolAnnotations
class TestMCPClient:
"""Test suite for MCPClient."""
def test_init(self):
"""Test client initialization."""
client = MCPClient(
server_url="http://test.example.com/mcp",
headers={"Authorization": "Bearer test"},
timeout=30.0,
sse_read_timeout=60.0,
)
assert client.server_url == "http://test.example.com/mcp"
assert client.headers == {"Authorization": "Bearer test"}
assert client.timeout == 30.0
assert client.sse_read_timeout == 60.0
assert client._session is None
assert isinstance(client._exit_stack, ExitStack)
assert client._initialized is False
def test_init_defaults(self):
"""Test client initialization with defaults."""
client = MCPClient(server_url="http://test.example.com")
assert client.server_url == "http://test.example.com"
assert client.headers == {}
assert client.timeout is None
assert client.sse_read_timeout is None
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_mcp_url(self, mock_client_session, mock_streamable_client):
"""Test initialization with MCP URL."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/mcp")
client._initialize()
# Verify streamable client was called
mock_streamable_client.assert_called_once_with(
url="http://test.example.com/mcp",
headers={},
timeout=None,
sse_read_timeout=None,
)
# Verify session was created
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_sse_url(self, mock_client_session, mock_sse_client):
"""Test initialization with SSE URL."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/sse")
client._initialize()
# Verify SSE client was called
mock_sse_client.assert_called_once_with(
url="http://test.example.com/sse",
headers={},
timeout=None,
sse_read_timeout=None,
)
# Verify session was created
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_with_unknown_method_fallback_to_sse(
self, mock_client_session, mock_streamable_client, mock_sse_client
):
"""Test initialization with unknown method falls back to SSE."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/unknown")
client._initialize()
# Verify SSE client was tried
mock_sse_client.assert_called_once()
mock_streamable_client.assert_not_called()
# Verify session was created
assert client._session == mock_session
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_initialize_fallback_from_sse_to_mcp(self, mock_client_session, mock_streamable_client, mock_sse_client):
"""Test initialization falls back from SSE to MCP on connection error."""
# Setup SSE to fail
mock_sse_client.side_effect = MCPConnectionError("SSE connection failed")
# Setup MCP to succeed
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com/unknown")
client._initialize()
# Verify both were tried
mock_sse_client.assert_called_once()
mock_streamable_client.assert_called_once()
# Verify session was created with MCP
assert client._session == mock_session
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_connect_server_mcp(self, mock_client_session, mock_streamable_client):
"""Test connect_server with MCP method."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com")
client.connect_server(mock_streamable_client, "mcp")
# Verify correct streams were passed
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
@patch("core.mcp.mcp_client.sse_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_connect_server_sse(self, mock_client_session, mock_sse_client):
"""Test connect_server with SSE method."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(server_url="http://test.example.com")
client.connect_server(mock_sse_client, "sse")
# Verify correct streams were passed
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_session.initialize.assert_called_once()
def test_context_manager_enter(self):
"""Test context manager enter."""
client = MCPClient(server_url="http://test.example.com")
with patch.object(client, "_initialize") as mock_initialize:
result = client.__enter__()
assert result == client
assert client._initialized is True
mock_initialize.assert_called_once()
def test_context_manager_exit(self):
"""Test context manager exit."""
client = MCPClient(server_url="http://test.example.com")
with patch.object(client, "cleanup") as mock_cleanup:
exc_type: type[BaseException] | None = None
exc_val: BaseException | None = None
exc_tb: TracebackType | None = None
client.__exit__(exc_type, exc_val, exc_tb)
mock_cleanup.assert_called_once()
def test_list_tools_not_initialized(self):
"""Test list_tools when session not initialized."""
client = MCPClient(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.list_tools()
assert "Session not initialized" in str(exc_info.value)
def test_list_tools_success(self):
"""Test successful list_tools call."""
client = MCPClient(server_url="http://test.example.com")
# Setup mock session
mock_session = Mock()
expected_tools = [
Tool(
name="test-tool",
description="A test tool",
inputSchema={"type": "object", "properties": {}},
annotations=ToolAnnotations(title="Test Tool"),
)
]
mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools)
client._session = mock_session
result = client.list_tools()
assert result == expected_tools
mock_session.list_tools.assert_called_once()
def test_invoke_tool_not_initialized(self):
"""Test invoke_tool when session not initialized."""
client = MCPClient(server_url="http://test.example.com")
with pytest.raises(ValueError) as exc_info:
client.invoke_tool("test-tool", {"arg": "value"})
assert "Session not initialized" in str(exc_info.value)
def test_invoke_tool_success(self):
"""Test successful invoke_tool call."""
client = MCPClient(server_url="http://test.example.com")
# Setup mock session
mock_session = Mock()
expected_result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")],
isError=False,
)
mock_session.call_tool.return_value = expected_result
client._session = mock_session
result = client.invoke_tool("test-tool", {"arg": "value"})
assert result == expected_result
mock_session.call_tool.assert_called_once_with("test-tool", {"arg": "value"})
def test_cleanup(self):
"""Test cleanup method."""
client = MCPClient(server_url="http://test.example.com")
mock_exit_stack = Mock(spec=ExitStack)
client._exit_stack = mock_exit_stack
client._session = Mock()
client._initialized = True
client.cleanup()
mock_exit_stack.close.assert_called_once()
assert client._session is None
assert client._initialized is False
def test_cleanup_with_error(self):
"""Test cleanup method with error."""
client = MCPClient(server_url="http://test.example.com")
mock_exit_stack = Mock(spec=ExitStack)
mock_exit_stack.close.side_effect = Exception("Cleanup error")
client._exit_stack = mock_exit_stack
client._session = Mock()
client._initialized = True
with pytest.raises(ValueError) as exc_info:
client.cleanup()
assert "Error during cleanup: Cleanup error" in str(exc_info.value)
assert client._session is None
assert client._initialized is False
@patch("core.mcp.mcp_client.streamablehttp_client")
@patch("core.mcp.mcp_client.ClientSession")
def test_full_context_manager_flow(self, mock_client_session, mock_streamable_client):
"""Test full context manager flow."""
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
expected_tools = [Tool(name="test-tool", description="Test", inputSchema={})]
mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools)
with MCPClient(server_url="http://test.example.com/mcp") as client:
assert client._initialized is True
assert client._session == mock_session
# Test tool operations
tools = client.list_tools()
assert tools == expected_tools
# After exit, should be cleaned up
assert client._initialized is False
assert client._session is None
def test_headers_passed_to_clients(self):
"""Test that headers are properly passed to underlying clients."""
custom_headers = {
"Authorization": "Bearer test-token",
"X-Custom-Header": "test-value",
}
with patch("core.mcp.mcp_client.streamablehttp_client") as mock_streamable_client:
with patch("core.mcp.mcp_client.ClientSession") as mock_client_session:
# Setup mocks
mock_read_stream = Mock()
mock_write_stream = Mock()
mock_client_context = Mock()
mock_streamable_client.return_value.__enter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_client_context,
)
mock_session = Mock()
mock_client_session.return_value.__enter__.return_value = mock_session
client = MCPClient(
server_url="http://test.example.com/mcp",
headers=custom_headers,
timeout=30.0,
sse_read_timeout=60.0,
)
client._initialize()
# Verify headers were passed
mock_streamable_client.assert_called_once_with(
url="http://test.example.com/mcp",
headers=custom_headers,
timeout=30.0,
sse_read_timeout=60.0,
)

View File

@@ -0,0 +1,492 @@
"""Unit tests for MCP types module."""
import pytest
from pydantic import ValidationError
from core.mcp.types import (
INTERNAL_ERROR,
INVALID_PARAMS,
INVALID_REQUEST,
LATEST_PROTOCOL_VERSION,
METHOD_NOT_FOUND,
PARSE_ERROR,
SERVER_LATEST_PROTOCOL_VERSION,
Annotations,
CallToolRequest,
CallToolRequestParams,
CallToolResult,
ClientCapabilities,
CompleteRequest,
CompleteRequestParams,
CompleteResult,
Completion,
CompletionArgument,
CompletionContext,
ErrorData,
ImageContent,
Implementation,
InitializeRequest,
InitializeRequestParams,
InitializeResult,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ListToolsRequest,
ListToolsResult,
OAuthClientInformation,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
PingRequest,
ProgressNotification,
ProgressNotificationParams,
PromptReference,
RequestParams,
ResourceTemplateReference,
Result,
ServerCapabilities,
TextContent,
Tool,
ToolAnnotations,
)
class TestConstants:
"""Test module constants."""
def test_protocol_versions(self):
"""Test protocol version constants."""
assert LATEST_PROTOCOL_VERSION == "2025-06-18"
assert SERVER_LATEST_PROTOCOL_VERSION == "2024-11-05"
def test_error_codes(self):
"""Test JSON-RPC error code constants."""
assert PARSE_ERROR == -32700
assert INVALID_REQUEST == -32600
assert METHOD_NOT_FOUND == -32601
assert INVALID_PARAMS == -32602
assert INTERNAL_ERROR == -32603
class TestRequestParams:
"""Test RequestParams and related classes."""
def test_request_params_basic(self):
"""Test basic RequestParams creation."""
params = RequestParams()
assert params.meta is None
def test_request_params_with_meta(self):
"""Test RequestParams with meta."""
meta = RequestParams.Meta(progressToken="test-token")
params = RequestParams(_meta=meta)
assert params.meta is not None
assert params.meta.progressToken == "test-token"
def test_request_params_meta_extra_fields(self):
"""Test RequestParams.Meta allows extra fields."""
meta = RequestParams.Meta(progressToken="token", customField="value")
assert meta.progressToken == "token"
assert meta.customField == "value" # type: ignore
def test_request_params_serialization(self):
"""Test RequestParams serialization with _meta alias."""
meta = RequestParams.Meta(progressToken="test")
params = RequestParams(_meta=meta)
# Model dump should use the alias
dumped = params.model_dump(by_alias=True)
assert "_meta" in dumped
assert dumped["_meta"] is not None
assert dumped["_meta"]["progressToken"] == "test"
class TestJSONRPCMessages:
"""Test JSON-RPC message types."""
def test_jsonrpc_request(self):
"""Test JSONRPCRequest creation and validation."""
request = JSONRPCRequest(jsonrpc="2.0", id="test-123", method="test_method", params={"key": "value"})
assert request.jsonrpc == "2.0"
assert request.id == "test-123"
assert request.method == "test_method"
assert request.params == {"key": "value"}
def test_jsonrpc_request_numeric_id(self):
"""Test JSONRPCRequest with numeric ID."""
request = JSONRPCRequest(jsonrpc="2.0", id=123, method="test", params=None)
assert request.id == 123
def test_jsonrpc_notification(self):
"""Test JSONRPCNotification creation."""
notification = JSONRPCNotification(jsonrpc="2.0", method="notification_method", params={"data": "test"})
assert notification.jsonrpc == "2.0"
assert notification.method == "notification_method"
assert not hasattr(notification, "id") # Notifications don't have ID
def test_jsonrpc_response(self):
"""Test JSONRPCResponse creation."""
response = JSONRPCResponse(jsonrpc="2.0", id="req-123", result={"success": True})
assert response.jsonrpc == "2.0"
assert response.id == "req-123"
assert response.result == {"success": True}
def test_jsonrpc_error(self):
"""Test JSONRPCError creation."""
error_data = ErrorData(code=INVALID_PARAMS, message="Invalid parameters", data={"field": "missing"})
error = JSONRPCError(jsonrpc="2.0", id="req-123", error=error_data)
assert error.jsonrpc == "2.0"
assert error.id == "req-123"
assert error.error.code == INVALID_PARAMS
assert error.error.message == "Invalid parameters"
assert error.error.data == {"field": "missing"}
def test_jsonrpc_message_parsing(self):
"""Test JSONRPCMessage parsing different message types."""
# Parse request
request_json = '{"jsonrpc": "2.0", "id": 1, "method": "test", "params": null}'
msg = JSONRPCMessage.model_validate_json(request_json)
assert isinstance(msg.root, JSONRPCRequest)
# Parse response
response_json = '{"jsonrpc": "2.0", "id": 1, "result": {"data": "test"}}'
msg = JSONRPCMessage.model_validate_json(response_json)
assert isinstance(msg.root, JSONRPCResponse)
# Parse error
error_json = '{"jsonrpc": "2.0", "id": 1, "error": {"code": -32600, "message": "Invalid Request"}}'
msg = JSONRPCMessage.model_validate_json(error_json)
assert isinstance(msg.root, JSONRPCError)
class TestCapabilities:
"""Test capability classes."""
def test_client_capabilities(self):
"""Test ClientCapabilities creation."""
caps = ClientCapabilities(
experimental={"feature": {"enabled": True}},
sampling={"model_config": {"extra": "allow"}},
roots={"listChanged": True},
)
assert caps.experimental == {"feature": {"enabled": True}}
assert caps.sampling is not None
assert caps.roots.listChanged is True # type: ignore
def test_server_capabilities(self):
"""Test ServerCapabilities creation."""
caps = ServerCapabilities(
tools={"listChanged": True},
resources={"subscribe": True, "listChanged": False},
prompts={"listChanged": True},
logging={},
completions={},
)
assert caps.tools.listChanged is True # type: ignore
assert caps.resources.subscribe is True # type: ignore
assert caps.resources.listChanged is False # type: ignore
class TestInitialization:
"""Test initialization request/response types."""
def test_initialize_request(self):
"""Test InitializeRequest creation."""
client_info = Implementation(name="test-client", version="1.0.0")
capabilities = ClientCapabilities()
params = InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=capabilities, clientInfo=client_info
)
request = InitializeRequest(params=params)
assert request.method == "initialize"
assert request.params.protocolVersion == LATEST_PROTOCOL_VERSION
assert request.params.clientInfo.name == "test-client"
def test_initialize_result(self):
"""Test InitializeResult creation."""
server_info = Implementation(name="test-server", version="1.0.0")
capabilities = ServerCapabilities()
result = InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=capabilities,
serverInfo=server_info,
instructions="Welcome to test server",
)
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
assert result.serverInfo.name == "test-server"
assert result.instructions == "Welcome to test server"
class TestTools:
"""Test tool-related types."""
def test_tool_creation(self):
"""Test Tool creation with all fields."""
tool = Tool(
name="test_tool",
title="Test Tool",
description="A tool for testing",
inputSchema={"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]},
outputSchema={"type": "object", "properties": {"result": {"type": "string"}}},
annotations=ToolAnnotations(
title="Test Tool", readOnlyHint=False, destructiveHint=False, idempotentHint=True
),
)
assert tool.name == "test_tool"
assert tool.title == "Test Tool"
assert tool.description == "A tool for testing"
assert tool.inputSchema["properties"]["input"]["type"] == "string"
assert tool.annotations.idempotentHint is True
def test_call_tool_request(self):
"""Test CallToolRequest creation."""
params = CallToolRequestParams(name="test_tool", arguments={"input": "test value"})
request = CallToolRequest(params=params)
assert request.method == "tools/call"
assert request.params.name == "test_tool"
assert request.params.arguments == {"input": "test value"}
def test_call_tool_result(self):
"""Test CallToolResult creation."""
result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")],
structuredContent={"status": "success", "data": "test"},
isError=False,
)
assert len(result.content) == 1
assert result.content[0].text == "Tool executed successfully" # type: ignore
assert result.structuredContent == {"status": "success", "data": "test"}
assert result.isError is False
def test_list_tools_request(self):
"""Test ListToolsRequest creation."""
request = ListToolsRequest()
assert request.method == "tools/list"
def test_list_tools_result(self):
"""Test ListToolsResult creation."""
tool1 = Tool(name="tool1", inputSchema={})
tool2 = Tool(name="tool2", inputSchema={})
result = ListToolsResult(tools=[tool1, tool2])
assert len(result.tools) == 2
assert result.tools[0].name == "tool1"
assert result.tools[1].name == "tool2"
class TestContent:
"""Test content types."""
def test_text_content(self):
"""Test TextContent creation."""
annotations = Annotations(audience=["user"], priority=0.8)
content = TextContent(type="text", text="Hello, world!", annotations=annotations)
assert content.type == "text"
assert content.text == "Hello, world!"
assert content.annotations is not None
assert content.annotations.priority == 0.8
def test_image_content(self):
"""Test ImageContent creation."""
content = ImageContent(type="image", data="base64encodeddata", mimeType="image/png")
assert content.type == "image"
assert content.data == "base64encodeddata"
assert content.mimeType == "image/png"
class TestOAuth:
"""Test OAuth-related types."""
def test_oauth_client_metadata(self):
"""Test OAuthClientMetadata creation."""
metadata = OAuthClientMetadata(
client_name="Test Client",
redirect_uris=["https://example.com/callback"],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
token_endpoint_auth_method="none",
client_uri="https://example.com",
scope="read write",
)
assert metadata.client_name == "Test Client"
assert len(metadata.redirect_uris) == 1
assert "authorization_code" in metadata.grant_types
def test_oauth_client_information(self):
"""Test OAuthClientInformation creation."""
info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret")
assert info.client_id == "test-client-id"
assert info.client_secret == "test-secret"
def test_oauth_client_information_without_secret(self):
"""Test OAuthClientInformation without secret."""
info = OAuthClientInformation(client_id="public-client")
assert info.client_id == "public-client"
assert info.client_secret is None
def test_oauth_tokens(self):
"""Test OAuthTokens creation."""
tokens = OAuthTokens(
access_token="access-token-123",
token_type="Bearer",
expires_in=3600,
refresh_token="refresh-token-456",
scope="read write",
)
assert tokens.access_token == "access-token-123"
assert tokens.token_type == "Bearer"
assert tokens.expires_in == 3600
assert tokens.refresh_token == "refresh-token-456"
assert tokens.scope == "read write"
def test_oauth_metadata(self):
"""Test OAuthMetadata creation."""
metadata = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint="https://auth.example.com/register",
response_types_supported=["code", "token"],
grant_types_supported=["authorization_code", "refresh_token"],
code_challenge_methods_supported=["plain", "S256"],
)
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert "code" in metadata.response_types_supported
assert "S256" in metadata.code_challenge_methods_supported
class TestNotifications:
"""Test notification types."""
def test_progress_notification(self):
"""Test ProgressNotification creation."""
params = ProgressNotificationParams(
progressToken="progress-123", progress=50.0, total=100.0, message="Processing... 50%"
)
notification = ProgressNotification(params=params)
assert notification.method == "notifications/progress"
assert notification.params.progressToken == "progress-123"
assert notification.params.progress == 50.0
assert notification.params.total == 100.0
assert notification.params.message == "Processing... 50%"
def test_ping_request(self):
"""Test PingRequest creation."""
request = PingRequest()
assert request.method == "ping"
assert request.params is None
class TestCompletion:
"""Test completion-related types."""
def test_completion_context(self):
"""Test CompletionContext creation."""
context = CompletionContext(arguments={"template_var": "value"})
assert context.arguments == {"template_var": "value"}
def test_resource_template_reference(self):
"""Test ResourceTemplateReference creation."""
ref = ResourceTemplateReference(type="ref/resource", uri="file:///path/to/{filename}")
assert ref.type == "ref/resource"
assert ref.uri == "file:///path/to/{filename}"
def test_prompt_reference(self):
"""Test PromptReference creation."""
ref = PromptReference(type="ref/prompt", name="test_prompt")
assert ref.type == "ref/prompt"
assert ref.name == "test_prompt"
def test_complete_request(self):
"""Test CompleteRequest creation."""
ref = PromptReference(type="ref/prompt", name="test_prompt")
arg = CompletionArgument(name="arg1", value="val")
params = CompleteRequestParams(ref=ref, argument=arg, context=CompletionContext(arguments={"key": "value"}))
request = CompleteRequest(params=params)
assert request.method == "completion/complete"
assert request.params.ref.name == "test_prompt" # type: ignore
assert request.params.argument.name == "arg1"
def test_complete_result(self):
"""Test CompleteResult creation."""
completion = Completion(values=["option1", "option2", "option3"], total=10, hasMore=True)
result = CompleteResult(completion=completion)
assert len(result.completion.values) == 3
assert result.completion.total == 10
assert result.completion.hasMore is True
class TestValidation:
"""Test validation of various types."""
def test_invalid_jsonrpc_version(self):
"""Test invalid JSON-RPC version validation."""
with pytest.raises(ValidationError):
JSONRPCRequest(
jsonrpc="1.0", # Invalid version
id=1,
method="test",
)
def test_tool_annotations_validation(self):
"""Test ToolAnnotations with invalid values."""
# Valid annotations
annotations = ToolAnnotations(
title="Test", readOnlyHint=True, destructiveHint=False, idempotentHint=True, openWorldHint=False
)
assert annotations.title == "Test"
def test_extra_fields_allowed(self):
"""Test that extra fields are allowed in models."""
# Most models should allow extra fields
tool = Tool(
name="test",
inputSchema={},
customField="allowed", # type: ignore
)
assert tool.customField == "allowed" # type: ignore
def test_result_meta_alias(self):
"""Test Result model with _meta alias."""
# Create with the field name (not alias)
result = Result(_meta={"key": "value"})
# Verify the field is set correctly
assert result.meta == {"key": "value"}
# Dump with alias
dumped = result.model_dump(by_alias=True)
assert "_meta" in dumped
assert dumped["_meta"] == {"key": "value"}

View File

@@ -0,0 +1,355 @@
"""Unit tests for MCP utils module."""
import json
from collections.abc import Generator
from unittest.mock import MagicMock, Mock, patch
import httpx
import httpx_sse
import pytest
from core.mcp.utils import (
STATUS_FORCELIST,
create_mcp_error_response,
create_ssrf_proxy_mcp_http_client,
ssrf_proxy_sse_connect,
)
class TestConstants:
"""Test module constants."""
def test_status_forcelist(self):
"""Test STATUS_FORCELIST contains expected HTTP status codes."""
assert STATUS_FORCELIST == [429, 500, 502, 503, 504]
assert 429 in STATUS_FORCELIST # Too Many Requests
assert 500 in STATUS_FORCELIST # Internal Server Error
assert 502 in STATUS_FORCELIST # Bad Gateway
assert 503 in STATUS_FORCELIST # Service Unavailable
assert 504 in STATUS_FORCELIST # Gateway Timeout
class TestCreateSSRFProxyMCPHTTPClient:
"""Test create_ssrf_proxy_mcp_http_client function."""
@patch("core.mcp.utils.dify_config")
def test_create_client_with_all_url_proxy(self, mock_config):
"""Test client creation with SSRF_PROXY_ALL_URL configured."""
mock_config.SSRF_PROXY_ALL_URL = "http://proxy.example.com:8080"
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
client = create_ssrf_proxy_mcp_http_client(
headers={"Authorization": "Bearer token"}, timeout=httpx.Timeout(30.0)
)
assert isinstance(client, httpx.Client)
assert client.headers["Authorization"] == "Bearer token"
assert client.timeout.connect == 30.0
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_with_http_https_proxies(self, mock_config):
"""Test client creation with separate HTTP/HTTPS proxies."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = "http://http-proxy.example.com:8080"
mock_config.SSRF_PROXY_HTTPS_URL = "http://https-proxy.example.com:8443"
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = False
client = create_ssrf_proxy_mcp_http_client()
assert isinstance(client, httpx.Client)
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_without_proxy(self, mock_config):
"""Test client creation without proxy configuration."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = None
mock_config.SSRF_PROXY_HTTPS_URL = None
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
headers = {"X-Custom-Header": "value"}
timeout = httpx.Timeout(timeout=30.0, connect=5.0, read=10.0, write=30.0)
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
assert isinstance(client, httpx.Client)
assert client.headers["X-Custom-Header"] == "value"
assert client.timeout.connect == 5.0
assert client.timeout.read == 10.0
assert client.follow_redirects is True
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
def test_create_client_default_params(self, mock_config):
"""Test client creation with default parameters."""
mock_config.SSRF_PROXY_ALL_URL = None
mock_config.SSRF_PROXY_HTTP_URL = None
mock_config.SSRF_PROXY_HTTPS_URL = None
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
client = create_ssrf_proxy_mcp_http_client()
assert isinstance(client, httpx.Client)
# httpx.Client adds default headers, so we just check it's a Headers object
assert isinstance(client.headers, httpx.Headers)
# When no timeout is provided, httpx uses its default timeout
assert client.timeout is not None
# Clean up
client.close()
class TestSSRFProxySSEConnect:
"""Test ssrf_proxy_sse_connect function."""
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_with_provided_client(self, mock_create_client, mock_connect_sse):
"""Test SSE connection with pre-configured client."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
# Call with provided client
result = ssrf_proxy_sse_connect(
"http://example.com/sse", client=mock_client, method="POST", headers={"Authorization": "Bearer token"}
)
# Verify client creation was not called
mock_create_client.assert_not_called()
# Verify connect_sse was called correctly
mock_connect_sse.assert_called_once_with(
mock_client, "POST", "http://example.com/sse", headers={"Authorization": "Bearer token"}
)
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
@patch("core.mcp.utils.dify_config")
def test_sse_connect_without_client(self, mock_config, mock_create_client, mock_connect_sse):
"""Test SSE connection without pre-configured client."""
# Setup config
mock_config.SSRF_DEFAULT_TIME_OUT = 30.0
mock_config.SSRF_DEFAULT_CONNECT_TIME_OUT = 10.0
mock_config.SSRF_DEFAULT_READ_TIME_OUT = 60.0
mock_config.SSRF_DEFAULT_WRITE_TIME_OUT = 30.0
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
# Call without client
result = ssrf_proxy_sse_connect("http://example.com/sse", headers={"X-Custom": "value"})
# Verify client was created
mock_create_client.assert_called_once()
call_args = mock_create_client.call_args
assert call_args[1]["headers"] == {"X-Custom": "value"}
timeout = call_args[1]["timeout"]
# httpx.Timeout object has these attributes
assert isinstance(timeout, httpx.Timeout)
assert timeout.connect == 10.0
assert timeout.read == 60.0
assert timeout.write == 30.0
# Verify connect_sse was called
mock_connect_sse.assert_called_once_with(
mock_client,
"GET", # Default method
"http://example.com/sse",
)
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_with_custom_timeout(self, mock_create_client, mock_connect_sse):
"""Test SSE connection with custom timeout."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
mock_event_source = Mock(spec=httpx_sse.EventSource)
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_event_source
mock_connect_sse.return_value = mock_context
custom_timeout = httpx.Timeout(timeout=60.0, read=120.0)
# Call with custom timeout
result = ssrf_proxy_sse_connect("http://example.com/sse", timeout=custom_timeout)
# Verify client was created with custom timeout
mock_create_client.assert_called_once()
call_args = mock_create_client.call_args
assert call_args[1]["timeout"] == custom_timeout
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
def test_sse_connect_error_cleanup(self, mock_create_client, mock_connect_sse):
"""Test SSE connection cleans up client on error."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
mock_create_client.return_value = mock_client
# Make connect_sse raise an exception
mock_connect_sse.side_effect = httpx.ConnectError("Connection failed")
# Call should raise the exception
with pytest.raises(httpx.ConnectError):
ssrf_proxy_sse_connect("http://example.com/sse")
# Verify client was cleaned up
mock_client.close.assert_called_once()
@patch("core.mcp.utils.connect_sse")
def test_sse_connect_error_no_cleanup_with_provided_client(self, mock_connect_sse):
"""Test SSE connection doesn't clean up provided client on error."""
# Setup mocks
mock_client = Mock(spec=httpx.Client)
# Make connect_sse raise an exception
mock_connect_sse.side_effect = httpx.ConnectError("Connection failed")
# Call should raise the exception
with pytest.raises(httpx.ConnectError):
ssrf_proxy_sse_connect("http://example.com/sse", client=mock_client)
# Verify client was NOT cleaned up (because it was provided)
mock_client.close.assert_not_called()
class TestCreateMCPErrorResponse:
"""Test create_mcp_error_response function."""
def test_create_error_response_basic(self):
"""Test creating basic error response."""
generator = create_mcp_error_response(request_id="req-123", code=-32600, message="Invalid Request")
# Generator should yield bytes
assert isinstance(generator, Generator)
# Get the response
response_bytes = next(generator)
assert isinstance(response_bytes, bytes)
# Parse the response
response_str = response_bytes.decode("utf-8")
response_json = json.loads(response_str)
assert response_json["jsonrpc"] == "2.0"
assert response_json["id"] == "req-123"
assert response_json["error"]["code"] == -32600
assert response_json["error"]["message"] == "Invalid Request"
assert response_json["error"]["data"] is None
# Generator should be exhausted
with pytest.raises(StopIteration):
next(generator)
def test_create_error_response_with_data(self):
"""Test creating error response with additional data."""
error_data = {"field": "username", "reason": "required"}
generator = create_mcp_error_response(
request_id=456, # Numeric ID
code=-32602,
message="Invalid params",
data=error_data,
)
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
assert response_json["id"] == 456
assert response_json["error"]["code"] == -32602
assert response_json["error"]["message"] == "Invalid params"
assert response_json["error"]["data"] == error_data
def test_create_error_response_without_request_id(self):
"""Test creating error response without request ID."""
generator = create_mcp_error_response(request_id=None, code=-32700, message="Parse error")
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
# Should default to ID 1
assert response_json["id"] == 1
assert response_json["error"]["code"] == -32700
assert response_json["error"]["message"] == "Parse error"
def test_create_error_response_with_complex_data(self):
"""Test creating error response with complex error data."""
complex_data = {
"errors": [{"field": "name", "message": "Too short"}, {"field": "email", "message": "Invalid format"}],
"timestamp": "2024-01-01T00:00:00Z",
}
generator = create_mcp_error_response(
request_id="complex-req", code=-32602, message="Validation failed", data=complex_data
)
response_bytes = next(generator)
response_json = json.loads(response_bytes.decode("utf-8"))
assert response_json["error"]["data"] == complex_data
assert len(response_json["error"]["data"]["errors"]) == 2
def test_create_error_response_encoding(self):
"""Test error response with non-ASCII characters."""
generator = create_mcp_error_response(
request_id="unicode-req",
code=-32603,
message="内部错误", # Chinese characters
data={"details": "エラー詳細"}, # Japanese characters
)
response_bytes = next(generator)
# Should be valid UTF-8
response_str = response_bytes.decode("utf-8")
response_json = json.loads(response_str)
assert response_json["error"]["message"] == "内部错误"
assert response_json["error"]["data"]["details"] == "エラー詳細"
def test_create_error_response_yields_once(self):
"""Test that error response generator yields exactly once."""
generator = create_mcp_error_response(request_id="test", code=-32600, message="Test")
# First yield should work
first_yield = next(generator)
assert isinstance(first_yield, bytes)
# Second yield should raise StopIteration
with pytest.raises(StopIteration):
next(generator)
# Subsequent calls should also raise
with pytest.raises(StopIteration):
next(generator)

View File

@@ -0,0 +1,99 @@
from unittest.mock import MagicMock, patch
from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call
ToolCall = AssistantPromptMessage.ToolCall
# CASE 1: Single tool call
INPUTS_CASE_1 = [
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_1 = [
ToolCall(
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
),
]
# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...)
INPUTS_CASE_2 = [
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_2 = [
ToolCall(
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
),
ToolCall(
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
),
]
# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...)
INPUTS_CASE_3 = [
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_3 = [
ToolCall(
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
),
ToolCall(
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
),
]
# CASE 4: Tool call sequences with no IDs
INPUTS_CASE_4 = [
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_4 = [
ToolCall(
id="RANDOM_ID_1",
type="function",
function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
),
ToolCall(
id="RANDOM_ID_2",
type="function",
function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'),
),
]
def _run_case(inputs: list[ToolCall], expected: list[ToolCall]):
actual = []
_increase_tool_call(inputs, actual)
assert actual == expected
def test__increase_tool_call():
# case 1:
_run_case(INPUTS_CASE_1, EXPECTED_CASE_1)
# case 2:
_run_case(INPUTS_CASE_2, EXPECTED_CASE_2)
# case 3:
_run_case(INPUTS_CASE_3, EXPECTED_CASE_3)
# case 4:
mock_id_generator = MagicMock()
mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4]
with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator):
_run_case(INPUTS_CASE_4, EXPECTED_CASE_4)

View File

@@ -0,0 +1,148 @@
"""Tests for LLMUsage entity."""
from decimal import Decimal
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
class TestLLMUsage:
"""Test cases for LLMUsage class."""
def test_from_metadata_with_all_tokens(self):
"""Test from_metadata when all token types are provided."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"prompt_unit_price": 0.001,
"completion_unit_price": 0.002,
"total_price": 0.2,
"currency": "USD",
"latency": 1.5,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150
assert usage.prompt_unit_price == Decimal("0.001")
assert usage.completion_unit_price == Decimal("0.002")
assert usage.total_price == Decimal("0.2")
assert usage.currency == "USD"
assert usage.latency == 1.5
def test_from_metadata_with_prompt_tokens_only(self):
"""Test from_metadata when only prompt_tokens is provided."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"total_tokens": 100,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 0
assert usage.total_tokens == 100
def test_from_metadata_with_completion_tokens_only(self):
"""Test from_metadata when only completion_tokens is provided."""
metadata: LLMUsageMetadata = {
"completion_tokens": 50,
"total_tokens": 50,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 0
assert usage.completion_tokens == 50
assert usage.total_tokens == 50
def test_from_metadata_calculates_total_when_missing(self):
"""Test from_metadata calculates total_tokens when not provided."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"completion_tokens": 50,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150 # Should be calculated
def test_from_metadata_with_total_but_no_completion(self):
"""
Test from_metadata when total_tokens is provided but completion_tokens is 0.
This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens.
"""
metadata: LLMUsageMetadata = {
"prompt_tokens": 479,
"completion_tokens": 0,
"total_tokens": 521,
}
usage = LLMUsage.from_metadata(metadata)
# This is the key fix - prompt tokens should remain as prompt tokens
assert usage.prompt_tokens == 479
assert usage.completion_tokens == 0
assert usage.total_tokens == 521
def test_from_metadata_with_empty_metadata(self):
"""Test from_metadata with empty metadata."""
metadata: LLMUsageMetadata = {}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 0
assert usage.completion_tokens == 0
assert usage.total_tokens == 0
assert usage.currency == "USD"
assert usage.latency == 0.0
def test_from_metadata_preserves_zero_completion_tokens(self):
"""
Test that zero completion_tokens are preserved when explicitly set.
This is important for agent nodes that only use prompt tokens.
"""
metadata: LLMUsageMetadata = {
"prompt_tokens": 1000,
"completion_tokens": 0,
"total_tokens": 1000,
"prompt_unit_price": 0.15,
"completion_unit_price": 0.60,
"prompt_price": 0.00015,
"completion_price": 0,
"total_price": 0.00015,
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_tokens == 1000
assert usage.completion_tokens == 0
assert usage.total_tokens == 1000
assert usage.prompt_price == Decimal("0.00015")
assert usage.completion_price == Decimal(0)
assert usage.total_price == Decimal("0.00015")
def test_from_metadata_with_decimal_values(self):
"""Test from_metadata handles decimal values correctly."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"prompt_unit_price": "0.001",
"completion_unit_price": "0.002",
"prompt_price": "0.1",
"completion_price": "0.1",
"total_price": "0.2",
}
usage = LLMUsage.from_metadata(metadata)
assert usage.prompt_unit_price == Decimal("0.001")
assert usage.completion_unit_price == Decimal("0.002")
assert usage.prompt_price == Decimal("0.1")
assert usage.completion_price == Decimal("0.1")
assert usage.total_price == Decimal("0.2")

View File

@@ -0,0 +1 @@
# Unit tests for core ops module

View File

@@ -0,0 +1,416 @@
import pytest
from pydantic import ValidationError
from core.ops.entities.config_entity import (
AliyunConfig,
ArizeConfig,
LangfuseConfig,
LangSmithConfig,
OpikConfig,
PhoenixConfig,
TracingProviderEnum,
WeaveConfig,
)
class TestTracingProviderEnum:
"""Test cases for TracingProviderEnum"""
def test_enum_values(self):
"""Test that all expected enum values are present"""
assert TracingProviderEnum.ARIZE == "arize"
assert TracingProviderEnum.PHOENIX == "phoenix"
assert TracingProviderEnum.LANGFUSE == "langfuse"
assert TracingProviderEnum.LANGSMITH == "langsmith"
assert TracingProviderEnum.OPIK == "opik"
assert TracingProviderEnum.WEAVE == "weave"
assert TracingProviderEnum.ALIYUN == "aliyun"
class TestArizeConfig:
"""Test cases for ArizeConfig"""
def test_valid_config(self):
"""Test valid Arize configuration"""
config = ArizeConfig(
api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
)
assert config.api_key == "test_key"
assert config.space_id == "test_space"
assert config.project == "test_project"
assert config.endpoint == "https://custom.arize.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = ArizeConfig()
assert config.api_key is None
assert config.space_id is None
assert config.project is None
assert config.endpoint == "https://otlp.arize.com"
def test_project_validation_empty(self):
"""Test project validation with empty value"""
config = ArizeConfig(project="")
assert config.project == "default"
def test_project_validation_none(self):
"""Test project validation with None value"""
config = ArizeConfig(project=None)
assert config.project == "default"
def test_endpoint_validation_empty(self):
"""Test endpoint validation with empty value"""
config = ArizeConfig(endpoint="")
assert config.endpoint == "https://otlp.arize.com"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation normalizes URL by removing path"""
config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
assert config.endpoint == "https://custom.arize.com"
def test_endpoint_validation_invalid_scheme(self):
"""Test endpoint validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
ArizeConfig(endpoint="ftp://invalid.com")
def test_endpoint_validation_no_scheme(self):
"""Test endpoint validation rejects URLs without scheme"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
ArizeConfig(endpoint="invalid.com")
class TestPhoenixConfig:
"""Test cases for PhoenixConfig"""
def test_valid_config(self):
"""Test valid Phoenix configuration"""
config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
assert config.api_key == "test_key"
assert config.project == "test_project"
assert config.endpoint == "https://custom.phoenix.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = PhoenixConfig()
assert config.api_key is None
assert config.project is None
assert config.endpoint == "https://app.phoenix.arize.com"
def test_project_validation_empty(self):
"""Test project validation with empty value"""
config = PhoenixConfig(project="")
assert config.project == "default"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation with path"""
config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
def test_endpoint_validation_without_path(self):
"""Test endpoint validation without path"""
config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
assert config.endpoint == "https://app.phoenix.arize.com"
class TestLangfuseConfig:
"""Test cases for LangfuseConfig"""
def test_valid_config(self):
"""Test valid Langfuse configuration"""
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
assert config.public_key == "public_key"
assert config.secret_key == "secret_key"
assert config.host == "https://custom.langfuse.com"
def test_valid_config_with_path(self):
host = "https://custom.langfuse.com/api/v1"
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host)
assert config.public_key == "public_key"
assert config.secret_key == "secret_key"
assert config.host == host
def test_default_values(self):
"""Test default values are set correctly"""
config = LangfuseConfig(public_key="public", secret_key="secret")
assert config.host == "https://api.langfuse.com"
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangfuseConfig()
with pytest.raises(ValidationError):
LangfuseConfig(public_key="public")
with pytest.raises(ValidationError):
LangfuseConfig(secret_key="secret")
def test_host_validation_empty(self):
"""Test host validation with empty value"""
config = LangfuseConfig(public_key="public", secret_key="secret", host="")
assert config.host == "https://api.langfuse.com"
class TestLangSmithConfig:
"""Test cases for LangSmithConfig"""
def test_valid_config(self):
"""Test valid LangSmith configuration"""
config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
assert config.api_key == "test_key"
assert config.project == "test_project"
assert config.endpoint == "https://custom.smith.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = LangSmithConfig(api_key="key", project="project")
assert config.endpoint == "https://api.smith.langchain.com"
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
LangSmithConfig()
with pytest.raises(ValidationError):
LangSmithConfig(api_key="key")
with pytest.raises(ValidationError):
LangSmithConfig(project="project")
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")
class TestOpikConfig:
"""Test cases for OpikConfig"""
def test_valid_config(self):
"""Test valid Opik configuration"""
config = OpikConfig(
api_key="test_key",
project="test_project",
workspace="test_workspace",
url="https://custom.comet.com/opik/api/",
)
assert config.api_key == "test_key"
assert config.project == "test_project"
assert config.workspace == "test_workspace"
assert config.url == "https://custom.comet.com/opik/api/"
def test_default_values(self):
"""Test default values are set correctly"""
config = OpikConfig()
assert config.api_key is None
assert config.project is None
assert config.workspace is None
assert config.url == "https://www.comet.com/opik/api/"
def test_project_validation_empty(self):
"""Test project validation with empty value"""
config = OpikConfig(project="")
assert config.project == "Default Project"
def test_url_validation_empty(self):
"""Test URL validation with empty value"""
config = OpikConfig(url="")
assert config.url == "https://www.comet.com/opik/api/"
def test_url_validation_missing_suffix(self):
"""Test URL validation requires /api/ suffix"""
with pytest.raises(ValidationError, match="URL should end with /api/"):
OpikConfig(url="https://custom.comet.com/opik/")
def test_url_validation_invalid_scheme(self):
"""Test URL validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
OpikConfig(url="ftp://custom.comet.com/opik/api/")
class TestWeaveConfig:
"""Test cases for WeaveConfig"""
def test_valid_config(self):
"""Test valid Weave configuration"""
config = WeaveConfig(
api_key="test_key",
entity="test_entity",
project="test_project",
endpoint="https://custom.wandb.ai",
host="https://custom.host.com",
)
assert config.api_key == "test_key"
assert config.entity == "test_entity"
assert config.project == "test_project"
assert config.endpoint == "https://custom.wandb.ai"
assert config.host == "https://custom.host.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = WeaveConfig(api_key="key", project="project")
assert config.entity is None
assert config.endpoint == "https://trace.wandb.ai"
assert config.host is None
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
WeaveConfig()
with pytest.raises(ValidationError):
WeaveConfig(api_key="key")
with pytest.raises(ValidationError):
WeaveConfig(project="project")
def test_endpoint_validation_https_only(self):
"""Test endpoint validation only allows HTTPS"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
def test_host_validation_optional(self):
"""Test host validation is optional but validates when provided"""
config = WeaveConfig(api_key="key", project="project", host=None)
assert config.host is None
config = WeaveConfig(api_key="key", project="project", host="")
assert config.host == ""
config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
assert config.host == "https://valid.host.com"
def test_host_validation_invalid_scheme(self):
"""Test host validation rejects invalid schemes when provided"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
class TestAliyunConfig:
"""Test cases for AliyunConfig"""
def test_valid_config(self):
"""Test valid Aliyun configuration"""
config = AliyunConfig(
app_name="test_app",
license_key="test_license_key",
endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
)
assert config.app_name == "test_app"
assert config.license_key == "test_license_key"
assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
def test_default_values(self):
"""Test default values are set correctly"""
config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
assert config.app_name == "dify_app"
def test_missing_required_fields(self):
"""Test that required fields are enforced"""
with pytest.raises(ValidationError):
AliyunConfig()
with pytest.raises(ValidationError):
AliyunConfig(license_key="test_license")
with pytest.raises(ValidationError):
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
def test_app_name_validation_empty(self):
"""Test app_name validation with empty value"""
config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
)
assert config.app_name == "dify_app"
def test_endpoint_validation_empty(self):
"""Test endpoint validation with empty value"""
config = AliyunConfig(license_key="test_license", endpoint="")
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation preserves path for Aliyun endpoints"""
config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
)
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
def test_endpoint_validation_invalid_scheme(self):
"""Test endpoint validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
def test_endpoint_validation_no_scheme(self):
"""Test endpoint validation rejects URLs without scheme"""
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
def test_license_key_required(self):
"""Test that license_key is required and cannot be empty"""
with pytest.raises(ValidationError):
AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
def test_valid_endpoint_format_examples(self):
"""Test valid endpoint format examples from comments"""
valid_endpoints = [
# cms2.0 public endpoint
"https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry",
# cms2.0 intranet endpoint
"https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry",
# xtrace public endpoint
"http://tracing-cn-heyuan.arms.aliyuncs.com",
# xtrace intranet endpoint
"http://tracing-cn-heyuan-internal.arms.aliyuncs.com",
]
for endpoint in valid_endpoints:
config = AliyunConfig(license_key="test_license", endpoint=endpoint)
assert config.endpoint == endpoint
class TestConfigIntegration:
"""Integration tests for configuration classes"""
def test_all_configs_can_be_instantiated(self):
"""Test that all config classes can be instantiated with valid data"""
configs = [
ArizeConfig(api_key="key"),
PhoenixConfig(api_key="key"),
LangfuseConfig(public_key="public", secret_key="secret"),
LangSmithConfig(api_key="key", project="project"),
OpikConfig(api_key="key"),
WeaveConfig(api_key="key", project="project"),
AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com"),
]
for config in configs:
assert config is not None
def test_url_normalization_consistency(self):
"""Test that URL normalization works consistently across configs"""
# Test that paths are removed from endpoints
arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test")
phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
aliyun_config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
)
assert arize_config.endpoint == "https://arize.com"
assert phoenix_with_path_config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
assert phoenix_without_path_config.endpoint == "https://app.phoenix.arize.com"
assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
def test_project_default_values(self):
"""Test that project default values are set correctly"""
arize_config = ArizeConfig(project="")
phoenix_config = PhoenixConfig(project="")
opik_config = OpikConfig(project="")
aliyun_config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
)
assert arize_config.project == "default"
assert phoenix_config.project == "default"
assert opik_config.project == "Default Project"
assert aliyun_config.app_name == "dify_app"

View File

@@ -0,0 +1,138 @@
import pytest
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
class TestValidateUrl:
"""Test cases for validate_url function"""
def test_valid_https_url(self):
"""Test valid HTTPS URL"""
result = validate_url("https://example.com", "https://default.com")
assert result == "https://example.com"
def test_valid_http_url(self):
"""Test valid HTTP URL"""
result = validate_url("http://example.com", "https://default.com")
assert result == "http://example.com"
def test_url_with_path_removed(self):
"""Test that URL path is removed during normalization"""
result = validate_url("https://example.com/api/v1/test", "https://default.com")
assert result == "https://example.com"
def test_url_with_query_removed(self):
"""Test that URL query parameters are removed"""
result = validate_url("https://example.com?param=value", "https://default.com")
assert result == "https://example.com"
def test_url_with_fragment_removed(self):
"""Test that URL fragments are removed"""
result = validate_url("https://example.com#section", "https://default.com")
assert result == "https://example.com"
def test_empty_url_returns_default(self):
"""Test empty URL returns default"""
result = validate_url("", "https://default.com")
assert result == "https://default.com"
def test_none_url_returns_default(self):
"""Test None URL returns default"""
result = validate_url(None, "https://default.com")
assert result == "https://default.com"
def test_whitespace_url_returns_default(self):
"""Test whitespace URL returns default"""
result = validate_url(" ", "https://default.com")
assert result == "https://default.com"
def test_invalid_scheme_raises_error(self):
"""Test invalid scheme raises ValueError"""
with pytest.raises(ValueError, match="URL scheme must be one of"):
validate_url("ftp://example.com", "https://default.com")
def test_no_scheme_raises_error(self):
"""Test URL without scheme raises ValueError"""
with pytest.raises(ValueError, match="URL scheme must be one of"):
validate_url("example.com", "https://default.com")
def test_custom_allowed_schemes(self):
"""Test custom allowed schemes"""
result = validate_url("https://example.com", "https://default.com", allowed_schemes=("https",))
assert result == "https://example.com"
with pytest.raises(ValueError, match="URL scheme must be one of"):
validate_url("http://example.com", "https://default.com", allowed_schemes=("https",))
class TestValidateUrlWithPath:
"""Test cases for validate_url_with_path function"""
def test_valid_url_with_path(self):
"""Test valid URL with path"""
result = validate_url_with_path("https://example.com/api/v1", "https://default.com")
assert result == "https://example.com/api/v1"
def test_valid_url_with_required_suffix(self):
"""Test valid URL with required suffix"""
result = validate_url_with_path("https://example.com/api/", "https://default.com", required_suffix="/api/")
assert result == "https://example.com/api/"
def test_url_without_required_suffix_raises_error(self):
"""Test URL without required suffix raises error"""
with pytest.raises(ValueError, match="URL should end with /api/"):
validate_url_with_path("https://example.com/api", "https://default.com", required_suffix="/api/")
def test_empty_url_returns_default(self):
"""Test empty URL returns default"""
result = validate_url_with_path("", "https://default.com")
assert result == "https://default.com"
def test_none_url_returns_default(self):
"""Test None URL returns default"""
result = validate_url_with_path(None, "https://default.com")
assert result == "https://default.com"
def test_invalid_scheme_raises_error(self):
"""Test invalid scheme raises ValueError"""
with pytest.raises(ValueError, match="URL must start with https:// or http://"):
validate_url_with_path("ftp://example.com", "https://default.com")
def test_no_scheme_raises_error(self):
"""Test URL without scheme raises ValueError"""
with pytest.raises(ValueError, match="URL must start with https:// or http://"):
validate_url_with_path("example.com", "https://default.com")
class TestValidateProjectName:
"""Test cases for validate_project_name function"""
def test_valid_project_name(self):
"""Test valid project name"""
result = validate_project_name("my-project", "default")
assert result == "my-project"
def test_empty_project_name_returns_default(self):
"""Test empty project name returns default"""
result = validate_project_name("", "default")
assert result == "default"
def test_none_project_name_returns_default(self):
"""Test None project name returns default"""
result = validate_project_name(None, "default")
assert result == "default"
def test_whitespace_project_name_returns_default(self):
"""Test whitespace project name returns default"""
result = validate_project_name(" ", "default")
assert result == "default"
def test_project_name_with_whitespace_trimmed(self):
"""Test project name with whitespace is trimmed"""
result = validate_project_name(" my-project ", "default")
assert result == "my-project"
def test_custom_default_name(self):
"""Test custom default name"""
result = validate_project_name("", "Custom Default")
assert result == "Custom Default"

View File

@@ -0,0 +1,460 @@
from collections.abc import Generator
import pytest
from core.agent.entities import AgentInvokeMessage
from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks
from core.tools.entities.tool_entities import ToolInvokeMessage
class TestChunkMerger:
def test_file_chunk_initialization(self):
"""Test FileChunk initialization."""
chunk = FileChunk(1024)
assert chunk.bytes_written == 0
assert chunk.total_length == 1024
assert len(chunk.data) == 1024
def test_merge_blob_chunks_with_single_complete_chunk(self):
"""Test merging a single complete blob chunk."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# First chunk (partial)
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=10, blob=b"Hello", end=False
),
)
# Second chunk (final)
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=1, total_length=10, blob=b"World", end=True
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
# The buffer should contain the complete data
assert result[0].message.blob[:10] == b"HelloWorld"
def test_merge_blob_chunks_with_multiple_files(self):
"""Test merging chunks from multiple files."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# File 1, chunk 1
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=4, blob=b"AB", end=False
),
)
# File 2, chunk 1
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file2", sequence=0, total_length=4, blob=b"12", end=False
),
)
# File 1, chunk 2 (final)
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=1, total_length=4, blob=b"CD", end=True
),
)
# File 2, chunk 2 (final)
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file2", sequence=1, total_length=4, blob=b"34", end=True
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 2
# Check that both files are properly merged
assert all(r.type == ToolInvokeMessage.MessageType.BLOB for r in result)
def test_merge_blob_chunks_passes_through_non_blob_messages(self):
"""Test that non-blob messages pass through unchanged."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Text message
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(text="Hello"),
)
# Blob chunk
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=5, blob=b"Test", end=True
),
)
# Another text message
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(text="World"),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 3
assert result[0].type == ToolInvokeMessage.MessageType.TEXT
assert isinstance(result[0].message, ToolInvokeMessage.TextMessage)
assert result[0].message.text == "Hello"
assert result[1].type == ToolInvokeMessage.MessageType.BLOB
assert result[2].type == ToolInvokeMessage.MessageType.TEXT
assert isinstance(result[2].message, ToolInvokeMessage.TextMessage)
assert result[2].message.text == "World"
def test_merge_blob_chunks_file_too_large(self):
"""Test that error is raised when file exceeds max size."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Send a chunk that would exceed the limit
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=100, blob=b"x" * 1024, end=False
),
)
with pytest.raises(ValueError) as exc_info:
list(merge_blob_chunks(mock_generator(), max_file_size=1000))
assert "File is too large" in str(exc_info.value)
def test_merge_blob_chunks_chunk_too_large(self):
"""Test that error is raised when chunk exceeds max chunk size."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Send a chunk that exceeds the max chunk size
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=10000, blob=b"x" * 9000, end=False
),
)
with pytest.raises(ValueError) as exc_info:
list(merge_blob_chunks(mock_generator(), max_chunk_size=8192))
assert "File chunk is too large" in str(exc_info.value)
def test_merge_blob_chunks_with_agent_invoke_message(self):
"""Test that merge_blob_chunks works with AgentInvokeMessage."""
def mock_generator() -> Generator[AgentInvokeMessage, None, None]:
# First chunk
yield AgentInvokeMessage(
type=AgentInvokeMessage.MessageType.BLOB_CHUNK,
message=AgentInvokeMessage.BlobChunkMessage(
id="agent_file", sequence=0, total_length=8, blob=b"Agent", end=False
),
)
# Final chunk
yield AgentInvokeMessage(
type=AgentInvokeMessage.MessageType.BLOB_CHUNK,
message=AgentInvokeMessage.BlobChunkMessage(
id="agent_file", sequence=1, total_length=8, blob=b"Data", end=True
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert isinstance(result[0], AgentInvokeMessage)
assert result[0].type == AgentInvokeMessage.MessageType.BLOB
def test_merge_blob_chunks_preserves_meta(self):
"""Test that meta information is preserved in merged messages."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=4, blob=b"Test", end=True
),
meta={"key": "value"},
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert result[0].meta == {"key": "value"}
def test_merge_blob_chunks_custom_limits(self):
"""Test merge_blob_chunks with custom size limits."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# This should work with custom limits
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=500, blob=b"x" * 400, end=False
),
)
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=1, total_length=500, blob=b"y" * 100, end=True
),
)
# Should work with custom limits
result = list(merge_blob_chunks(mock_generator(), max_file_size=1000, max_chunk_size=500))
assert len(result) == 1
# Should fail with smaller file size limit
def mock_generator2() -> Generator[ToolInvokeMessage, None, None]:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=500, blob=b"x" * 400, end=False
),
)
with pytest.raises(ValueError):
list(merge_blob_chunks(mock_generator2(), max_file_size=300))
def test_merge_blob_chunks_data_integrity(self):
"""Test that merged chunks exactly match the original data."""
# Create original data
original_data = b"This is a test message that will be split into chunks for testing purposes."
chunk_size = 20
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Split original data into chunks
chunks = []
for i in range(0, len(original_data), chunk_size):
chunk_data = original_data[i : i + chunk_size]
is_last = (i + chunk_size) >= len(original_data)
chunks.append((i // chunk_size, chunk_data, is_last))
# Yield chunks
for sequence, data, is_end in chunks:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="test_file",
sequence=sequence,
total_length=len(original_data),
blob=data,
end=is_end,
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
# Verify the merged data exactly matches the original
assert result[0].message.blob == original_data
def test_merge_blob_chunks_empty_chunk(self):
"""Test handling of empty chunks."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# First chunk with data
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=0, total_length=10, blob=b"Hello", end=False
),
)
# Empty chunk in the middle
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=1, total_length=10, blob=b"", end=False
),
)
# Final chunk with data
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="file1", sequence=2, total_length=10, blob=b"World", end=True
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
# The final blob should contain "Hello" followed by "World"
assert result[0].message.blob[:10] == b"HelloWorld"
def test_merge_blob_chunks_single_chunk_file(self):
"""Test file that arrives as a single complete chunk."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Single chunk that is both first and last
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="single_chunk_file",
sequence=0,
total_length=11,
blob=b"Single Data",
end=True,
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
assert result[0].message.blob == b"Single Data"
def test_merge_blob_chunks_concurrent_files(self):
"""Test that chunks from different files are properly separated."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Interleave chunks from three different files
files_data = {
"file1": b"First file content",
"file2": b"Second file data",
"file3": b"Third file",
}
# First chunk from each file
for file_id, data in files_data.items():
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id=file_id,
sequence=0,
total_length=len(data),
blob=data[:6],
end=False,
),
)
# Second chunk from each file (final)
for file_id, data in files_data.items():
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id=file_id,
sequence=1,
total_length=len(data),
blob=data[6:],
end=True,
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 3
# Extract the blob data from results
blobs = set()
for r in result:
assert isinstance(r.message, ToolInvokeMessage.BlobMessage)
blobs.add(r.message.blob)
expected = {b"First file content", b"Second file data", b"Third file"}
assert blobs == expected
def test_merge_blob_chunks_exact_buffer_size(self):
"""Test that data fitting exactly in buffer works correctly."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Create data that exactly fills the declared buffer
exact_data = b"X" * 100
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="exact_file",
sequence=0,
total_length=100,
blob=exact_data[:50],
end=False,
),
)
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="exact_file",
sequence=1,
total_length=100,
blob=exact_data[50:],
end=True,
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
assert len(result[0].message.blob) == 100
assert result[0].message.blob == b"X" * 100
def test_merge_blob_chunks_large_file_simulation(self):
"""Test handling of a large file split into many chunks."""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Simulate a 1MB file split into 128 chunks of 8KB each
chunk_size = 8192
num_chunks = 128
total_size = chunk_size * num_chunks
for i in range(num_chunks):
# Create unique data for each chunk to verify ordering
chunk_data = bytes([i % 256]) * chunk_size
is_last = i == num_chunks - 1
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="large_file",
sequence=i,
total_length=total_size,
blob=chunk_data,
end=is_last,
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
assert len(result[0].message.blob) == 1024 * 1024
# Verify the data pattern is correct
merged_data = result[0].message.blob
chunk_size = 8192
num_chunks = 128
for i in range(num_chunks):
chunk_start = i * chunk_size
chunk_end = chunk_start + chunk_size
expected_byte = i % 256
chunk = merged_data[chunk_start:chunk_end]
assert all(b == expected_byte for b in chunk), f"Chunk {i} has incorrect data"
def test_merge_blob_chunks_sequential_order_required(self):
"""
Test note: The current implementation assumes chunks arrive in sequential order.
Out-of-order chunks would need additional logic to handle properly.
This test documents the expected behavior with sequential chunks.
"""
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
# Chunks arriving in correct sequential order
data_parts = [b"First", b"Second", b"Third"]
total_length = sum(len(part) for part in data_parts)
for i, part in enumerate(data_parts):
is_last = i == len(data_parts) - 1
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
message=ToolInvokeMessage.BlobChunkMessage(
id="ordered_file",
sequence=i,
total_length=total_length,
blob=part,
end=is_last,
),
)
result = list(merge_blob_chunks(mock_generator()))
assert len(result) == 1
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
assert result[0].message.blob == b"FirstSecondThird"

View File

@@ -0,0 +1,655 @@
import pytest
from flask import Request, Response
from core.plugin.utils.http_parser import (
deserialize_request,
deserialize_response,
serialize_request,
serialize_response,
)
class TestSerializeRequest:
def test_serialize_simple_get_request(self):
# Create a simple GET request
environ = {
"REQUEST_METHOD": "GET",
"PATH_INFO": "/api/test",
"QUERY_STRING": "",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8000",
"wsgi.input": None,
"wsgi.url_scheme": "http",
}
request = Request(environ)
raw_data = serialize_request(request)
assert raw_data.startswith(b"GET /api/test HTTP/1.1\r\n")
assert b"\r\n\r\n" in raw_data # Empty line between headers and body
def test_serialize_request_with_query_params(self):
# Create a GET request with query parameters
environ = {
"REQUEST_METHOD": "GET",
"PATH_INFO": "/api/search",
"QUERY_STRING": "q=test&limit=10",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8000",
"wsgi.input": None,
"wsgi.url_scheme": "http",
}
request = Request(environ)
raw_data = serialize_request(request)
assert raw_data.startswith(b"GET /api/search?q=test&limit=10 HTTP/1.1\r\n")
def test_serialize_post_request_with_body(self):
# Create a POST request with body
from io import BytesIO
body = b'{"name": "test", "value": 123}'
environ = {
"REQUEST_METHOD": "POST",
"PATH_INFO": "/api/data",
"QUERY_STRING": "",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8000",
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "http",
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": "application/json",
"HTTP_CONTENT_TYPE": "application/json",
}
request = Request(environ)
raw_data = serialize_request(request)
assert b"POST /api/data HTTP/1.1\r\n" in raw_data
assert b"Content-Type: application/json" in raw_data
assert raw_data.endswith(body)
def test_serialize_request_with_custom_headers(self):
# Create a request with custom headers
environ = {
"REQUEST_METHOD": "GET",
"PATH_INFO": "/api/test",
"QUERY_STRING": "",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8000",
"wsgi.input": None,
"wsgi.url_scheme": "http",
"HTTP_AUTHORIZATION": "Bearer token123",
"HTTP_X_CUSTOM_HEADER": "custom-value",
}
request = Request(environ)
raw_data = serialize_request(request)
assert b"Authorization: Bearer token123" in raw_data
assert b"X-Custom-Header: custom-value" in raw_data
class TestDeserializeRequest:
def test_deserialize_simple_get_request(self):
raw_data = b"GET /api/test HTTP/1.1\r\nHost: localhost:8000\r\n\r\n"
request = deserialize_request(raw_data)
assert request.method == "GET"
assert request.path == "/api/test"
assert request.headers.get("Host") == "localhost:8000"
def test_deserialize_request_with_query_params(self):
raw_data = b"GET /api/search?q=test&limit=10 HTTP/1.1\r\nHost: example.com\r\n\r\n"
request = deserialize_request(raw_data)
assert request.method == "GET"
assert request.path == "/api/search"
assert request.query_string == b"q=test&limit=10"
assert request.args.get("q") == "test"
assert request.args.get("limit") == "10"
def test_deserialize_post_request_with_body(self):
body = b'{"name": "test", "value": 123}'
raw_data = (
b"POST /api/data HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Type: application/json\r\n"
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
b"\r\n" + body
)
request = deserialize_request(raw_data)
assert request.method == "POST"
assert request.path == "/api/data"
assert request.content_type == "application/json"
assert request.get_data() == body
def test_deserialize_request_with_custom_headers(self):
raw_data = (
b"GET /api/protected HTTP/1.1\r\n"
b"Host: api.example.com\r\n"
b"Authorization: Bearer token123\r\n"
b"X-Custom-Header: custom-value\r\n"
b"User-Agent: TestClient/1.0\r\n"
b"\r\n"
)
request = deserialize_request(raw_data)
assert request.method == "GET"
assert request.headers.get("Authorization") == "Bearer token123"
assert request.headers.get("X-Custom-Header") == "custom-value"
assert request.headers.get("User-Agent") == "TestClient/1.0"
def test_deserialize_request_with_multiline_body(self):
body = b"line1\r\nline2\r\nline3"
raw_data = b"PUT /api/text HTTP/1.1\r\nHost: localhost\r\nContent-Type: text/plain\r\n\r\n" + body
request = deserialize_request(raw_data)
assert request.method == "PUT"
assert request.get_data() == body
def test_deserialize_invalid_request_line(self):
raw_data = b"INVALID\r\n\r\n" # Only one part, should fail
with pytest.raises(ValueError, match="Invalid request line"):
deserialize_request(raw_data)
def test_roundtrip_request(self):
# Test that serialize -> deserialize produces equivalent request
from io import BytesIO
body = b"test body content"
environ = {
"REQUEST_METHOD": "POST",
"PATH_INFO": "/api/echo",
"QUERY_STRING": "format=json",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8080",
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "http",
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": "text/plain",
"HTTP_CONTENT_TYPE": "text/plain",
"HTTP_X_REQUEST_ID": "req-123",
}
original_request = Request(environ)
# Serialize and deserialize
raw_data = serialize_request(original_request)
restored_request = deserialize_request(raw_data)
# Verify key properties are preserved
assert restored_request.method == original_request.method
assert restored_request.path == original_request.path
assert restored_request.query_string == original_request.query_string
assert restored_request.get_data() == body
assert restored_request.headers.get("X-Request-Id") == "req-123"
class TestSerializeResponse:
def test_serialize_simple_response(self):
response = Response("Hello, World!", status=200)
raw_data = serialize_response(response)
assert raw_data.startswith(b"HTTP/1.1 200 OK\r\n")
assert b"\r\n\r\n" in raw_data
assert raw_data.endswith(b"Hello, World!")
def test_serialize_response_with_headers(self):
response = Response(
'{"status": "success"}',
status=201,
headers={
"Content-Type": "application/json",
"X-Request-Id": "req-456",
},
)
raw_data = serialize_response(response)
assert b"HTTP/1.1 201 CREATED\r\n" in raw_data
assert b"Content-Type: application/json" in raw_data
assert b"X-Request-Id: req-456" in raw_data
assert raw_data.endswith(b'{"status": "success"}')
def test_serialize_error_response(self):
response = Response(
"Not Found",
status=404,
headers={"Content-Type": "text/plain"},
)
raw_data = serialize_response(response)
assert b"HTTP/1.1 404 NOT FOUND\r\n" in raw_data
assert b"Content-Type: text/plain" in raw_data
assert raw_data.endswith(b"Not Found")
def test_serialize_response_without_body(self):
response = Response(status=204) # No Content
raw_data = serialize_response(response)
assert b"HTTP/1.1 204 NO CONTENT\r\n" in raw_data
assert raw_data.endswith(b"\r\n\r\n") # Should end with empty line
def test_serialize_response_with_binary_body(self):
binary_data = b"\x00\x01\x02\x03\x04\x05"
response = Response(
binary_data,
status=200,
headers={"Content-Type": "application/octet-stream"},
)
raw_data = serialize_response(response)
assert b"HTTP/1.1 200 OK\r\n" in raw_data
assert b"Content-Type: application/octet-stream" in raw_data
assert raw_data.endswith(binary_data)
class TestDeserializeResponse:
def test_deserialize_simple_response(self):
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nHello, World!"
response = deserialize_response(raw_data)
assert response.status_code == 200
assert response.get_data() == b"Hello, World!"
assert response.headers.get("Content-Type") == "text/plain"
def test_deserialize_response_with_json(self):
body = b'{"result": "success", "data": [1, 2, 3]}'
raw_data = (
b"HTTP/1.1 201 Created\r\n"
b"Content-Type: application/json\r\n"
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
b"X-Custom-Header: test-value\r\n"
b"\r\n" + body
)
response = deserialize_response(raw_data)
assert response.status_code == 201
assert response.get_data() == body
assert response.headers.get("Content-Type") == "application/json"
assert response.headers.get("X-Custom-Header") == "test-value"
def test_deserialize_error_response(self):
raw_data = b"HTTP/1.1 404 Not Found\r\nContent-Type: text/html\r\n\r\n<html><body>Page not found</body></html>"
response = deserialize_response(raw_data)
assert response.status_code == 404
assert response.get_data() == b"<html><body>Page not found</body></html>"
def test_deserialize_response_without_body(self):
raw_data = b"HTTP/1.1 204 No Content\r\n\r\n"
response = deserialize_response(raw_data)
assert response.status_code == 204
assert response.get_data() == b""
def test_deserialize_response_with_multiline_body(self):
body = b"Line 1\r\nLine 2\r\nLine 3"
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n" + body
response = deserialize_response(raw_data)
assert response.status_code == 200
assert response.get_data() == body
def test_deserialize_response_minimal_status_line(self):
# Test with minimal status line (no status text)
raw_data = b"HTTP/1.1 200\r\n\r\nOK"
response = deserialize_response(raw_data)
assert response.status_code == 200
assert response.get_data() == b"OK"
def test_deserialize_invalid_status_line(self):
raw_data = b"INVALID\r\n\r\n"
with pytest.raises(ValueError, match="Invalid status line"):
deserialize_response(raw_data)
def test_roundtrip_response(self):
# Test that serialize -> deserialize produces equivalent response
original_response = Response(
'{"message": "test"}',
status=200,
headers={
"Content-Type": "application/json",
"X-Request-Id": "abc-123",
"Cache-Control": "no-cache",
},
)
# Serialize and deserialize
raw_data = serialize_response(original_response)
restored_response = deserialize_response(raw_data)
# Verify key properties are preserved
assert restored_response.status_code == original_response.status_code
assert restored_response.get_data() == original_response.get_data()
assert restored_response.headers.get("Content-Type") == "application/json"
assert restored_response.headers.get("X-Request-Id") == "abc-123"
assert restored_response.headers.get("Cache-Control") == "no-cache"
class TestEdgeCases:
def test_request_with_empty_headers(self):
raw_data = b"GET / HTTP/1.1\r\n\r\n"
request = deserialize_request(raw_data)
assert request.method == "GET"
assert request.path == "/"
def test_response_with_empty_headers(self):
raw_data = b"HTTP/1.1 200 OK\r\n\r\nSuccess"
response = deserialize_response(raw_data)
assert response.status_code == 200
assert response.get_data() == b"Success"
def test_request_with_special_characters_in_path(self):
raw_data = b"GET /api/test%20path?key=%26value HTTP/1.1\r\n\r\n"
request = deserialize_request(raw_data)
assert request.method == "GET"
assert "/api/test%20path" in request.full_path
def test_response_with_binary_content(self):
binary_body = bytes(range(256)) # All possible byte values
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\n\r\n" + binary_body
response = deserialize_response(raw_data)
assert response.status_code == 200
assert response.get_data() == binary_body
class TestFileUploads:
def test_serialize_request_with_text_file_upload(self):
# Test multipart/form-data request with text file
from io import BytesIO
boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW"
text_content = "Hello, this is a test file content!\nWith multiple lines."
body = (
f"------{boundary}\r\n"
f'Content-Disposition: form-data; name="file"; filename="test.txt"\r\n'
f"Content-Type: text/plain\r\n"
f"\r\n"
f"{text_content}\r\n"
f"------{boundary}\r\n"
f'Content-Disposition: form-data; name="description"\r\n'
f"\r\n"
f"Test file upload\r\n"
f"------{boundary}--\r\n"
).encode()
environ = {
"REQUEST_METHOD": "POST",
"PATH_INFO": "/api/upload",
"QUERY_STRING": "",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8000",
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "http",
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
}
request = Request(environ)
raw_data = serialize_request(request)
assert b"POST /api/upload HTTP/1.1\r\n" in raw_data
assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data
assert b'Content-Disposition: form-data; name="file"; filename="test.txt"' in raw_data
assert text_content.encode() in raw_data
def test_deserialize_request_with_text_file_upload(self):
# Test deserializing multipart/form-data request with text file
boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW"
text_content = "Sample text file content\nLine 2\nLine 3"
body = (
f"------{boundary}\r\n"
f'Content-Disposition: form-data; name="document"; filename="document.txt"\r\n'
f"Content-Type: text/plain\r\n"
f"\r\n"
f"{text_content}\r\n"
f"------{boundary}\r\n"
f'Content-Disposition: form-data; name="title"\r\n'
f"\r\n"
f"My Document\r\n"
f"------{boundary}--\r\n"
).encode()
raw_data = (
b"POST /api/documents HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Type: multipart/form-data; boundary=" + boundary.encode() + b"\r\n"
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
b"\r\n" + body
)
request = deserialize_request(raw_data)
assert request.method == "POST"
assert request.path == "/api/documents"
assert "multipart/form-data" in request.content_type
# The body should contain the multipart data
request_body = request.get_data()
assert b"document.txt" in request_body
assert text_content.encode() in request_body
def test_serialize_request_with_binary_file_upload(self):
# Test multipart/form-data request with binary file (e.g., image)
from io import BytesIO
boundary = "----BoundaryString123"
# Simulate a small PNG file header
binary_content = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x10\x00\x00\x00\x10"
# Build multipart body
body_parts = []
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="image"; filename="test.png"')
body_parts.append(b"Content-Type: image/png")
body_parts.append(b"")
body_parts.append(binary_content)
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="caption"')
body_parts.append(b"")
body_parts.append(b"Test image")
body_parts.append(f"------{boundary}--".encode())
body = b"\r\n".join(body_parts)
environ = {
"REQUEST_METHOD": "POST",
"PATH_INFO": "/api/images",
"QUERY_STRING": "",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8000",
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "http",
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
}
request = Request(environ)
raw_data = serialize_request(request)
assert b"POST /api/images HTTP/1.1\r\n" in raw_data
assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data
assert b'filename="test.png"' in raw_data
assert b"Content-Type: image/png" in raw_data
assert binary_content in raw_data
def test_deserialize_request_with_binary_file_upload(self):
# Test deserializing multipart/form-data request with binary file
boundary = "----BoundaryABC123"
# Simulate a small JPEG file header
binary_content = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00"
body_parts = []
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="photo"; filename="photo.jpg"')
body_parts.append(b"Content-Type: image/jpeg")
body_parts.append(b"")
body_parts.append(binary_content)
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="album"')
body_parts.append(b"")
body_parts.append(b"Vacation 2024")
body_parts.append(f"------{boundary}--".encode())
body = b"\r\n".join(body_parts)
raw_data = (
b"POST /api/photos HTTP/1.1\r\n"
b"Host: api.example.com\r\n"
b"Content-Type: multipart/form-data; boundary=" + boundary.encode() + b"\r\n"
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
b"Accept: application/json\r\n"
b"\r\n" + body
)
request = deserialize_request(raw_data)
assert request.method == "POST"
assert request.path == "/api/photos"
assert "multipart/form-data" in request.content_type
assert request.headers.get("Accept") == "application/json"
# Verify the binary content is preserved
request_body = request.get_data()
assert b"photo.jpg" in request_body
assert b"image/jpeg" in request_body
assert binary_content in request_body
assert b"Vacation 2024" in request_body
def test_serialize_request_with_multiple_files(self):
# Test request with multiple file uploads
from io import BytesIO
boundary = "----MultiFilesBoundary"
text_file = b"Text file contents"
binary_file = b"\x00\x01\x02\x03\x04\x05"
body_parts = []
# First file (text)
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="files"; filename="doc.txt"')
body_parts.append(b"Content-Type: text/plain")
body_parts.append(b"")
body_parts.append(text_file)
# Second file (binary)
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="files"; filename="data.bin"')
body_parts.append(b"Content-Type: application/octet-stream")
body_parts.append(b"")
body_parts.append(binary_file)
# Additional form field
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="folder"')
body_parts.append(b"")
body_parts.append(b"uploads/2024")
body_parts.append(f"------{boundary}--".encode())
body = b"\r\n".join(body_parts)
environ = {
"REQUEST_METHOD": "POST",
"PATH_INFO": "/api/batch-upload",
"QUERY_STRING": "",
"SERVER_NAME": "localhost",
"SERVER_PORT": "8000",
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "https",
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
"HTTP_X_FORWARDED_PROTO": "https",
}
request = Request(environ)
raw_data = serialize_request(request)
assert b"POST /api/batch-upload HTTP/1.1\r\n" in raw_data
assert b"doc.txt" in raw_data
assert b"data.bin" in raw_data
assert text_file in raw_data
assert binary_file in raw_data
assert b"uploads/2024" in raw_data
def test_roundtrip_file_upload_request(self):
# Test that file upload request survives serialize -> deserialize
from io import BytesIO
boundary = "----RoundTripBoundary"
file_content = b"This is my file content with special chars: \xf0\x9f\x98\x80"
body_parts = []
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="upload"; filename="emoji.txt"')
body_parts.append(b"Content-Type: text/plain; charset=utf-8")
body_parts.append(b"")
body_parts.append(file_content)
body_parts.append(f"------{boundary}".encode())
body_parts.append(b'Content-Disposition: form-data; name="metadata"')
body_parts.append(b"")
body_parts.append(b'{"encoding": "utf-8", "size": 42}')
body_parts.append(f"------{boundary}--".encode())
body = b"\r\n".join(body_parts)
environ = {
"REQUEST_METHOD": "PUT",
"PATH_INFO": "/api/files/123",
"QUERY_STRING": "version=2",
"SERVER_NAME": "storage.example.com",
"SERVER_PORT": "443",
"wsgi.input": BytesIO(body),
"wsgi.url_scheme": "https",
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
"HTTP_AUTHORIZATION": "Bearer token123",
"HTTP_X_FORWARDED_PROTO": "https",
}
original_request = Request(environ)
# Serialize and deserialize
raw_data = serialize_request(original_request)
restored_request = deserialize_request(raw_data)
# Verify the request is preserved
assert restored_request.method == "PUT"
assert restored_request.path == "/api/files/123"
assert restored_request.query_string == b"version=2"
assert "multipart/form-data" in restored_request.content_type
assert boundary in restored_request.content_type
# Verify file content is preserved
restored_body = restored_request.get_data()
assert b"emoji.txt" in restored_body
assert file_content in restored_body
assert b'{"encoding": "utf-8", "size": 42}' in restored_body

View File

@@ -0,0 +1,190 @@
from unittest.mock import MagicMock, patch
import pytest
from configs import dify_config
from core.app.app_config.entities import ModelConfigEntity
from core.file import File, FileTransferMethod, FileType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessageRole,
UserPromptMessage,
)
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.model import Conversation
def test__get_completion_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-3.5-turbo-instruct"
prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}."
prompt_template_config = CompletionModelPromptTemplate(text=prompt_template)
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=False),
)
inputs = {"name": "John"}
files = []
context = "I am superman."
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_completion_model_prompt_messages(
prompt_template=prompt_template_config,
inputs=inputs,
query=None,
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config_mock,
)
assert len(prompt_messages) == 1
assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format(
{
"#context#": context,
"#histories#": "\n".join(
[
f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: {prompt.content}"
for prompt in history_prompt_messages
]
),
**inputs,
}
)
def test__get_chat_model_prompt_messages(get_chat_model_args):
model_config_mock, memory_config, messages, inputs, context = get_chat_model_args
files = []
query = "Hi2."
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
history_prompt_messages = [UserPromptMessage(content="Hi1."), AssistantPromptMessage(content="Hello1!")]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,
query=query,
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config_mock,
)
assert len(prompt_messages) == 6
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
{**inputs, "#context#": context}
)
assert prompt_messages[5].content == query
def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
model_config_mock, _, messages, inputs, context = get_chat_model_args
files = []
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,
query=None,
files=files,
context=context,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert len(prompt_messages) == 3
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
{**inputs, "#context#": context}
)
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
model_config_mock, _, messages, inputs, context = get_chat_model_args
dify_config.MULTIMODAL_SEND_FORMAT = "url"
files = [
File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="",
)
]
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
mock_get_encoded_string.return_value = ImagePromptMessageContent(
url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,
query=None,
files=files,
context=context,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert len(prompt_messages) == 4
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
{**inputs, "#context#": context}
)
assert isinstance(prompt_messages[3].content, list)
assert len(prompt_messages[3].content) == 2
assert prompt_messages[3].content[0].data == files[0].remote_url
@pytest.fixture
def get_chat_model_args():
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-4"
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
prompt_messages = [
ChatModelMessage(
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", role=PromptMessageRole.SYSTEM
),
ChatModelMessage(text="Hi.", role=PromptMessageRole.USER),
ChatModelMessage(text="Hello!", role=PromptMessageRole.ASSISTANT),
]
inputs = {"name": "John"}
context = "I am superman."
return model_config_mock, memory_config, prompt_messages, inputs, context

View File

@@ -0,0 +1,73 @@
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import (
ModelConfigWithCredentialsEntity,
)
from core.entities.provider_configuration import ProviderModelBundle
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from models.model import Conversation
def test_get_prompt():
prompt_messages = [
SystemPromptMessage(content="System Template"),
UserPromptMessage(content="User Query"),
]
history_messages = [
SystemPromptMessage(content="System Prompt 1"),
UserPromptMessage(content="User Prompt 1"),
AssistantPromptMessage(content="Assistant Thought 1"),
ToolPromptMessage(content="Tool 1-1", name="Tool 1-1", tool_call_id="1"),
ToolPromptMessage(content="Tool 1-2", name="Tool 1-2", tool_call_id="2"),
SystemPromptMessage(content="System Prompt 2"),
UserPromptMessage(content="User Prompt 2"),
AssistantPromptMessage(content="Assistant Thought 2"),
ToolPromptMessage(content="Tool 2-1", name="Tool 2-1", tool_call_id="3"),
ToolPromptMessage(content="Tool 2-2", name="Tool 2-2", tool_call_id="4"),
UserPromptMessage(content="User Prompt 3"),
AssistantPromptMessage(content="Assistant Thought 3"),
]
# use message number instead of token for testing
def side_effect_get_num_tokens(*args):
return len(args[2])
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens)
provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
provider_model_bundle_mock.model_type_instance = large_language_model_mock
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.model = "openai"
model_config_mock.credentials = {}
model_config_mock.provider_model_bundle = provider_model_bundle_mock
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
transform = AgentHistoryPromptTransform(
model_config=model_config_mock,
prompt_messages=prompt_messages,
history_messages=history_messages,
memory=memory,
)
max_token_limit = 5
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
result = transform.get_prompt()
assert len(result) == 4
max_token_limit = 20
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
result = transform.get_prompt()
assert len(result) == 12

View File

@@ -0,0 +1,91 @@
from uuid import uuid4
from constants import UUID_NIL
from core.prompt.utils.extract_thread_messages import extract_thread_messages
class MockMessage:
def __init__(self, id, parent_message_id):
self.id = id
self.parent_message_id = parent_message_id
def __getitem__(self, item):
return getattr(self, item)
def test_extract_thread_messages_single_message():
messages = [MockMessage(str(uuid4()), UUID_NIL)]
result = extract_thread_messages(messages)
assert len(result) == 1
assert result[0] == messages[0]
def test_extract_thread_messages_linear_thread():
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
MockMessage(id5, id4),
MockMessage(id4, id3),
MockMessage(id3, id2),
MockMessage(id2, id1),
MockMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 5
assert [msg["id"] for msg in result] == [id5, id4, id3, id2, id1]
def test_extract_thread_messages_branched_thread():
id1, id2, id3, id4 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
MockMessage(id4, id2),
MockMessage(id3, id2),
MockMessage(id2, id1),
MockMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 3
assert [msg["id"] for msg in result] == [id4, id2, id1]
def test_extract_thread_messages_empty_list():
messages = []
result = extract_thread_messages(messages)
assert len(result) == 0
def test_extract_thread_messages_partially_loaded():
id0, id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
MockMessage(id3, id2),
MockMessage(id2, id1),
MockMessage(id1, id0),
]
result = extract_thread_messages(messages)
assert len(result) == 3
assert [msg["id"] for msg in result] == [id3, id2, id1]
def test_extract_thread_messages_legacy_messages():
id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4())
messages = [
MockMessage(id3, UUID_NIL),
MockMessage(id2, UUID_NIL),
MockMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 3
assert [msg["id"] for msg in result] == [id3, id2, id1]
def test_extract_thread_messages_mixed_with_legacy_messages():
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
MockMessage(id5, id4),
MockMessage(id4, id2),
MockMessage(id3, id2),
MockMessage(id2, UUID_NIL),
MockMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 4
assert [msg["id"] for msg in result] == [id5, id4, id2, id1]

View File

@@ -0,0 +1,27 @@
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
TextPromptMessageContent,
UserPromptMessage,
)
def test_build_prompt_message_with_prompt_message_contents():
prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")])
assert isinstance(prompt.content, list)
assert isinstance(prompt.content[0], TextPromptMessageContent)
assert prompt.content[0].data == "Hello, World!"
def test_dump_prompt_message():
example_url = "https://example.com/image.jpg"
prompt = UserPromptMessage(
content=[
ImagePromptMessageContent(
url=example_url,
format="jpeg",
mime_type="image/jpeg",
)
]
)
data = prompt.model_dump()
assert data["content"][0].get("url") == example_url

View File

@@ -0,0 +1,52 @@
# from unittest.mock import MagicMock
# from core.app.app_config.entities import ModelConfigEntity
# from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
# from core.model_runtime.entities.message_entities import UserPromptMessage
# from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
# from core.model_runtime.entities.provider_entities import ProviderEntity
# from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
# from core.prompt.prompt_transform import PromptTransform
# def test__calculate_rest_token():
# model_schema_mock = MagicMock(spec=AIModelEntity)
# parameter_rule_mock = MagicMock(spec=ParameterRule)
# parameter_rule_mock.name = "max_tokens"
# model_schema_mock.parameter_rules = [parameter_rule_mock]
# model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62}
# large_language_model_mock = MagicMock(spec=LargeLanguageModel)
# large_language_model_mock.get_num_tokens.return_value = 6
# provider_mock = MagicMock(spec=ProviderEntity)
# provider_mock.provider = "openai"
# provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
# provider_configuration_mock.provider = provider_mock
# provider_configuration_mock.model_settings = None
# provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
# provider_model_bundle_mock.model_type_instance = large_language_model_mock
# provider_model_bundle_mock.configuration = provider_configuration_mock
# model_config_mock = MagicMock(spec=ModelConfigEntity)
# model_config_mock.model = "gpt-4"
# model_config_mock.credentials = {}
# model_config_mock.parameters = {"max_tokens": 50}
# model_config_mock.model_schema = model_schema_mock
# model_config_mock.provider_model_bundle = provider_model_bundle_mock
# prompt_transform = PromptTransform()
# prompt_messages = [UserPromptMessage(content="Hello, how are you?")]
# rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock)
# # Validate based on the mock configuration and expected logic
# expected_rest_tokens = (
# model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
# - model_config_mock.parameters["max_tokens"]
# - large_language_model_mock.get_num_tokens.return_value
# )
# assert rest_tokens == expected_rest_tokens
# assert rest_tokens == 6

View File

@@ -0,0 +1,246 @@
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from core.prompt.simple_prompt_transform import SimplePromptTransform
from models.model import AppMode, Conversation
def test_get_common_chat_app_prompt_template_with_pcqm():
prompt_transform = SimplePromptTransform()
pre_prompt = "You are a helpful assistant."
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=True,
)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"]
+ pre_prompt
+ "\n"
+ prompt_rules["histories_prompt"]
+ prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"]
def test_get_baichuan_chat_app_prompt_template_with_pcqm():
prompt_transform = SimplePromptTransform()
pre_prompt = "You are a helpful assistant."
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="baichuan",
model="Baichuan2-53B",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=True,
)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"]
+ pre_prompt
+ "\n"
+ prompt_rules["histories_prompt"]
+ prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"]
def test_get_common_completion_app_prompt_template_with_pcq():
prompt_transform = SimplePromptTransform()
pre_prompt = "You are a helpful assistant."
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.WORKFLOW,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
def test_get_baichuan_completion_app_prompt_template_with_pcq():
prompt_transform = SimplePromptTransform()
pre_prompt = "You are a helpful assistant."
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.WORKFLOW,
provider="baichuan",
model="Baichuan2-53B",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
def test_get_common_chat_app_prompt_template_with_q():
prompt_transform = SimplePromptTransform()
pre_prompt = ""
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=False,
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == prompt_rules["query_prompt"]
assert prompt_template["special_variable_keys"] == ["#query#"]
def test_get_common_chat_app_prompt_template_with_cq():
prompt_transform = SimplePromptTransform()
pre_prompt = ""
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"] + prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
def test_get_common_chat_app_prompt_template_with_p():
prompt_transform = SimplePromptTransform()
pre_prompt = "you are {{name}}"
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider="openai",
model="gpt-4",
pre_prompt=pre_prompt,
has_context=False,
query_in_prompt=False,
with_memory_prompt=False,
)
assert prompt_template["prompt_template"].template == pre_prompt + "\n"
assert prompt_template["custom_variable_keys"] == ["name"]
assert prompt_template["special_variable_keys"] == []
def test__get_chat_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-4"
memory_mock = MagicMock(spec=TokenBufferMemory)
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
memory_mock.get_history_prompt_messages.return_value = history_prompt_messages
prompt_transform = SimplePromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
pre_prompt = "You are a helpful assistant {{name}}."
inputs = {"name": "John"}
context = "yes or no."
query = "How are you?"
prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages(
app_mode=AppMode.CHAT,
pre_prompt=pre_prompt,
inputs=inputs,
query=query,
files=[],
context=context,
memory=memory_mock,
model_config=model_config_mock,
)
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider=model_config_mock.provider,
model=model_config_mock.model,
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=False,
with_memory_prompt=False,
)
full_inputs = {**inputs, "#context#": context}
real_system_prompt = prompt_template["prompt_template"].format(full_inputs)
assert len(prompt_messages) == 4
assert prompt_messages[0].content == real_system_prompt
assert prompt_messages[1].content == history_prompt_messages[0].content
assert prompt_messages[2].content == history_prompt_messages[1].content
assert prompt_messages[3].content == query
def test__get_completion_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-3.5-turbo-instruct"
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = SimplePromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
pre_prompt = "You are a helpful assistant {{name}}."
inputs = {"name": "John"}
context = "yes or no."
query = "How are you?"
prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages(
app_mode=AppMode.CHAT,
pre_prompt=pre_prompt,
inputs=inputs,
query=query,
files=[],
context=context,
memory=memory,
model_config=model_config_mock,
)
prompt_template = prompt_transform.get_prompt_template(
app_mode=AppMode.CHAT,
provider=model_config_mock.provider,
model=model_config_mock.model,
pre_prompt=pre_prompt,
has_context=True,
query_in_prompt=True,
with_memory_prompt=True,
)
prompt_rules = prompt_template["prompt_rules"]
full_inputs = {
**inputs,
"#context#": context,
"#query#": query,
"#histories#": memory.get_history_prompt_text(
max_token_limit=2000,
human_prefix=prompt_rules.get("human_prefix", "Human"),
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
),
}
real_prompt = prompt_template["prompt_template"].format(full_inputs)
assert len(prompt_messages) == 1
assert stops == prompt_rules.get("stops")
assert prompt_messages[0].content == real_prompt

View File

@@ -0,0 +1,733 @@
import json
import unittest
from unittest.mock import MagicMock, patch
import pytest
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
AlibabaCloudMySQLVector,
AlibabaCloudMySQLVectorConfig,
)
from core.rag.models.document import Document
try:
from mysql.connector import Error as MySQLError
except ImportError:
# Fallback for testing environments where mysql-connector-python might not be installed
class MySQLError(Exception):
def __init__(self, errno, msg):
self.errno = errno
self.msg = msg
super().__init__(msg)
class TestAlibabaCloudMySQLVector(unittest.TestCase):
def setUp(self):
self.config = AlibabaCloudMySQLVectorConfig(
host="localhost",
port=3306,
user="test_user",
password="test_password",
database="test_db",
max_connection=5,
charset="utf8mb4",
)
self.collection_name = "test_collection"
# Sample documents for testing
self.sample_documents = [
Document(
page_content="This is a test document about AI.",
metadata={"doc_id": "doc1", "document_id": "dataset1", "source": "test"},
),
Document(
page_content="Another document about machine learning.",
metadata={"doc_id": "doc2", "document_id": "dataset1", "source": "test"},
),
]
# Sample embeddings
self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_init(self, mock_pool_class):
"""Test AlibabaCloudMySQLVector initialization."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor for vector support check
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [
{"VERSION()": "8.0.36"}, # Version check
{"vector_support": True}, # Vector support check
]
alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config)
assert alibabacloud_mysql_vector.collection_name == self.collection_name
assert alibabacloud_mysql_vector.table_name == self.collection_name.lower()
assert alibabacloud_mysql_vector.get_type() == "alibabacloud_mysql"
assert alibabacloud_mysql_vector.distance_function == "cosine"
assert alibabacloud_mysql_vector.pool is not None
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
def test_create_collection(self, mock_redis, mock_pool_class):
"""Test collection creation."""
# Mock Redis operations
mock_redis.lock.return_value.__enter__ = MagicMock()
mock_redis.lock.return_value.__exit__ = MagicMock()
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [
{"VERSION()": "8.0.36"}, # Version check
{"vector_support": True}, # Vector support check
]
alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config)
alibabacloud_mysql_vector._create_collection(768)
# Verify SQL execution calls - should include table creation and index creation
assert mock_cursor.execute.called
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
mock_redis.set.assert_called_once()
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_vector_support_check_success(self, mock_pool_class):
"""Test successful vector support check."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
# Should not raise an exception
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
assert vector_store is not None
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_vector_support_check_failure(self, mock_pool_class):
"""Test vector support check failure."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.35"}, {"vector_support": False}]
with pytest.raises(ValueError) as context:
AlibabaCloudMySQLVector(self.collection_name, self.config)
assert "RDS MySQL Vector functions are not available" in str(context.value)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_vector_support_check_function_error(self, mock_pool_class):
"""Test vector support check with function not found error."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = {"VERSION()": "8.0.36"}
mock_cursor.execute.side_effect = [None, MySQLError(errno=1305, msg="FUNCTION VEC_FromText does not exist")]
with pytest.raises(ValueError) as context:
AlibabaCloudMySQLVector(self.collection_name, self.config)
assert "RDS MySQL Vector functions are not available" in str(context.value)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
def test_create_documents(self, mock_redis, mock_pool_class):
"""Test creating documents with embeddings."""
# Setup mocks
self._setup_mocks(mock_redis, mock_pool_class)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
result = vector_store.create(self.sample_documents, self.sample_embeddings)
assert len(result) == 2
assert "doc1" in result
assert "doc2" in result
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_add_texts(self, mock_pool_class):
"""Test adding texts to the vector store."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
result = vector_store.add_texts(self.sample_documents, self.sample_embeddings)
assert len(result) == 2
mock_cursor.executemany.assert_called_once()
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_text_exists(self, mock_pool_class):
"""Test checking if text exists."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [
{"VERSION()": "8.0.36"},
{"vector_support": True},
{"id": "doc1"}, # Text exists
]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
exists = vector_store.text_exists("doc1")
assert exists
# Check that the correct SQL was executed (last call after init)
execute_calls = mock_cursor.execute.call_args_list
last_call = execute_calls[-1]
assert "SELECT id FROM" in last_call[0][0]
assert last_call[0][1] == ("doc1",)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_text_not_exists(self, mock_pool_class):
"""Test checking if text does not exist."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [
{"VERSION()": "8.0.36"},
{"vector_support": True},
None, # Text does not exist
]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
exists = vector_store.text_exists("nonexistent")
assert not exists
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_get_by_ids(self, mock_pool_class):
"""Test getting documents by IDs."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[
{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1"},
{"meta": json.dumps({"doc_id": "doc2", "source": "test"}), "text": "Test document 2"},
]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
docs = vector_store.get_by_ids(["doc1", "doc2"])
assert len(docs) == 2
assert docs[0].page_content == "Test document 1"
assert docs[1].page_content == "Test document 2"
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_get_by_ids_empty_list(self, mock_pool_class):
"""Test getting documents with empty ID list."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
docs = vector_store.get_by_ids([])
assert len(docs) == 0
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_ids(self, mock_pool_class):
"""Test deleting documents by IDs."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
vector_store.delete_by_ids(["doc1", "doc2"])
# Check that delete SQL was executed
execute_calls = mock_cursor.execute.call_args_list
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
assert len(delete_calls) == 1
delete_call = delete_calls[0]
assert "DELETE FROM" in delete_call[0][0]
assert delete_call[0][1] == ["doc1", "doc2"]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_ids_empty_list(self, mock_pool_class):
"""Test deleting with empty ID list."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
vector_store.delete_by_ids([]) # Should not raise an exception
# Verify no delete SQL was executed
execute_calls = mock_cursor.execute.call_args_list
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
assert len(delete_calls) == 0
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_ids_table_not_exists(self, mock_pool_class):
"""Test deleting when table doesn't exist."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
# Simulate table doesn't exist error on delete
def execute_side_effect(*args, **kwargs):
if "DELETE" in args[0]:
raise MySQLError(errno=1146, msg="Table doesn't exist")
mock_cursor.execute.side_effect = execute_side_effect
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
# Should not raise an exception
vector_store.delete_by_ids(["doc1"])
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_by_metadata_field(self, mock_pool_class):
"""Test deleting documents by metadata field."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
vector_store.delete_by_metadata_field("document_id", "dataset1")
# Check that the correct SQL was executed
execute_calls = mock_cursor.execute.call_args_list
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
assert len(delete_calls) == 1
delete_call = delete_calls[0]
assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0]
assert delete_call[0][1] == ("$.document_id", "dataset1")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_cosine(self, mock_pool_class):
"""Test vector search with cosine distance."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 0.1}]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
query_vector = [0.1, 0.2, 0.3, 0.4]
docs = vector_store.search_by_vector(query_vector, top_k=5)
assert len(docs) == 1
assert docs[0].page_content == "Test document 1"
assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9
assert docs[0].metadata["distance"] == 0.1
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_euclidean(self, mock_pool_class):
"""Test vector search with euclidean distance."""
config = AlibabaCloudMySQLVectorConfig(
host="localhost",
port=3306,
user="test_user",
password="test_password",
database="test_db",
max_connection=5,
distance_function="euclidean",
)
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 2.0}]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, config)
query_vector = [0.1, 0.2, 0.3, 0.4]
docs = vector_store.search_by_vector(query_vector, top_k=5)
assert len(docs) == 1
assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_with_filter(self, mock_pool_class):
"""Test vector search with document ID filter."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter([])
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
query_vector = [0.1, 0.2, 0.3, 0.4]
docs = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["dataset1"])
# Verify the SQL contains the WHERE clause for filtering
execute_calls = mock_cursor.execute.call_args_list
search_calls = [call for call in execute_calls if "VEC_DISTANCE" in str(call)]
assert len(search_calls) > 0
search_call = search_calls[0]
assert "WHERE JSON_UNQUOTE" in search_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_with_score_threshold(self, mock_pool_class):
"""Test vector search with score threshold."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[
{
"meta": json.dumps({"doc_id": "doc1", "source": "test"}),
"text": "High similarity document",
"distance": 0.1, # High similarity (score = 0.9)
},
{
"meta": json.dumps({"doc_id": "doc2", "source": "test"}),
"text": "Low similarity document",
"distance": 0.8, # Low similarity (score = 0.2)
},
]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
query_vector = [0.1, 0.2, 0.3, 0.4]
docs = vector_store.search_by_vector(query_vector, top_k=5, score_threshold=0.5)
# Only the high similarity document should be returned
assert len(docs) == 1
assert docs[0].page_content == "High similarity document"
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_vector_invalid_top_k(self, mock_pool_class):
"""Test vector search with invalid top_k."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
query_vector = [0.1, 0.2, 0.3, 0.4]
with pytest.raises(ValueError):
vector_store.search_by_vector(query_vector, top_k=0)
with pytest.raises(ValueError):
vector_store.search_by_vector(query_vector, top_k="invalid")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_full_text(self, mock_pool_class):
"""Test full-text search."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter(
[
{
"meta": {"doc_id": "doc1", "source": "test"},
"text": "This document contains machine learning content",
"score": 1.5,
}
]
)
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
docs = vector_store.search_by_full_text("machine learning", top_k=5)
assert len(docs) == 1
assert docs[0].page_content == "This document contains machine learning content"
assert docs[0].metadata["score"] == 1.5
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_full_text_with_filter(self, mock_pool_class):
"""Test full-text search with document ID filter."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
mock_cursor.__iter__ = lambda self: iter([])
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
docs = vector_store.search_by_full_text("machine learning", top_k=5, document_ids_filter=["dataset1"])
# Verify the SQL contains the AND clause for filtering
execute_calls = mock_cursor.execute.call_args_list
search_calls = [call for call in execute_calls if "MATCH" in str(call)]
assert len(search_calls) > 0
search_call = search_calls[0]
assert "AND JSON_UNQUOTE" in search_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_search_by_full_text_invalid_top_k(self, mock_pool_class):
"""Test full-text search with invalid top_k."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
with pytest.raises(ValueError):
vector_store.search_by_full_text("test", top_k=0)
with pytest.raises(ValueError):
vector_store.search_by_full_text("test", top_k="invalid")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_delete_collection(self, mock_pool_class):
"""Test deleting the entire collection."""
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
vector_store.delete()
# Check that DROP TABLE SQL was executed
execute_calls = mock_cursor.execute.call_args_list
drop_calls = [call for call in execute_calls if "DROP TABLE" in str(call)]
assert len(drop_calls) == 1
drop_call = drop_calls[0]
assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
def test_unsupported_distance_function(self, mock_pool_class):
"""Test that Pydantic validation rejects unsupported distance functions."""
# Test that creating config with unsupported distance function raises ValidationError
with pytest.raises(ValueError) as context:
AlibabaCloudMySQLVectorConfig(
host="localhost",
port=3306,
user="test_user",
password="test_password",
database="test_db",
max_connection=5,
distance_function="manhattan", # Unsupported - not in Literal["cosine", "euclidean"]
)
# The error should be related to validation
assert "Input should be 'cosine' or 'euclidean'" in str(context.value) or "manhattan" in str(context.value)
def _setup_mocks(self, mock_redis, mock_pool_class):
"""Helper method to setup common mocks."""
# Mock Redis operations
mock_redis.lock.return_value.__enter__ = MagicMock()
mock_redis.lock.return_value.__exit__ = MagicMock()
mock_redis.get.return_value = None
mock_redis.set.return_value = None
# Mock the connection pool
mock_pool = MagicMock()
mock_pool_class.return_value = mock_pool
# Mock connection and cursor
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_pool.get_connection.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
@pytest.mark.parametrize(
"invalid_config_override",
[
{"host": ""}, # Test empty host
{"port": 0}, # Test invalid port
{"max_connection": 0}, # Test invalid max_connection
],
)
def test_config_validation_parametrized(invalid_config_override):
"""Test configuration validation for various invalid inputs using parametrize."""
config = {
"host": "localhost",
"port": 3306,
"user": "test",
"password": "test",
"database": "test",
"max_connection": 5,
}
config.update(invalid_config_override)
with pytest.raises(ValueError):
AlibabaCloudMySQLVectorConfig(**config)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,18 @@
import pytest
from pydantic import ValidationError
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig
def test_default_value():
valid_config = {"uri": "http://localhost:19530", "user": "root", "password": "Milvus"}
for key in valid_config:
config = valid_config.copy()
del config[key]
with pytest.raises(ValidationError) as e:
MilvusConfig.model_validate(config)
assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required"
config = MilvusConfig.model_validate(valid_config)
assert config.database == "default"

View File

@@ -0,0 +1,27 @@
import os
from pytest_mock import MockerFixture
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
url = "https://firecrawl.dev"
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
base_url = "https://api.firecrawl.dev"
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url)
params = {
"includePaths": [],
"excludePaths": [],
"maxDepth": 1,
"limit": 1,
}
mocked_firecrawl = {
"id": "test",
}
mocker.patch("httpx.post", return_value=_mock_response(mocked_firecrawl))
job_id = firecrawl_app.crawl_url(url, params)
assert job_id is not None
assert isinstance(job_id, str)

View File

@@ -0,0 +1,22 @@
from core.rag.extractor.markdown_extractor import MarkdownExtractor
def test_markdown_to_tups():
markdown = """
this is some text without header
# title 1
this is balabala text
## title 2
this is more specific text.
"""
extractor = MarkdownExtractor(file_path="dummy_path")
updated_output = extractor.markdown_to_tups(markdown)
assert len(updated_output) == 3
key, header_value = updated_output[0]
assert key == None
assert header_value.strip() == "this is some text without header"
title_1, value = updated_output[1]
assert title_1.strip() == "title 1"
assert value.strip() == "this is balabala text"

View File

@@ -0,0 +1,93 @@
from unittest import mock
from pytest_mock import MockerFixture
from core.rag.extractor import notion_extractor
user_id = "user1"
database_id = "database1"
page_id = "page1"
extractor = notion_extractor.NotionExtractor(
notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x"
)
def _generate_page(page_title: str):
return {
"object": "page",
"id": page_id,
"properties": {
"Page": {
"type": "title",
"title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}],
}
},
}
def _generate_block(block_id: str, block_type: str, block_text: str):
return {
"object": "block",
"id": block_id,
"parent": {"type": "page_id", "page_id": page_id},
"type": block_type,
"has_children": False,
block_type: {
"rich_text": [
{
"type": "text",
"text": {"content": block_text},
"plain_text": block_text,
}
]
},
}
def _mock_response(data):
response = mock.Mock()
response.status_code = 200
response.json.return_value = data
return response
def _remove_multiple_new_lines(text):
while "\n\n" in text:
text = text.replace("\n\n", "\n")
return text.strip()
def test_notion_page(mocker: MockerFixture):
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
mocked_notion_page = {
"object": "list",
"results": [
_generate_block("b1", "heading_1", texts[0]),
_generate_block("b2", "heading_2", texts[1]),
_generate_block("b3", "paragraph", texts[2]),
_generate_block("b4", "heading_3", texts[3]),
],
"next_cursor": None,
}
mocker.patch("httpx.request", return_value=_mock_response(mocked_notion_page))
page_docs = extractor._load_data_as_documents(page_id, "page")
assert len(page_docs) == 1
content = _remove_multiple_new_lines(page_docs[0].page_content)
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
def test_notion_database(mocker: MockerFixture):
page_title_list = ["page1", "page2", "page3"]
mocked_notion_database = {
"object": "list",
"results": [_generate_page(i) for i in page_title_list],
"next_cursor": None,
}
mocker.patch("httpx.post", return_value=_mock_response(mocked_notion_database))
database_docs = extractor._load_data_as_documents(database_id, "database")
assert len(database_docs) == 1
content = _remove_multiple_new_lines(database_docs[0].page_content)
assert content == "\n".join([f"Page:{i}" for i in page_title_list])

View File

@@ -0,0 +1,49 @@
"""Primarily used for testing merged cell scenarios"""
from docx import Document
from core.rag.extractor.word_extractor import WordExtractor
def _generate_table_with_merged_cells():
doc = Document()
"""
The table looks like this:
+-----+-----+-----+
| 1-1 & 1-2 | 1-3 |
+-----+-----+-----+
| 2-1 | 2-2 | 2-3 |
| & |-----+-----+
| 3-1 | 3-2 | 3-3 |
+-----+-----+-----+
"""
table = doc.add_table(rows=3, cols=3)
table.style = "Table Grid"
for i in range(3):
for j in range(3):
cell = table.cell(i, j)
cell.text = f"{i + 1}-{j + 1}"
# Merge cells
cell_0_0 = table.cell(0, 0)
cell_0_1 = table.cell(0, 1)
merged_cell_1 = cell_0_0.merge(cell_0_1)
merged_cell_1.text = "1-1 & 1-2"
cell_1_0 = table.cell(1, 0)
cell_2_0 = table.cell(2, 0)
merged_cell_2 = cell_1_0.merge(cell_2_0)
merged_cell_2.text = "2-1 & 3-1"
ground_truth = [["1-1 & 1-2", "", "1-3"], ["2-1 & 3-1", "2-2", "2-3"], ["2-1 & 3-1", "3-2", "3-3"]]
return doc.tables[0], ground_truth
def test_parse_row():
table, gt = _generate_table_with_merged_cells()
extractor = object.__new__(WordExtractor)
for idx, row in enumerate(table.rows):
assert extractor._parse_row(row, {}, 3) == gt[idx]

View File

@@ -0,0 +1,301 @@
"""
Unit tests for TenantIsolatedTaskQueue.
These tests verify the Redis-based task queue functionality for tenant-specific
task management with proper serialization and deserialization.
"""
import json
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from pydantic import ValidationError
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
class TestTaskWrapper:
"""Test cases for TaskWrapper serialization/deserialization."""
def test_serialize_simple_data(self):
"""Test serialization of simple data types."""
data = {"key": "value", "number": 42, "list": [1, 2, 3]}
wrapper = TaskWrapper(data=data)
serialized = wrapper.serialize()
assert isinstance(serialized, str)
# Verify it's valid JSON
parsed = json.loads(serialized)
assert parsed["data"] == data
def test_serialize_complex_data(self):
"""Test serialization of complex nested data."""
data = {
"nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5]}},
"unicode": "测试中文",
"special_chars": "!@#$%^&*()",
}
wrapper = TaskWrapper(data=data)
serialized = wrapper.serialize()
parsed = json.loads(serialized)
assert parsed["data"] == data
def test_deserialize_valid_data(self):
"""Test deserialization of valid JSON data."""
original_data = {"key": "value", "number": 42}
# Serialize using TaskWrapper to get the correct format
wrapper = TaskWrapper(data=original_data)
serialized = wrapper.serialize()
wrapper = TaskWrapper.deserialize(serialized)
assert wrapper.data == original_data
def test_deserialize_invalid_json(self):
"""Test deserialization handles invalid JSON gracefully."""
invalid_json = "{invalid json}"
# Pydantic will raise ValidationError for invalid JSON
with pytest.raises(ValidationError):
TaskWrapper.deserialize(invalid_json)
def test_serialize_ensure_ascii_false(self):
"""Test that serialization preserves Unicode characters."""
data = {"chinese": "中文测试", "emoji": "🚀"}
wrapper = TaskWrapper(data=data)
serialized = wrapper.serialize()
assert "中文测试" in serialized
assert "🚀" in serialized
class TestTenantIsolatedTaskQueue:
"""Test cases for TenantIsolatedTaskQueue functionality."""
@pytest.fixture
def mock_redis_client(self):
"""Mock Redis client for testing."""
mock_redis = MagicMock()
return mock_redis
@pytest.fixture
def sample_queue(self, mock_redis_client):
"""Create a sample TenantIsolatedTaskQueue instance."""
return TenantIsolatedTaskQueue("tenant-123", "test-key")
def test_initialization(self, sample_queue):
"""Test queue initialization with correct key generation."""
assert sample_queue._tenant_id == "tenant-123"
assert sample_queue._unique_key == "test-key"
assert sample_queue._queue == "tenant_self_test-key_task_queue:tenant-123"
assert sample_queue._task_key == "tenant_test-key_task:tenant-123"
@patch("core.rag.pipeline.queue.redis_client")
def test_get_task_key_exists(self, mock_redis, sample_queue):
"""Test getting task key when it exists."""
mock_redis.get.return_value = "1"
result = sample_queue.get_task_key()
assert result == "1"
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_get_task_key_not_exists(self, mock_redis, sample_queue):
"""Test getting task key when it doesn't exist."""
mock_redis.get.return_value = None
result = sample_queue.get_task_key()
assert result is None
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_set_task_waiting_time_default_ttl(self, mock_redis, sample_queue):
"""Test setting task waiting flag with default TTL."""
sample_queue.set_task_waiting_time()
mock_redis.setex.assert_called_once_with(
"tenant_test-key_task:tenant-123",
3600, # DEFAULT_TASK_TTL
1,
)
@patch("core.rag.pipeline.queue.redis_client")
def test_set_task_waiting_time_custom_ttl(self, mock_redis, sample_queue):
"""Test setting task waiting flag with custom TTL."""
custom_ttl = 1800
sample_queue.set_task_waiting_time(custom_ttl)
mock_redis.setex.assert_called_once_with("tenant_test-key_task:tenant-123", custom_ttl, 1)
@patch("core.rag.pipeline.queue.redis_client")
def test_delete_task_key(self, mock_redis, sample_queue):
"""Test deleting task key."""
sample_queue.delete_task_key()
mock_redis.delete.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_push_tasks_string_list(self, mock_redis, sample_queue):
"""Test pushing string tasks directly."""
tasks = ["task1", "task2", "task3"]
sample_queue.push_tasks(tasks)
mock_redis.lpush.assert_called_once_with(
"tenant_self_test-key_task_queue:tenant-123", "task1", "task2", "task3"
)
@patch("core.rag.pipeline.queue.redis_client")
def test_push_tasks_mixed_types(self, mock_redis, sample_queue):
"""Test pushing mixed string and object tasks."""
tasks = ["string_task", {"object_task": "data", "id": 123}, "another_string"]
sample_queue.push_tasks(tasks)
# Verify lpush was called
mock_redis.lpush.assert_called_once()
call_args = mock_redis.lpush.call_args
# Check queue name
assert call_args[0][0] == "tenant_self_test-key_task_queue:tenant-123"
# Check serialized tasks
serialized_tasks = call_args[0][1:]
assert len(serialized_tasks) == 3
assert serialized_tasks[0] == "string_task"
assert serialized_tasks[2] == "another_string"
# Check object task is serialized as TaskWrapper JSON (without prefix)
# It should be a valid JSON string that can be deserialized by TaskWrapper
wrapper = TaskWrapper.deserialize(serialized_tasks[1])
assert wrapper.data == {"object_task": "data", "id": 123}
@patch("core.rag.pipeline.queue.redis_client")
def test_push_tasks_empty_list(self, mock_redis, sample_queue):
"""Test pushing empty task list."""
sample_queue.push_tasks([])
mock_redis.lpush.assert_not_called()
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_default_count(self, mock_redis, sample_queue):
"""Test pulling tasks with default count (1)."""
mock_redis.rpop.side_effect = ["task1", None]
result = sample_queue.pull_tasks()
assert result == ["task1"]
assert mock_redis.rpop.call_count == 1
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_custom_count(self, mock_redis, sample_queue):
"""Test pulling tasks with custom count."""
# First test: pull 3 tasks
mock_redis.rpop.side_effect = ["task1", "task2", "task3", None]
result = sample_queue.pull_tasks(3)
assert result == ["task1", "task2", "task3"]
assert mock_redis.rpop.call_count == 3
# Reset mock for second test
mock_redis.reset_mock()
mock_redis.rpop.side_effect = ["task1", "task2", None]
result = sample_queue.pull_tasks(3)
assert result == ["task1", "task2"]
assert mock_redis.rpop.call_count == 3
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_zero_count(self, mock_redis, sample_queue):
"""Test pulling tasks with zero count returns empty list."""
result = sample_queue.pull_tasks(0)
assert result == []
mock_redis.rpop.assert_not_called()
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_negative_count(self, mock_redis, sample_queue):
"""Test pulling tasks with negative count returns empty list."""
result = sample_queue.pull_tasks(-1)
assert result == []
mock_redis.rpop.assert_not_called()
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_with_wrapped_objects(self, mock_redis, sample_queue):
"""Test pulling tasks that include wrapped objects."""
# Create a wrapped task
task_data = {"task_id": 123, "data": "test"}
wrapper = TaskWrapper(data=task_data)
wrapped_task = wrapper.serialize()
mock_redis.rpop.side_effect = [
"string_task",
wrapped_task.encode("utf-8"), # Simulate bytes from Redis
None,
]
result = sample_queue.pull_tasks(2)
assert len(result) == 2
assert result[0] == "string_task"
assert result[1] == {"task_id": 123, "data": "test"}
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_with_invalid_wrapped_data(self, mock_redis, sample_queue):
"""Test pulling tasks with invalid JSON falls back to string."""
# Invalid JSON string that cannot be deserialized
invalid_json = "invalid json data"
mock_redis.rpop.side_effect = [invalid_json, None]
result = sample_queue.pull_tasks(1)
assert result == [invalid_json]
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_bytes_decoding(self, mock_redis, sample_queue):
"""Test pulling tasks handles bytes from Redis correctly."""
mock_redis.rpop.side_effect = [
b"task1", # bytes
"task2", # string
None,
]
result = sample_queue.pull_tasks(2)
assert result == ["task1", "task2"]
@patch("core.rag.pipeline.queue.redis_client")
def test_complex_object_serialization_roundtrip(self, mock_redis, sample_queue):
"""Test complex object serialization and deserialization roundtrip."""
complex_task = {
"id": uuid4().hex,
"data": {"nested": {"deep": [1, 2, 3], "unicode": "测试中文", "special": "!@#$%^&*()"}},
"metadata": {"created_at": "2024-01-01T00:00:00Z", "tags": ["tag1", "tag2", "tag3"]},
}
# Push the complex task
sample_queue.push_tasks([complex_task])
# Verify it was serialized as TaskWrapper JSON
call_args = mock_redis.lpush.call_args
wrapped_task = call_args[0][1]
# Verify it's a valid TaskWrapper JSON (starts with {"data":)
assert wrapped_task.startswith('{"data":')
# Verify it can be deserialized
wrapper = TaskWrapper.deserialize(wrapped_task)
assert wrapper.data == complex_task
# Simulate pulling it back
mock_redis.rpop.return_value = wrapped_task
result = sample_queue.pull_tasks(1)
assert len(result) == 1
assert result[0] == complex_task

View File

@@ -0,0 +1 @@
# Unit tests for core repositories module

View File

@@ -0,0 +1,247 @@
"""
Unit tests for CeleryWorkflowExecutionRepository.
These tests verify the Celery-based asynchronous storage functionality
for workflow execution data.
"""
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType
from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Create a real sessionmaker with in-memory SQLite for testing
engine = create_engine("sqlite:///:memory:")
return sessionmaker(bind=engine)
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = Mock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = Mock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_execution():
"""Sample WorkflowExecution for testing."""
return WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=naive_utc_now(),
)
class TestCeleryWorkflowExecutionRepository:
"""Test cases for CeleryWorkflowExecutionRepository."""
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
"""Test repository initialization with sessionmaker."""
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=app_id,
triggered_from=triggered_from,
)
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role is not None
def test_init_basic_functionality(self, mock_session_factory, mock_account):
"""Test repository initialization basic functionality."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Verify basic initialization
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == "test-app"
assert repo._triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
"""Test repository initialization with EndUser."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert repo._tenant_id == mock_end_user.tenant_id
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
"""Test that initialization fails without tenant_id."""
# Create a mock Account with no tenant_id
user = Mock(spec=Account)
user.current_tenant_id = None
user.id = str(uuid4())
with pytest.raises(ValueError, match="User must have a tenant_id"):
CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_queues_celery_task(self, mock_task, mock_session_factory, mock_account, sample_workflow_execution):
"""Test that save operation queues a Celery task without tracking."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
repo.save(sample_workflow_execution)
# Verify Celery task was queued with correct parameters
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN
assert call_args["creator_user_id"] == mock_account.id
# Verify no task tracking occurs (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_handles_celery_failure(
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
):
"""Test that save operation handles Celery task failures."""
mock_task.delay.side_effect = Exception("Celery is down")
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
with pytest.raises(Exception, match="Celery is down"):
repo.save(sample_workflow_execution)
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_operation_fire_and_forget(
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
):
"""Test that save operation works in fire-and-forget mode."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Test that save doesn't block or maintain state
repo.save(sample_workflow_execution)
# Verify no pending saves are tracked (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_multiple_save_operations(self, mock_task, mock_session_factory, mock_account):
"""Test multiple save operations work correctly."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Create multiple executions
exec1 = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=naive_utc_now(),
)
exec2 = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input2": "value2"},
started_at=naive_utc_now(),
)
# Save both executions
repo.save(exec1)
repo.save(exec2)
# Should work without issues and not maintain state (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_with_different_user_types(self, mock_task, mock_session_factory, mock_end_user):
"""Test save operation with different user types."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
execution = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=naive_utc_now(),
)
repo.save(execution)
# Verify task was called with EndUser context
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["tenant_id"] == mock_end_user.tenant_id
assert call_args["creator_user_id"] == mock_end_user.id

View File

@@ -0,0 +1,349 @@
"""
Unit tests for CeleryWorkflowNodeExecutionRepository.
These tests verify the Celery-based asynchronous storage functionality
for workflow node execution data.
"""
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
)
from core.workflow.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Create a real sessionmaker with in-memory SQLite for testing
engine = create_engine("sqlite:///:memory:")
return sessionmaker(bind=engine)
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = Mock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = Mock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_node_execution():
"""Sample WorkflowNodeExecution for testing."""
return WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=str(uuid4()),
index=1,
node_id="test_node",
node_type=NodeType.START,
title="Test Node",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
)
class TestCeleryWorkflowNodeExecutionRepository:
"""Test cases for CeleryWorkflowNodeExecutionRepository."""
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
"""Test repository initialization with sessionmaker."""
app_id = "test-app-id"
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=app_id,
triggered_from=triggered_from,
)
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role is not None
def test_init_with_cache_initialized(self, mock_session_factory, mock_account):
"""Test repository initialization with cache properly initialized."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert repo._execution_cache == {}
assert repo._workflow_execution_mapping == {}
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
"""Test repository initialization with EndUser."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert repo._tenant_id == mock_end_user.tenant_id
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
"""Test that initialization fails without tenant_id."""
# Create a mock Account with no tenant_id
user = Mock(spec=Account)
user.current_tenant_id = None
user.id = str(uuid4())
with pytest.raises(ValueError, match="User must have a tenant_id"):
CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=user,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_save_caches_and_queues_celery_task(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that save operation caches execution and queues a Celery task."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
repo.save(sample_workflow_node_execution)
# Verify Celery task was queued with correct parameters
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
assert call_args["creator_user_id"] == mock_account.id
# Verify execution is cached
assert sample_workflow_node_execution.id in repo._execution_cache
assert repo._execution_cache[sample_workflow_node_execution.id] == sample_workflow_node_execution
# Verify workflow execution mapping is updated
assert sample_workflow_node_execution.workflow_execution_id in repo._workflow_execution_mapping
assert (
sample_workflow_node_execution.id
in repo._workflow_execution_mapping[sample_workflow_node_execution.workflow_execution_id]
)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_save_handles_celery_failure(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that save operation handles Celery task failures."""
mock_task.delay.side_effect = Exception("Celery is down")
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
with pytest.raises(Exception, match="Celery is down"):
repo.save(sample_workflow_node_execution)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_get_by_workflow_run_from_cache(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that get_by_workflow_run retrieves executions from cache."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Save execution to cache first
repo.save(sample_workflow_node_execution)
workflow_run_id = sample_workflow_node_execution.workflow_execution_id
order_config = OrderConfig(order_by=["index"], order_direction="asc")
result = repo.get_by_workflow_run(workflow_run_id, order_config)
# Verify results were retrieved from cache
assert len(result) == 1
assert result[0].id == sample_workflow_node_execution.id
assert result[0] is sample_workflow_node_execution
def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account):
"""Test get_by_workflow_run without order configuration."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
result = repo.get_by_workflow_run("workflow-run-id")
# Should return empty list since nothing in cache
assert len(result) == 0
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_cache_operations(self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution):
"""Test cache operations work correctly."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Test saving to cache
repo.save(sample_workflow_node_execution)
# Verify cache contains the execution
assert sample_workflow_node_execution.id in repo._execution_cache
# Test retrieving from cache
result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id)
assert len(result) == 1
assert result[0].id == sample_workflow_node_execution.id
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_multiple_executions_same_workflow(self, mock_task, mock_session_factory, mock_account):
"""Test multiple executions for the same workflow."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Create multiple executions for the same workflow
workflow_run_id = str(uuid4())
exec1 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=1,
node_id="node1",
node_type=NodeType.START,
title="Node 1",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
)
exec2 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=2,
node_id="node2",
node_type=NodeType.LLM,
title="Node 2",
inputs={"input2": "value2"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
)
# Save both executions
repo.save(exec1)
repo.save(exec2)
# Verify both are cached and mapped
assert len(repo._execution_cache) == 2
assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2
# Test retrieval
result = repo.get_by_workflow_run(workflow_run_id)
assert len(result) == 2
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_ordering_functionality(self, mock_task, mock_session_factory, mock_account):
"""Test ordering functionality works correctly."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Create executions with different indices
workflow_run_id = str(uuid4())
exec1 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=2,
node_id="node2",
node_type=NodeType.START,
title="Node 2",
inputs={},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
)
exec2 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=1,
node_id="node1",
node_type=NodeType.LLM,
title="Node 1",
inputs={},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
)
# Save in random order
repo.save(exec1)
repo.save(exec2)
# Test ascending order
order_config = OrderConfig(order_by=["index"], order_direction="asc")
result = repo.get_by_workflow_run(workflow_run_id, order_config)
assert len(result) == 2
assert result[0].index == 1
assert result[1].index == 2
# Test descending order
order_config = OrderConfig(order_by=["index"], order_direction="desc")
result = repo.get_by_workflow_run(workflow_run_id, order_config)
assert len(result) == 2
assert result[0].index == 2
assert result[1].index == 1

View File

@@ -0,0 +1,244 @@
"""
Unit tests for the RepositoryFactory.
This module tests the factory pattern implementation for creating repository instances
based on configuration, including error handling.
"""
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from libs.module_loading import import_string
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowNodeExecutionTriggeredFrom
class TestRepositoryFactory:
"""Test cases for RepositoryFactory."""
def test_import_string_success(self):
"""Test successful class import."""
# Test importing a real class
class_path = "unittest.mock.MagicMock"
result = import_string(class_path)
assert result is MagicMock
def test_import_string_invalid_path(self):
"""Test import with invalid module path."""
with pytest.raises(ImportError) as exc_info:
import_string("invalid.module.path")
assert "No module named" in str(exc_info.value)
def test_import_string_invalid_class_name(self):
"""Test import with invalid class name."""
with pytest.raises(ImportError) as exc_info:
import_string("unittest.mock.NonExistentClass")
assert "does not define" in str(exc_info.value)
def test_import_string_malformed_path(self):
"""Test import with malformed path (no dots)."""
with pytest.raises(ImportError) as exc_info:
import_string("invalidpath")
assert "doesn't look like a module path" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_success(self, mock_config):
"""Test successful WorkflowExecutionRepository creation."""
# Setup mock configuration
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
# Create mock repository class and instance
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
# Mock import_string
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id=app_id,
triggered_from=triggered_from,
)
# Verify the repository was created with correct parameters
mock_repository_class.assert_called_once_with(
session_factory=mock_session_factory,
user=mock_user,
app_id=app_id,
triggered_from=triggered_from,
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_import_error(self, mock_config):
"""Test WorkflowExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
# Create a mock repository class that raises exception on instantiation
mock_repository_class = MagicMock()
mock_repository_class.side_effect = Exception("Instantiation failed")
# Mock import_string to return a failing class
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_success(self, mock_config):
"""Test successful WorkflowNodeExecutionRepository creation."""
# Setup mock configuration
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
app_id = "test-app-id"
triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP
# Create mock repository class and instance
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
# Mock import_string
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id=app_id,
triggered_from=triggered_from,
)
# Verify the repository was created with correct parameters
mock_repository_class.assert_called_once_with(
session_factory=mock_session_factory,
user=mock_user,
app_id=app_id,
triggered_from=triggered_from,
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Create a mock repository class that raises exception on instantiation
mock_repository_class = MagicMock()
mock_repository_class.side_effect = Exception("Instantiation failed")
# Mock import_string to return a failing class
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
def test_repository_import_error_exception(self):
"""Test RepositoryImportError exception handling."""
error_message = "Custom error message"
error = RepositoryImportError(error_message)
assert str(error) == error_message
@patch("core.repositories.factory.dify_config")
def test_create_with_engine_instead_of_sessionmaker(self, mock_config):
"""Test repository creation with Engine instead of sessionmaker."""
# Setup mock configuration
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies using Engine instead of sessionmaker
mock_engine = MagicMock(spec=Engine)
mock_user = MagicMock(spec=Account)
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
# Create mock repository class and instance
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
# Mock import_string
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_engine, # Using Engine instead of sessionmaker
user=mock_user,
app_id=app_id,
triggered_from=triggered_from,
)
# Verify the repository was created with correct parameters
mock_repository_class.assert_called_once_with(
session_factory=mock_engine,
user=mock_user,
app_id=app_id,
triggered_from=triggered_from,
)
assert result is mock_repository_instance

View File

@@ -0,0 +1,210 @@
"""Unit tests for workflow node execution conflict handling."""
from unittest.mock import MagicMock, Mock
import psycopg2.errors
import pytest
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
SQLAlchemyWorkflowNodeExecutionRepository,
)
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
)
from core.workflow.enums import NodeType
from libs.datetime_utils import naive_utc_now
from models import Account, WorkflowNodeExecutionTriggeredFrom
class TestWorkflowNodeExecutionConflictHandling:
"""Test cases for handling duplicate key conflicts in workflow node execution."""
def setup_method(self):
"""Set up test fixtures."""
# Create a mock user with tenant_id
self.mock_user = Mock(spec=Account)
self.mock_user.id = "test-user-id"
self.mock_user.current_tenant_id = "test-tenant-id"
# Create mock session factory
self.mock_session_factory = Mock(spec=sessionmaker)
# Create repository instance
self.repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=self.mock_session_factory,
user=self.mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_save_with_duplicate_key_retries_with_new_uuid(self):
"""Test that save retries with a new UUID v7 when encountering duplicate key error."""
# Create a mock session
mock_session = MagicMock()
mock_session.__enter__ = Mock(return_value=mock_session)
mock_session.__exit__ = Mock(return_value=None)
self.mock_session_factory.return_value = mock_session
# Mock session.get to return None (no existing record)
mock_session.get.return_value = None
# Create IntegrityError for duplicate key with proper psycopg2.errors.UniqueViolation
mock_unique_violation = Mock(spec=psycopg2.errors.UniqueViolation)
duplicate_error = IntegrityError(
"duplicate key value violates unique constraint",
params=None,
orig=mock_unique_violation,
)
# First call to session.add raises IntegrityError, second succeeds
mock_session.add.side_effect = [duplicate_error, None]
mock_session.commit.side_effect = [None, None]
# Create test execution
execution = WorkflowNodeExecution(
id="original-id",
workflow_id="test-workflow-id",
workflow_execution_id="test-workflow-execution-id",
node_execution_id="test-node-execution-id",
node_id="test-node-id",
node_type=NodeType.START,
title="Test Node",
index=1,
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
)
original_id = execution.id
# Save should succeed after retry
self.repository.save(execution)
# Verify that session.add was called twice (initial attempt + retry)
assert mock_session.add.call_count == 2
# Verify that the ID was changed (new UUID v7 generated)
assert execution.id != original_id
def test_save_with_existing_record_updates_instead_of_insert(self):
"""Test that save updates existing record instead of inserting duplicate."""
# Create a mock session
mock_session = MagicMock()
mock_session.__enter__ = Mock(return_value=mock_session)
mock_session.__exit__ = Mock(return_value=None)
self.mock_session_factory.return_value = mock_session
# Mock existing record
mock_existing = MagicMock()
mock_session.get.return_value = mock_existing
mock_session.commit.return_value = None
# Create test execution
execution = WorkflowNodeExecution(
id="existing-id",
workflow_id="test-workflow-id",
workflow_execution_id="test-workflow-execution-id",
node_execution_id="test-node-execution-id",
node_id="test-node-id",
node_type=NodeType.START,
title="Test Node",
index=1,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_at=naive_utc_now(),
)
# Save should update existing record
self.repository.save(execution)
# Verify that session.add was not called (update path)
mock_session.add.assert_not_called()
# Verify that session.commit was called
mock_session.commit.assert_called_once()
def test_save_exceeds_max_retries_raises_error(self):
"""Test that save raises error after exceeding max retries."""
# Create a mock session
mock_session = MagicMock()
mock_session.__enter__ = Mock(return_value=mock_session)
mock_session.__exit__ = Mock(return_value=None)
self.mock_session_factory.return_value = mock_session
# Mock session.get to return None (no existing record)
mock_session.get.return_value = None
# Create IntegrityError for duplicate key with proper psycopg2.errors.UniqueViolation
mock_unique_violation = Mock(spec=psycopg2.errors.UniqueViolation)
duplicate_error = IntegrityError(
"duplicate key value violates unique constraint",
params=None,
orig=mock_unique_violation,
)
# All attempts fail with duplicate error
mock_session.add.side_effect = duplicate_error
# Create test execution
execution = WorkflowNodeExecution(
id="test-id",
workflow_id="test-workflow-id",
workflow_execution_id="test-workflow-execution-id",
node_execution_id="test-node-execution-id",
node_id="test-node-id",
node_type=NodeType.START,
title="Test Node",
index=1,
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
)
# Save should raise IntegrityError after max retries
with pytest.raises(IntegrityError):
self.repository.save(execution)
# Verify that session.add was called 3 times (max_retries)
assert mock_session.add.call_count == 3
def test_save_non_duplicate_integrity_error_raises_immediately(self):
"""Test that non-duplicate IntegrityErrors are raised immediately without retry."""
# Create a mock session
mock_session = MagicMock()
mock_session.__enter__ = Mock(return_value=mock_session)
mock_session.__exit__ = Mock(return_value=None)
self.mock_session_factory.return_value = mock_session
# Mock session.get to return None (no existing record)
mock_session.get.return_value = None
# Create IntegrityError for non-duplicate constraint
other_error = IntegrityError(
"null value in column violates not-null constraint",
params=None,
orig=None,
)
# First call raises non-duplicate error
mock_session.add.side_effect = other_error
# Create test execution
execution = WorkflowNodeExecution(
id="test-id",
workflow_id="test-workflow-id",
workflow_execution_id="test-workflow-execution-id",
node_execution_id="test-node-execution-id",
node_id="test-node-id",
node_type=NodeType.START,
title="Test Node",
index=1,
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
)
# Save should raise error immediately
with pytest.raises(IntegrityError):
self.repository.save(execution)
# Verify that session.add was called only once (no retry)
assert mock_session.add.call_count == 1

View File

@@ -0,0 +1,217 @@
"""
Unit tests for WorkflowNodeExecution truncation functionality.
Tests the truncation and offloading logic for large inputs and outputs
in the SQLAlchemyWorkflowNodeExecutionRepository.
"""
import json
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
from unittest.mock import MagicMock
from sqlalchemy import Engine
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
SQLAlchemyWorkflowNodeExecutionRepository,
)
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
)
from core.workflow.enums import NodeType
from models import Account, WorkflowNodeExecutionTriggeredFrom
from models.enums import ExecutionOffLoadType
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
@dataclass
class TruncationTestCase:
"""Test case data for truncation scenarios."""
name: str
inputs: dict[str, Any] | None
outputs: dict[str, Any] | None
should_truncate_inputs: bool
should_truncate_outputs: bool
description: str
def create_test_cases() -> list[TruncationTestCase]:
"""Create test cases for different truncation scenarios."""
# Create large data that will definitely exceed the threshold (10KB)
large_data = {"data": "x" * (TRUNCATION_SIZE_THRESHOLD + 1000)}
small_data = {"data": "small"}
return [
TruncationTestCase(
name="small_data_no_truncation",
inputs=small_data,
outputs=small_data,
should_truncate_inputs=False,
should_truncate_outputs=False,
description="Small data should not be truncated",
),
TruncationTestCase(
name="large_inputs_truncation",
inputs=large_data,
outputs=small_data,
should_truncate_inputs=True,
should_truncate_outputs=False,
description="Large inputs should be truncated",
),
TruncationTestCase(
name="large_outputs_truncation",
inputs=small_data,
outputs=large_data,
should_truncate_inputs=False,
should_truncate_outputs=True,
description="Large outputs should be truncated",
),
TruncationTestCase(
name="large_both_truncation",
inputs=large_data,
outputs=large_data,
should_truncate_inputs=True,
should_truncate_outputs=True,
description="Both large inputs and outputs should be truncated",
),
TruncationTestCase(
name="none_inputs_outputs",
inputs=None,
outputs=None,
should_truncate_inputs=False,
should_truncate_outputs=False,
description="None inputs and outputs should not be truncated",
),
]
def create_workflow_node_execution(
execution_id: str = "test-execution-id",
inputs: dict[str, Any] | None = None,
outputs: dict[str, Any] | None = None,
) -> WorkflowNodeExecution:
"""Factory function to create a WorkflowNodeExecution for testing."""
return WorkflowNodeExecution(
id=execution_id,
node_execution_id="test-node-execution-id",
workflow_id="test-workflow-id",
workflow_execution_id="test-workflow-execution-id",
index=1,
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
inputs=inputs,
outputs=outputs,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_at=datetime.now(UTC),
)
def mock_user() -> Account:
"""Create a mock Account user for testing."""
from unittest.mock import MagicMock
user = MagicMock(spec=Account)
user.id = "test-user-id"
user.current_tenant_id = "test-tenant-id"
return user
class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation:
"""Test class for truncation functionality in SQLAlchemyWorkflowNodeExecutionRepository."""
def create_repository(self) -> SQLAlchemyWorkflowNodeExecutionRepository:
"""Create a repository instance for testing."""
return SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=MagicMock(spec=Engine),
user=mock_user(),
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_to_domain_model_without_offload_data(self):
"""Test _to_domain_model correctly handles models without offload data."""
repo = self.create_repository()
# Create a mock database model without offload data
db_model = WorkflowNodeExecutionModel()
db_model.id = "test-id"
db_model.node_execution_id = "node-exec-id"
db_model.workflow_id = "workflow-id"
db_model.workflow_run_id = "run-id"
db_model.index = 1
db_model.predecessor_node_id = None
db_model.node_id = "node-id"
db_model.node_type = NodeType.LLM
db_model.title = "Test Node"
db_model.inputs = json.dumps({"value": "inputs"})
db_model.process_data = json.dumps({"value": "process_data"})
db_model.outputs = json.dumps({"value": "outputs"})
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
db_model.error = None
db_model.elapsed_time = 1.0
db_model.execution_metadata = "{}"
db_model.created_at = datetime.now(UTC)
db_model.finished_at = None
db_model.offload_data = []
domain_model = repo._to_domain_model(db_model)
# Check that no truncated data was set
assert domain_model.get_truncated_inputs() is None
assert domain_model.get_truncated_outputs() is None
class TestWorkflowNodeExecutionModelTruncatedProperties:
"""Test the truncated properties on WorkflowNodeExecutionModel."""
def test_inputs_truncated_with_offload_data(self):
"""Test inputs_truncated property when offload data exists."""
model = WorkflowNodeExecutionModel()
offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
model.offload_data = [offload]
assert model.inputs_truncated is True
assert model.process_data_truncated is False
assert model.outputs_truncated is False
def test_outputs_truncated_with_offload_data(self):
"""Test outputs_truncated property when offload data exists."""
model = WorkflowNodeExecutionModel()
# Mock offload data with outputs file
offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
model.offload_data = [offload]
assert model.inputs_truncated is False
assert model.process_data_truncated is False
assert model.outputs_truncated is True
def test_process_data_truncated_with_offload_data(self):
model = WorkflowNodeExecutionModel()
offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA)
model.offload_data = [offload]
assert model.process_data_truncated is True
assert model.inputs_truncated is False
assert model.outputs_truncated is False
def test_truncated_properties_without_offload_data(self):
"""Test truncated properties when no offload data exists."""
model = WorkflowNodeExecutionModel()
model.offload_data = []
assert model.inputs_truncated is False
assert model.outputs_truncated is False
assert model.process_data_truncated is False
def test_truncated_properties_without_offload_attribute(self):
"""Test truncated properties when offload_data attribute doesn't exist."""
model = WorkflowNodeExecutionModel()
# Don't set offload_data attribute at all
assert model.inputs_truncated is False
assert model.outputs_truncated is False
assert model.process_data_truncated is False

View File

@@ -0,0 +1 @@
# Core schemas unit tests

View File

@@ -0,0 +1,769 @@
import time
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import MagicMock, patch
import pytest
from core.schemas import resolve_dify_schema_refs
from core.schemas.registry import SchemaRegistry
from core.schemas.resolver import (
MaxDepthExceededError,
SchemaResolver,
_has_dify_refs,
_has_dify_refs_hybrid,
_has_dify_refs_recursive,
_is_dify_schema_ref,
_remove_metadata_fields,
parse_dify_schema_uri,
)
class TestSchemaResolver:
"""Test cases for schema reference resolution"""
def setup_method(self):
"""Setup method to initialize test resources"""
self.registry = SchemaRegistry.default_registry()
# Clear cache before each test
SchemaResolver.clear_cache()
def teardown_method(self):
"""Cleanup after each test"""
SchemaResolver.clear_cache()
def test_simple_ref_resolution(self):
"""Test resolving a simple $ref to a complete schema"""
schema_with_ref = {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
resolved = resolve_dify_schema_refs(schema_with_ref)
# Should be resolved to the actual qa_structure schema
assert resolved["type"] == "object"
assert resolved["title"] == "Q&A Structure"
assert "qa_chunks" in resolved["properties"]
assert resolved["properties"]["qa_chunks"]["type"] == "array"
# Metadata fields should be removed
assert "$id" not in resolved
assert "$schema" not in resolved
assert "version" not in resolved
def test_nested_object_with_refs(self):
"""Test resolving $refs within nested object structures"""
nested_schema = {
"type": "object",
"properties": {
"file_data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
"metadata": {"type": "string", "description": "Additional metadata"},
},
}
resolved = resolve_dify_schema_refs(nested_schema)
# Original structure should be preserved
assert resolved["type"] == "object"
assert "metadata" in resolved["properties"]
assert resolved["properties"]["metadata"]["type"] == "string"
# $ref should be resolved
file_schema = resolved["properties"]["file_data"]
assert file_schema["type"] == "object"
assert file_schema["title"] == "File"
assert "name" in file_schema["properties"]
# Metadata fields should be removed from resolved schema
assert "$id" not in file_schema
assert "$schema" not in file_schema
assert "version" not in file_schema
def test_array_items_ref_resolution(self):
"""Test resolving $refs in array items"""
array_schema = {
"type": "array",
"items": {"$ref": "https://dify.ai/schemas/v1/general_structure.json"},
"description": "Array of general structures",
}
resolved = resolve_dify_schema_refs(array_schema)
# Array structure should be preserved
assert resolved["type"] == "array"
assert resolved["description"] == "Array of general structures"
# Items $ref should be resolved
items_schema = resolved["items"]
assert items_schema["type"] == "array"
assert items_schema["title"] == "General Structure"
def test_non_dify_ref_unchanged(self):
"""Test that non-Dify $refs are left unchanged"""
external_ref_schema = {
"type": "object",
"properties": {
"external_data": {"$ref": "https://example.com/external-schema.json"},
"dify_data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
},
}
resolved = resolve_dify_schema_refs(external_ref_schema)
# External $ref should remain unchanged
assert resolved["properties"]["external_data"]["$ref"] == "https://example.com/external-schema.json"
# Dify $ref should be resolved
assert resolved["properties"]["dify_data"]["type"] == "object"
assert resolved["properties"]["dify_data"]["title"] == "File"
def test_no_refs_schema_unchanged(self):
"""Test that schemas without $refs are returned unchanged"""
simple_schema = {
"type": "object",
"properties": {
"name": {"type": "string", "description": "Name field"},
"items": {"type": "array", "items": {"type": "number"}},
},
"required": ["name"],
}
resolved = resolve_dify_schema_refs(simple_schema)
# Should be identical to input
assert resolved == simple_schema
assert resolved["type"] == "object"
assert resolved["properties"]["name"]["type"] == "string"
assert resolved["properties"]["items"]["items"]["type"] == "number"
assert resolved["required"] == ["name"]
def test_recursion_depth_protection(self):
"""Test that excessive recursion depth is prevented"""
# Create a moderately nested structure
deep_schema = {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
# Wrap it in fewer layers to make the test more reasonable
for _ in range(2):
deep_schema = {"type": "object", "properties": {"nested": deep_schema}}
# Should handle normal cases fine with reasonable depth
resolved = resolve_dify_schema_refs(deep_schema, max_depth=25)
assert resolved is not None
assert resolved["type"] == "object"
# Should raise error with very low max_depth
with pytest.raises(MaxDepthExceededError) as exc_info:
resolve_dify_schema_refs(deep_schema, max_depth=5)
assert exc_info.value.max_depth == 5
def test_circular_reference_detection(self):
"""Test that circular references are detected and handled"""
# Mock registry with circular reference
mock_registry = MagicMock()
mock_registry.get_schema.side_effect = lambda uri: {
"$ref": "https://dify.ai/schemas/v1/circular.json",
"type": "object",
}
schema = {"$ref": "https://dify.ai/schemas/v1/circular.json"}
resolved = resolve_dify_schema_refs(schema, registry=mock_registry)
# Should mark circular reference
assert "$circular_ref" in resolved
def test_schema_not_found_handling(self):
"""Test handling of missing schemas"""
# Mock registry that returns None for unknown schemas
mock_registry = MagicMock()
mock_registry.get_schema.return_value = None
schema = {"$ref": "https://dify.ai/schemas/v1/unknown.json"}
resolved = resolve_dify_schema_refs(schema, registry=mock_registry)
# Should keep the original $ref when schema not found
assert resolved["$ref"] == "https://dify.ai/schemas/v1/unknown.json"
def test_primitive_types_unchanged(self):
"""Test that primitive types are returned unchanged"""
assert resolve_dify_schema_refs("string") == "string"
assert resolve_dify_schema_refs(123) == 123
assert resolve_dify_schema_refs(True) is True
assert resolve_dify_schema_refs(None) is None
assert resolve_dify_schema_refs(3.14) == 3.14
def test_cache_functionality(self):
"""Test that caching works correctly"""
schema = {"$ref": "https://dify.ai/schemas/v1/file.json"}
# First resolution should fetch from registry
resolved1 = resolve_dify_schema_refs(schema)
# Mock the registry to return different data
with patch.object(self.registry, "get_schema") as mock_get:
mock_get.return_value = {"type": "different"}
# Second resolution should use cache
resolved2 = resolve_dify_schema_refs(schema)
# Should be the same as first resolution (from cache)
assert resolved1 == resolved2
# Mock should not have been called
mock_get.assert_not_called()
# Clear cache and try again
SchemaResolver.clear_cache()
# Now it should fetch again
resolved3 = resolve_dify_schema_refs(schema)
assert resolved3 == resolved1
def test_thread_safety(self):
"""Test that the resolver is thread-safe"""
schema = {
"type": "object",
"properties": {f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"} for i in range(10)},
}
results = []
def resolve_in_thread():
try:
result = resolve_dify_schema_refs(schema)
results.append(result)
return True
except Exception as e:
results.append(e)
return False
# Run multiple threads concurrently
with ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(resolve_in_thread) for _ in range(20)]
success = all(f.result() for f in futures)
assert success
# All results should be the same
first_result = results[0]
assert all(r == first_result for r in results if not isinstance(r, Exception))
def test_mixed_nested_structures(self):
"""Test resolving refs in complex mixed structures"""
complex_schema = {
"type": "object",
"properties": {
"files": {"type": "array", "items": {"$ref": "https://dify.ai/schemas/v1/file.json"}},
"nested": {
"type": "object",
"properties": {
"qa": {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"},
"data": {
"type": "array",
"items": {
"type": "object",
"properties": {
"general": {"$ref": "https://dify.ai/schemas/v1/general_structure.json"}
},
},
},
},
},
},
}
resolved = resolve_dify_schema_refs(complex_schema, max_depth=20)
# Check structure is preserved
assert resolved["type"] == "object"
assert "files" in resolved["properties"]
assert "nested" in resolved["properties"]
# Check refs are resolved
assert resolved["properties"]["files"]["items"]["type"] == "object"
assert resolved["properties"]["files"]["items"]["title"] == "File"
assert resolved["properties"]["nested"]["properties"]["qa"]["type"] == "object"
assert resolved["properties"]["nested"]["properties"]["qa"]["title"] == "Q&A Structure"
class TestUtilityFunctions:
"""Test utility functions"""
def test_is_dify_schema_ref(self):
"""Test _is_dify_schema_ref function"""
# Valid Dify refs
assert _is_dify_schema_ref("https://dify.ai/schemas/v1/file.json")
assert _is_dify_schema_ref("https://dify.ai/schemas/v2/complex_name.json")
assert _is_dify_schema_ref("https://dify.ai/schemas/v999/test-file.json")
# Invalid refs
assert not _is_dify_schema_ref("https://example.com/schema.json")
assert not _is_dify_schema_ref("https://dify.ai/other/path.json")
assert not _is_dify_schema_ref("not a uri")
assert not _is_dify_schema_ref("")
assert not _is_dify_schema_ref(None)
assert not _is_dify_schema_ref(123)
assert not _is_dify_schema_ref(["list"])
def test_has_dify_refs(self):
"""Test _has_dify_refs function"""
# Schemas with Dify refs
assert _has_dify_refs({"$ref": "https://dify.ai/schemas/v1/file.json"})
assert _has_dify_refs(
{"type": "object", "properties": {"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}}}
)
assert _has_dify_refs([{"type": "string"}, {"$ref": "https://dify.ai/schemas/v1/file.json"}])
assert _has_dify_refs(
{
"type": "array",
"items": {
"type": "object",
"properties": {"nested": {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}},
},
}
)
# Schemas without Dify refs
assert not _has_dify_refs({"type": "string"})
assert not _has_dify_refs(
{"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "number"}}}
)
assert not _has_dify_refs(
[{"type": "string"}, {"type": "number"}, {"type": "object", "properties": {"name": {"type": "string"}}}]
)
# Schemas with non-Dify refs (should return False)
assert not _has_dify_refs({"$ref": "https://example.com/schema.json"})
assert not _has_dify_refs(
{"type": "object", "properties": {"external": {"$ref": "https://example.com/external.json"}}}
)
# Primitive types
assert not _has_dify_refs("string")
assert not _has_dify_refs(123)
assert not _has_dify_refs(True)
assert not _has_dify_refs(None)
def test_has_dify_refs_hybrid_vs_recursive(self):
"""Test that hybrid and recursive detection give same results"""
test_schemas = [
# No refs
{"type": "string"},
{"type": "object", "properties": {"name": {"type": "string"}}},
[{"type": "string"}, {"type": "number"}],
# With Dify refs
{"$ref": "https://dify.ai/schemas/v1/file.json"},
{"type": "object", "properties": {"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}}},
[{"type": "string"}, {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}],
# With non-Dify refs
{"$ref": "https://example.com/schema.json"},
{"type": "object", "properties": {"external": {"$ref": "https://example.com/external.json"}}},
# Complex nested
{
"type": "object",
"properties": {
"level1": {
"type": "object",
"properties": {
"level2": {"type": "array", "items": {"$ref": "https://dify.ai/schemas/v1/file.json"}}
},
}
},
},
# Edge cases
{"description": "This mentions $ref but is not a reference"},
{"$ref": "not-a-url"},
# Primitive types
"string",
123,
True,
None,
[],
]
for schema in test_schemas:
hybrid_result = _has_dify_refs_hybrid(schema)
recursive_result = _has_dify_refs_recursive(schema)
assert hybrid_result == recursive_result, f"Mismatch for schema: {schema}"
def test_parse_dify_schema_uri(self):
"""Test parse_dify_schema_uri function"""
# Valid URIs
assert parse_dify_schema_uri("https://dify.ai/schemas/v1/file.json") == ("v1", "file")
assert parse_dify_schema_uri("https://dify.ai/schemas/v2/complex_name.json") == ("v2", "complex_name")
assert parse_dify_schema_uri("https://dify.ai/schemas/v999/test-file.json") == ("v999", "test-file")
# Invalid URIs
assert parse_dify_schema_uri("https://example.com/schema.json") == ("", "")
assert parse_dify_schema_uri("invalid") == ("", "")
assert parse_dify_schema_uri("") == ("", "")
def test_remove_metadata_fields(self):
"""Test _remove_metadata_fields function"""
schema = {
"$id": "should be removed",
"$schema": "should be removed",
"version": "should be removed",
"type": "object",
"title": "should remain",
"properties": {},
}
cleaned = _remove_metadata_fields(schema)
assert "$id" not in cleaned
assert "$schema" not in cleaned
assert "version" not in cleaned
assert cleaned["type"] == "object"
assert cleaned["title"] == "should remain"
assert "properties" in cleaned
# Original should be unchanged
assert "$id" in schema
class TestSchemaResolverClass:
"""Test SchemaResolver class specifically"""
def test_resolver_initialization(self):
"""Test resolver initialization"""
# Default initialization
resolver = SchemaResolver()
assert resolver.max_depth == 10
assert resolver.registry is not None
# Custom initialization
custom_registry = MagicMock()
resolver = SchemaResolver(registry=custom_registry, max_depth=5)
assert resolver.max_depth == 5
assert resolver.registry is custom_registry
def test_cache_sharing(self):
"""Test that cache is shared between resolver instances"""
SchemaResolver.clear_cache()
schema = {"$ref": "https://dify.ai/schemas/v1/file.json"}
# First resolver populates cache
resolver1 = SchemaResolver()
result1 = resolver1.resolve(schema)
# Second resolver should use the same cache
resolver2 = SchemaResolver()
with patch.object(resolver2.registry, "get_schema") as mock_get:
result2 = resolver2.resolve(schema)
# Should not call registry since it's in cache
mock_get.assert_not_called()
assert result1 == result2
def test_resolver_with_list_schema(self):
"""Test resolver with list as root schema"""
list_schema = [
{"$ref": "https://dify.ai/schemas/v1/file.json"},
{"type": "string"},
{"$ref": "https://dify.ai/schemas/v1/qa_structure.json"},
]
resolver = SchemaResolver()
resolved = resolver.resolve(list_schema)
assert isinstance(resolved, list)
assert len(resolved) == 3
assert resolved[0]["type"] == "object"
assert resolved[0]["title"] == "File"
assert resolved[1] == {"type": "string"}
assert resolved[2]["type"] == "object"
assert resolved[2]["title"] == "Q&A Structure"
def test_cache_performance(self):
"""Test that caching improves performance"""
SchemaResolver.clear_cache()
# Create a schema with many references to the same schema
schema = {
"type": "object",
"properties": {
f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"}
for i in range(50) # Reduced to avoid depth issues
},
}
# First run (no cache) - run multiple times to warm up
results1 = []
for _ in range(3):
SchemaResolver.clear_cache()
start = time.perf_counter()
result1 = resolve_dify_schema_refs(schema)
time_no_cache = time.perf_counter() - start
results1.append(time_no_cache)
avg_time_no_cache = sum(results1) / len(results1)
# Second run (with cache) - run multiple times
results2 = []
for _ in range(3):
start = time.perf_counter()
result2 = resolve_dify_schema_refs(schema)
time_with_cache = time.perf_counter() - start
results2.append(time_with_cache)
avg_time_with_cache = sum(results2) / len(results2)
# Cache should make it faster (more lenient check)
assert result1 == result2
# Cache should provide some performance benefit (allow for measurement variance)
# We expect cache to be faster, but allow for small timing variations
performance_ratio = avg_time_with_cache / avg_time_no_cache if avg_time_no_cache > 0 else 1.0
assert performance_ratio <= 2.0, f"Cache performance degraded too much: {performance_ratio}"
def test_fast_path_performance_no_refs(self):
"""Test that schemas without $refs use fast path and avoid deep copying"""
# Create a moderately complex schema without any $refs (typical plugin output_schema)
no_refs_schema = {
"type": "object",
"properties": {
f"property_{i}": {
"type": "object",
"properties": {
"name": {"type": "string"},
"value": {"type": "number"},
"items": {"type": "array", "items": {"type": "string"}},
},
}
for i in range(50)
},
}
# Measure fast path (no refs) performance
fast_times = []
for _ in range(10):
start = time.perf_counter()
result_fast = resolve_dify_schema_refs(no_refs_schema)
elapsed = time.perf_counter() - start
fast_times.append(elapsed)
avg_fast_time = sum(fast_times) / len(fast_times)
# Most importantly: result should be identical to input (no copying)
assert result_fast is no_refs_schema
# Create schema with $refs for comparison (same structure size)
with_refs_schema = {
"type": "object",
"properties": {
f"property_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"}
for i in range(20) # Fewer to avoid depth issues but still comparable
},
}
# Measure slow path (with refs) performance
SchemaResolver.clear_cache()
slow_times = []
for _ in range(10):
SchemaResolver.clear_cache()
start = time.perf_counter()
result_slow = resolve_dify_schema_refs(with_refs_schema, max_depth=50)
elapsed = time.perf_counter() - start
slow_times.append(elapsed)
avg_slow_time = sum(slow_times) / len(slow_times)
# The key benefit: fast path should be reasonably fast (main goal is no deep copy)
# and definitely avoid the expensive BFS resolution
# Even if detection has some overhead, it should still be faster for typical cases
print(f"Fast path (no refs): {avg_fast_time:.6f}s")
print(f"Slow path (with refs): {avg_slow_time:.6f}s")
# More lenient check: fast path should be at least somewhat competitive
# The main benefit is avoiding deep copy and BFS, not necessarily being 5x faster
assert avg_fast_time < avg_slow_time * 2 # Should not be more than 2x slower
def test_batch_processing_performance(self):
"""Test performance improvement for batch processing of schemas without refs"""
# Simulate the plugin tool scenario: many schemas, most without refs
schemas_without_refs = [
{
"type": "object",
"properties": {f"field_{j}": {"type": "string" if j % 2 else "number"} for j in range(10)},
}
for i in range(100)
]
# Test batch processing performance
start = time.perf_counter()
results = [resolve_dify_schema_refs(schema) for schema in schemas_without_refs]
batch_time = time.perf_counter() - start
# Verify all results are identical to inputs (fast path used)
for original, result in zip(schemas_without_refs, results):
assert result is original
# Should be very fast - each schema should take < 0.001 seconds on average
avg_time_per_schema = batch_time / len(schemas_without_refs)
assert avg_time_per_schema < 0.001
def test_has_dify_refs_performance(self):
"""Test that _has_dify_refs is fast for large schemas without refs"""
# Create a very large schema without refs
large_schema = {"type": "object", "properties": {}}
# Add many nested properties
current = large_schema
for i in range(100):
current["properties"][f"level_{i}"] = {"type": "object", "properties": {}}
current = current["properties"][f"level_{i}"]
# _has_dify_refs should be fast even for large schemas
times = []
for _ in range(50):
start = time.perf_counter()
has_refs = _has_dify_refs(large_schema)
elapsed = time.perf_counter() - start
times.append(elapsed)
avg_time = sum(times) / len(times)
# Should be False and fast
assert not has_refs
assert avg_time < 0.01 # Should complete in less than 10ms
def test_hybrid_vs_recursive_performance(self):
"""Test performance comparison between hybrid and recursive detection"""
# Create test schemas of different types and sizes
test_cases = [
# Case 1: Small schema without refs (most common case)
{
"name": "small_no_refs",
"schema": {"type": "object", "properties": {"name": {"type": "string"}, "value": {"type": "number"}}},
"expected": False,
},
# Case 2: Medium schema without refs
{
"name": "medium_no_refs",
"schema": {
"type": "object",
"properties": {
f"field_{i}": {
"type": "object",
"properties": {
"name": {"type": "string"},
"value": {"type": "number"},
"items": {"type": "array", "items": {"type": "string"}},
},
}
for i in range(20)
},
},
"expected": False,
},
# Case 3: Large schema without refs
{"name": "large_no_refs", "schema": {"type": "object", "properties": {}}, "expected": False},
# Case 4: Schema with Dify refs
{
"name": "with_dify_refs",
"schema": {
"type": "object",
"properties": {
"file": {"$ref": "https://dify.ai/schemas/v1/file.json"},
"data": {"type": "string"},
},
},
"expected": True,
},
# Case 5: Schema with non-Dify refs
{
"name": "with_external_refs",
"schema": {
"type": "object",
"properties": {"external": {"$ref": "https://example.com/schema.json"}, "data": {"type": "string"}},
},
"expected": False,
},
]
# Add deep nesting to large schema
current = test_cases[2]["schema"]
for i in range(50):
current["properties"][f"level_{i}"] = {"type": "object", "properties": {}}
current = current["properties"][f"level_{i}"]
# Performance comparison
for test_case in test_cases:
schema = test_case["schema"]
expected = test_case["expected"]
name = test_case["name"]
# Test correctness first
assert _has_dify_refs_hybrid(schema) == expected
assert _has_dify_refs_recursive(schema) == expected
# Measure hybrid performance
hybrid_times = []
for _ in range(10):
start = time.perf_counter()
result_hybrid = _has_dify_refs_hybrid(schema)
elapsed = time.perf_counter() - start
hybrid_times.append(elapsed)
# Measure recursive performance
recursive_times = []
for _ in range(10):
start = time.perf_counter()
result_recursive = _has_dify_refs_recursive(schema)
elapsed = time.perf_counter() - start
recursive_times.append(elapsed)
avg_hybrid = sum(hybrid_times) / len(hybrid_times)
avg_recursive = sum(recursive_times) / len(recursive_times)
print(f"{name}: hybrid={avg_hybrid:.6f}s, recursive={avg_recursive:.6f}s")
# Results should be identical
assert result_hybrid == result_recursive == expected
# For schemas without refs, hybrid should be competitive or better
if not expected: # No refs case
# Hybrid might be slightly slower due to JSON serialization overhead,
# but should not be dramatically worse
assert avg_hybrid < avg_recursive * 5 # At most 5x slower
def test_string_matching_edge_cases(self):
"""Test edge cases for string-based detection"""
# Case 1: False positive potential - $ref in description
schema_false_positive = {
"type": "object",
"properties": {
"description": {"type": "string", "description": "This field explains how $ref works in JSON Schema"}
},
}
# Both methods should return False
assert not _has_dify_refs_hybrid(schema_false_positive)
assert not _has_dify_refs_recursive(schema_false_positive)
# Case 2: Complex URL patterns
complex_schema = {
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"dify_url": {"type": "string", "default": "https://dify.ai/schemas/info"},
"actual_ref": {"$ref": "https://dify.ai/schemas/v1/file.json"},
},
}
},
}
# Both methods should return True (due to actual_ref)
assert _has_dify_refs_hybrid(complex_schema)
assert _has_dify_refs_recursive(complex_schema)
# Case 3: Non-JSON serializable objects (should fall back to recursive)
import datetime
non_serializable = {
"type": "object",
"timestamp": datetime.datetime.now(),
"data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
}
# Hybrid should fall back to recursive and still work
assert _has_dify_refs_hybrid(non_serializable)
assert _has_dify_refs_recursive(non_serializable)

View File

@@ -0,0 +1,56 @@
import json
from core.file import File, FileTransferMethod, FileType, FileUploadConfig
from models.workflow import Workflow
def test_file_to_dict():
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="storage_key",
)
file_dict = file.to_dict()
assert "_storage_key" not in file_dict
assert "url" in file_dict
def test_workflow_features_with_image():
# Create a feature dict that mimics the old structure with image config
features = {
"file_upload": {
"image": {"enabled": True, "number_limits": 5, "transfer_methods": ["remote_url", "local_file"]}
}
}
# Create a workflow instance with the features
workflow = Workflow(
tenant_id="tenant-1",
app_id="app-1",
type="chat",
version="1.0",
graph="{}",
features=json.dumps(features),
created_by="user-1",
environment_variables=[],
conversation_variables=[],
)
# Get the converted features through the property
converted_features = json.loads(workflow.features)
# Create FileUploadConfig from the converted features
file_upload_config = FileUploadConfig.model_validate(converted_features["file_upload"])
# Validate the config
assert file_upload_config.number_limits == 5
assert list(file_upload_config.allowed_file_types) == [FileType.IMAGE]
assert list(file_upload_config.allowed_file_upload_methods) == [
FileTransferMethod.REMOTE_URL,
FileTransferMethod.LOCAL_FILE,
]
assert list(file_upload_config.allowed_file_extensions) == []

View File

@@ -0,0 +1,73 @@
from unittest.mock import MagicMock, patch
import pytest
import redis
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.model_manager import LBModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_redis import redis_client
@pytest.fixture
def lb_model_manager():
load_balancing_configs = [
ModelLoadBalancingConfiguration(id="id1", name="__inherit__", credentials={}),
ModelLoadBalancingConfiguration(id="id2", name="first", credentials={"openai_api_key": "fake_key"}),
ModelLoadBalancingConfiguration(id="id3", name="second", credentials={"openai_api_key": "fake_key"}),
]
lb_model_manager = LBModelManager(
tenant_id="tenant_id",
provider="openai",
model_type=ModelType.LLM,
model="gpt-4",
load_balancing_configs=load_balancing_configs,
managed_credentials={"openai_api_key": "fake_key"},
)
lb_model_manager.cooldown = MagicMock(return_value=None)
def is_cooldown(config: ModelLoadBalancingConfiguration):
if config.id == "id1":
return True
return False
lb_model_manager.in_cooldown = MagicMock(side_effect=is_cooldown)
return lb_model_manager
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
# initialize redis client
redis_client.initialize(redis.Redis())
assert len(lb_model_manager._load_balancing_configs) == 3
config1 = lb_model_manager._load_balancing_configs[0]
config2 = lb_model_manager._load_balancing_configs[1]
config3 = lb_model_manager._load_balancing_configs[2]
assert lb_model_manager.in_cooldown(config1) is True
assert lb_model_manager.in_cooldown(config2) is False
assert lb_model_manager.in_cooldown(config3) is False
start_index = 0
def incr(key):
nonlocal start_index
start_index += 1
return start_index
with (
patch.object(redis_client, "incr", side_effect=incr),
patch.object(redis_client, "set", return_value=None),
patch.object(redis_client, "expire", return_value=None),
):
config = lb_model_manager.fetch_next()
assert config == config2
config = lb_model_manager.fetch_next()
assert config == config3

View File

@@ -0,0 +1,485 @@
from unittest.mock import Mock, patch
import pytest
from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus
from core.entities.provider_entities import (
CustomConfiguration,
ModelSettings,
ProviderQuotaType,
QuotaConfiguration,
QuotaUnit,
RestrictModel,
SystemConfiguration,
)
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
CredentialFormSchema,
FormOption,
FormType,
ProviderEntity,
)
from models.provider import Provider, ProviderType
@pytest.fixture
def mock_provider_entity():
"""Mock provider entity with basic configuration"""
provider_entity = ProviderEntity(
provider="openai",
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
description=I18nObject(en_US="OpenAI provider", zh_Hans="OpenAI 提供商"),
icon_small=I18nObject(en_US="icon.png", zh_Hans="icon.png"),
icon_large=I18nObject(en_US="icon.png", zh_Hans="icon.png"),
background="background.png",
help=None,
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
provider_credential_schema=None,
model_credential_schema=None,
)
return provider_entity
@pytest.fixture
def mock_system_configuration():
"""Mock system configuration"""
quota_config = QuotaConfiguration(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=1000,
quota_used=0,
is_valid=True,
restrict_models=[RestrictModel(model="gpt-4", reason="Experimental", model_type=ModelType.LLM)],
)
system_config = SystemConfiguration(
enabled=True,
credentials={"openai_api_key": "test_key"},
quota_configurations=[quota_config],
current_quota_type=ProviderQuotaType.TRIAL,
)
return system_config
@pytest.fixture
def mock_custom_configuration():
"""Mock custom configuration"""
custom_config = CustomConfiguration(provider=None, models=[])
return custom_config
@pytest.fixture
def provider_configuration(mock_provider_entity, mock_system_configuration, mock_custom_configuration):
"""Create a test provider configuration instance"""
with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}):
return ProviderConfiguration(
tenant_id="test_tenant",
provider=mock_provider_entity,
preferred_provider_type=ProviderType.SYSTEM,
using_provider_type=ProviderType.SYSTEM,
system_configuration=mock_system_configuration,
custom_configuration=mock_custom_configuration,
model_settings=[],
)
class TestProviderConfiguration:
"""Test cases for ProviderConfiguration class"""
def test_get_current_credentials_system_provider_success(self, provider_configuration):
"""Test successfully getting credentials from system provider"""
# Arrange
provider_configuration.using_provider_type = ProviderType.SYSTEM
# Act
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
# Assert
assert credentials == {"openai_api_key": "test_key"}
def test_get_current_credentials_model_disabled(self, provider_configuration):
"""Test getting credentials when model is disabled"""
# Arrange
model_setting = ModelSettings(
model="gpt-4",
model_type=ModelType.LLM,
enabled=False,
load_balancing_configs=[],
has_invalid_load_balancing_configs=False,
)
provider_configuration.model_settings = [model_setting]
# Act & Assert
with pytest.raises(ValueError, match="Model gpt-4 is disabled"):
provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
def test_get_current_credentials_custom_provider_with_models(self, provider_configuration):
"""Test getting credentials from custom provider with model configurations"""
# Arrange
provider_configuration.using_provider_type = ProviderType.CUSTOM
mock_model_config = Mock()
mock_model_config.model_type = ModelType.LLM
mock_model_config.model = "gpt-4"
mock_model_config.credentials = {"openai_api_key": "custom_key"}
provider_configuration.custom_configuration.models = [mock_model_config]
# Act
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
# Assert
assert credentials == {"openai_api_key": "custom_key"}
def test_get_system_configuration_status_active(self, provider_configuration):
"""Test getting active system configuration status"""
# Arrange
provider_configuration.system_configuration.enabled = True
# Act
status = provider_configuration.get_system_configuration_status()
# Assert
assert status == SystemConfigurationStatus.ACTIVE
def test_get_system_configuration_status_unsupported(self, provider_configuration):
"""Test getting unsupported system configuration status"""
# Arrange
provider_configuration.system_configuration.enabled = False
# Act
status = provider_configuration.get_system_configuration_status()
# Assert
assert status == SystemConfigurationStatus.UNSUPPORTED
def test_get_system_configuration_status_quota_exceeded(self, provider_configuration):
"""Test getting quota exceeded system configuration status"""
# Arrange
provider_configuration.system_configuration.enabled = True
quota_config = provider_configuration.system_configuration.quota_configurations[0]
quota_config.is_valid = False
# Act
status = provider_configuration.get_system_configuration_status()
# Assert
assert status == SystemConfigurationStatus.QUOTA_EXCEEDED
def test_is_custom_configuration_available_with_provider(self, provider_configuration):
"""Test custom configuration availability with provider credentials"""
# Arrange
mock_provider = Mock()
mock_provider.available_credentials = ["openai_api_key"]
provider_configuration.custom_configuration.provider = mock_provider
provider_configuration.custom_configuration.models = []
# Act
result = provider_configuration.is_custom_configuration_available()
# Assert
assert result is True
def test_is_custom_configuration_available_with_models(self, provider_configuration):
"""Test custom configuration availability with model configurations"""
# Arrange
provider_configuration.custom_configuration.provider = None
provider_configuration.custom_configuration.models = [Mock()]
# Act
result = provider_configuration.is_custom_configuration_available()
# Assert
assert result is True
def test_is_custom_configuration_available_false(self, provider_configuration):
"""Test custom configuration not available"""
# Arrange
provider_configuration.custom_configuration.provider = None
provider_configuration.custom_configuration.models = []
# Act
result = provider_configuration.is_custom_configuration_available()
# Assert
assert result is False
@patch("core.entities.provider_configuration.Session")
def test_get_provider_record_found(self, mock_session, provider_configuration):
"""Test getting provider record successfully"""
# Arrange
mock_provider = Mock(spec=Provider)
mock_session_instance = Mock()
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_provider
# Act
result = provider_configuration._get_provider_record(mock_session_instance)
# Assert
assert result == mock_provider
@patch("core.entities.provider_configuration.Session")
def test_get_provider_record_not_found(self, mock_session, provider_configuration):
"""Test getting provider record when not found"""
# Arrange
mock_session_instance = Mock()
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
# Act
result = provider_configuration._get_provider_record(mock_session_instance)
# Assert
assert result is None
def test_init_with_customizable_model_only(
self, mock_provider_entity, mock_system_configuration, mock_custom_configuration
):
"""Test initialization with customizable model only configuration"""
# Arrange
mock_provider_entity.configurate_methods = [ConfigurateMethod.CUSTOMIZABLE_MODEL]
# Act
with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}):
config = ProviderConfiguration(
tenant_id="test_tenant",
provider=mock_provider_entity,
preferred_provider_type=ProviderType.SYSTEM,
using_provider_type=ProviderType.SYSTEM,
system_configuration=mock_system_configuration,
custom_configuration=mock_custom_configuration,
model_settings=[],
)
# Assert
assert ConfigurateMethod.PREDEFINED_MODEL in config.provider.configurate_methods
def test_get_current_credentials_with_restricted_models(self, provider_configuration):
"""Test getting credentials with model restrictions"""
# Arrange
provider_configuration.using_provider_type = ProviderType.SYSTEM
# Act
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-3.5-turbo")
# Assert
assert credentials is not None
assert "openai_api_key" in credentials
@patch("core.entities.provider_configuration.Session")
def test_get_specific_provider_credential_success(self, mock_session, provider_configuration):
"""Test getting specific provider credential successfully"""
# Arrange
credential_id = "test_credential_id"
mock_credential = Mock()
mock_credential.encrypted_config = '{"openai_api_key": "encrypted_key"}'
mock_session_instance = Mock()
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_credential
# Act
with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get:
mock_get.return_value = {"openai_api_key": "test_key"}
result = provider_configuration._get_specific_provider_credential(credential_id)
# Assert
assert result == {"openai_api_key": "test_key"}
@patch("core.entities.provider_configuration.Session")
def test_get_specific_provider_credential_not_found(self, mock_session, provider_configuration):
"""Test getting specific provider credential when not found"""
# Arrange
credential_id = "nonexistent_credential_id"
mock_session_instance = Mock()
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
# Act & Assert
with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get:
mock_get.return_value = None
result = provider_configuration._get_specific_provider_credential(credential_id)
assert result is None
# Act
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
# Assert
assert credentials == {"openai_api_key": "test_key"}
def test_extract_secret_variables_with_secret_input(self, provider_configuration):
"""Test extracting secret variables from credential form schemas"""
# Arrange
credential_form_schemas = [
CredentialFormSchema(
variable="api_key",
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
type=FormType.SECRET_INPUT,
required=True,
),
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="secret_token",
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
type=FormType.SECRET_INPUT,
required=False,
),
]
# Act
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
# Assert
assert len(secret_variables) == 2
assert "api_key" in secret_variables
assert "secret_token" in secret_variables
assert "model_name" not in secret_variables
def test_extract_secret_variables_no_secret_input(self, provider_configuration):
"""Test extracting secret variables when no secret input fields exist"""
# Arrange
credential_form_schemas = [
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=FormType.SELECT,
required=True,
options=[FormOption(label=I18nObject(en_US="0.1", zh_Hans="0.1"), value="0.1")],
),
]
# Act
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
# Assert
assert len(secret_variables) == 0
def test_extract_secret_variables_empty_list(self, provider_configuration):
"""Test extracting secret variables from empty credential form schemas"""
# Arrange
credential_form_schemas = []
# Act
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
# Assert
assert len(secret_variables) == 0
@patch("core.entities.provider_configuration.encrypter")
def test_obfuscated_credentials_with_secret_variables(self, mock_encrypter, provider_configuration):
"""Test obfuscating credentials with secret variables"""
# Arrange
credentials = {
"api_key": "sk-1234567890abcdef",
"model_name": "gpt-4",
"secret_token": "secret_value_123",
"temperature": "0.7",
}
credential_form_schemas = [
CredentialFormSchema(
variable="api_key",
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
type=FormType.SECRET_INPUT,
required=True,
),
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="secret_token",
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
type=FormType.SECRET_INPUT,
required=False,
),
CredentialFormSchema(
variable="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=FormType.TEXT_INPUT,
required=True,
),
]
mock_encrypter.obfuscated_token.side_effect = lambda x: f"***{x[-4:]}"
# Act
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
# Assert
assert obfuscated["api_key"] == "***cdef"
assert obfuscated["model_name"] == "gpt-4" # Not obfuscated
assert obfuscated["secret_token"] == "***_123"
assert obfuscated["temperature"] == "0.7" # Not obfuscated
# Verify encrypter was called for secret fields only
assert mock_encrypter.obfuscated_token.call_count == 2
mock_encrypter.obfuscated_token.assert_any_call("sk-1234567890abcdef")
mock_encrypter.obfuscated_token.assert_any_call("secret_value_123")
def test_obfuscated_credentials_no_secret_variables(self, provider_configuration):
"""Test obfuscating credentials when no secret variables exist"""
# Arrange
credentials = {
"model_name": "gpt-4",
"temperature": "0.7",
"max_tokens": "1000",
}
credential_form_schemas = [
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="max_tokens",
label=I18nObject(en_US="Max Tokens", zh_Hans="最大令牌数"),
type=FormType.TEXT_INPUT,
required=True,
),
]
# Act
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
# Assert
assert obfuscated == credentials # No changes expected
def test_obfuscated_credentials_empty_credentials(self, provider_configuration):
"""Test obfuscating empty credentials"""
# Arrange
credentials = {}
credential_form_schemas = []
# Act
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
# Assert
assert obfuscated == {}

View File

@@ -0,0 +1,192 @@
import pytest
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelSettings
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from models.provider import LoadBalancingModelConfig, ProviderModelSetting
@pytest.fixture
def mock_provider_entity(mocker: MockerFixture):
mock_entity = mocker.Mock()
mock_entity.provider = "openai"
mock_entity.configurate_methods = ["predefined-model"]
mock_entity.supported_model_types = [ModelType.LLM]
# Use PropertyMock to ensure credential_form_schemas is iterable
provider_credential_schema = mocker.Mock()
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
mock_entity.provider_credential_schema = provider_credential_schema
model_credential_schema = mocker.Mock()
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
mock_entity.model_credential_schema = model_credential_schema
return mock_entity
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
ps = ProviderModelSetting(
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
enabled=True,
load_balancing_enabled=True,
)
ps.id = "id"
provider_model_settings = [ps]
load_balancing_model_configs = [
LoadBalancingModelConfig(
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="__inherit__",
encrypted_config=None,
enabled=True,
),
LoadBalancingModelConfig(
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="first",
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True,
),
]
load_balancing_model_configs[0].id = "id1"
load_balancing_model_configs[1].id = "id2"
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == "gpt-4"
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 2
assert result[0].load_balancing_configs[0].name == "__inherit__"
assert result[0].load_balancing_configs[1].name == "first"
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
ps = ProviderModelSetting(
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
enabled=True,
load_balancing_enabled=True,
)
ps.id = "id"
provider_model_settings = [ps]
load_balancing_model_configs = [
LoadBalancingModelConfig(
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="__inherit__",
encrypted_config=None,
enabled=True,
)
]
load_balancing_model_configs[0].id = "id1"
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == "gpt-4"
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 0
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
ps = ProviderModelSetting(
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
enabled=True,
load_balancing_enabled=False,
)
ps.id = "id"
provider_model_settings = [ps]
load_balancing_model_configs = [
LoadBalancingModelConfig(
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="__inherit__",
encrypted_config=None,
enabled=True,
),
LoadBalancingModelConfig(
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="first",
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True,
),
]
load_balancing_model_configs[0].id = "id1"
load_balancing_model_configs[1].id = "id2"
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == "gpt-4"
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 0

View File

@@ -0,0 +1,102 @@
import hashlib
import json
from datetime import UTC, datetime
import pytest
import pytz
from core.trigger.debug import event_selectors
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig
class _DummyRedis:
def __init__(self):
self.store: dict[str, str] = {}
def get(self, key: str):
return self.store.get(key)
def setex(self, name: str, time: int, value: str):
self.store[name] = value
def expire(self, name: str, ttl: int):
# Expiration not required for these tests.
pass
def delete(self, name: str):
self.store.pop(name, None)
@pytest.fixture
def dummy_schedule_config() -> ScheduleConfig:
return ScheduleConfig(
node_id="node-1",
cron_expression="* * * * *",
timezone="Asia/Shanghai",
)
@pytest.fixture(autouse=True)
def patch_schedule_service(monkeypatch: pytest.MonkeyPatch, dummy_schedule_config: ScheduleConfig):
# Ensure poller always receives the deterministic config.
monkeypatch.setattr(
"services.trigger.schedule_service.ScheduleService.to_schedule_config",
staticmethod(lambda *_args, **_kwargs: dummy_schedule_config),
)
def _make_poller(
monkeypatch: pytest.MonkeyPatch, redis_client: _DummyRedis
) -> event_selectors.ScheduleTriggerDebugEventPoller:
monkeypatch.setattr(event_selectors, "redis_client", redis_client)
return event_selectors.ScheduleTriggerDebugEventPoller(
tenant_id="tenant-1",
user_id="user-1",
app_id="app-1",
node_config={"id": "node-1", "data": {"mode": "cron"}},
node_id="node-1",
)
def test_schedule_poller_handles_aware_next_run(monkeypatch: pytest.MonkeyPatch):
redis_client = _DummyRedis()
poller = _make_poller(monkeypatch, redis_client)
base_now = datetime(2025, 1, 1, 12, 0, 10)
aware_next_run = datetime(2025, 1, 1, 12, 0, 5, tzinfo=UTC)
monkeypatch.setattr(event_selectors, "naive_utc_now", lambda: base_now)
monkeypatch.setattr(event_selectors, "calculate_next_run_at", lambda *_: aware_next_run)
event = poller.poll()
assert event is not None
assert event.node_id == "node-1"
assert event.workflow_args["inputs"] == {}
def test_schedule_runtime_cache_normalizes_timezone(
monkeypatch: pytest.MonkeyPatch, dummy_schedule_config: ScheduleConfig
):
redis_client = _DummyRedis()
poller = _make_poller(monkeypatch, redis_client)
localized_time = pytz.timezone("Asia/Shanghai").localize(datetime(2025, 1, 1, 20, 0, 0))
cron_hash = hashlib.sha256(dummy_schedule_config.cron_expression.encode()).hexdigest()
cache_key = poller.schedule_debug_runtime_key(cron_hash)
redis_client.store[cache_key] = json.dumps(
{
"cache_key": cache_key,
"timezone": dummy_schedule_config.timezone,
"cron_expression": dummy_schedule_config.cron_expression,
"next_run_at": localized_time.isoformat(),
}
)
runtime = poller.get_or_create_schedule_debug_runtime()
expected = localized_time.astimezone(UTC).replace(tzinfo=None)
assert runtime.next_run_at == expected
assert runtime.next_run_at.tzinfo is None

View File

@@ -0,0 +1,29 @@
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
def _make_identity() -> ToolIdentity:
return ToolIdentity(
author="author",
name="tool",
label=I18nObject(en_US="Label"),
provider="builtin",
)
def test_log_message_metadata_none_defaults_to_empty_dict():
log_message = ToolInvokeMessage.LogMessage(
id="log-1",
label="Log entry",
status=ToolInvokeMessage.LogMessage.LogStatus.START,
data={},
metadata=None,
)
assert log_message.metadata == {}
def test_tool_entity_output_schema_none_defaults_to_empty_dict():
entity = ToolEntity(identity=_make_identity(), output_schema=None)
assert entity.output_schema == {}

View File

@@ -0,0 +1,49 @@
from core.tools.entities.tool_entities import ToolParameter
def test_get_parameter_type():
assert ToolParameter.ToolParameterType.STRING.as_normal_type() == "string"
assert ToolParameter.ToolParameterType.SELECT.as_normal_type() == "string"
assert ToolParameter.ToolParameterType.SECRET_INPUT.as_normal_type() == "string"
assert ToolParameter.ToolParameterType.BOOLEAN.as_normal_type() == "boolean"
assert ToolParameter.ToolParameterType.NUMBER.as_normal_type() == "number"
assert ToolParameter.ToolParameterType.FILE.as_normal_type() == "file"
assert ToolParameter.ToolParameterType.FILES.as_normal_type() == "files"
def test_cast_parameter_by_type():
# string
assert ToolParameter.ToolParameterType.STRING.cast_value("test") == "test"
assert ToolParameter.ToolParameterType.STRING.cast_value(1) == "1"
assert ToolParameter.ToolParameterType.STRING.cast_value(1.0) == "1.0"
assert ToolParameter.ToolParameterType.STRING.cast_value(None) == ""
# secret input
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value("test") == "test"
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1) == "1"
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1.0) == "1.0"
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(None) == ""
# select
assert ToolParameter.ToolParameterType.SELECT.cast_value("test") == "test"
assert ToolParameter.ToolParameterType.SELECT.cast_value(1) == "1"
assert ToolParameter.ToolParameterType.SELECT.cast_value(1.0) == "1.0"
assert ToolParameter.ToolParameterType.SELECT.cast_value(None) == ""
# boolean
true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"]
for value in true_values:
assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is True
false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""]
for value in false_values:
assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is False
# number
assert ToolParameter.ToolParameterType.NUMBER.cast_value("1") == 1
assert ToolParameter.ToolParameterType.NUMBER.cast_value("1.0") == 1.0
assert ToolParameter.ToolParameterType.NUMBER.cast_value("-1.0") == -1.0
assert ToolParameter.ToolParameterType.NUMBER.cast_value(1) == 1
assert ToolParameter.ToolParameterType.NUMBER.cast_value(1.0) == 1.0
assert ToolParameter.ToolParameterType.NUMBER.cast_value(-1.0) == -1.0
assert ToolParameter.ToolParameterType.NUMBER.cast_value(None) is None

View File

@@ -0,0 +1,181 @@
import copy
from unittest.mock import patch
import pytest
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_encryption import ProviderConfigEncrypter
# ---------------------------
# A no-op cache
# ---------------------------
class NoopCache:
"""Simple cache stub: always returns None, does nothing for set/delete."""
def get(self):
return None
def set(self, config):
pass
def delete(self):
pass
@pytest.fixture
def secret_field() -> BasicProviderConfig:
"""A SECRET_INPUT field named 'password'."""
return BasicProviderConfig(
name="password",
type=BasicProviderConfig.Type.SECRET_INPUT,
)
@pytest.fixture
def normal_field() -> BasicProviderConfig:
"""A TEXT_INPUT field named 'username'."""
return BasicProviderConfig(
name="username",
type=BasicProviderConfig.Type.TEXT_INPUT,
)
@pytest.fixture
def encrypter_obj(secret_field, normal_field):
"""
Build ProviderConfigEncrypter with:
- tenant_id = tenant123
- one secret field (password) and one normal field (username)
- NoopCache as cache
"""
return ProviderConfigEncrypter(
tenant_id="tenant123",
config=[secret_field, normal_field],
provider_config_cache=NoopCache(),
)
# ============================================================
# ProviderConfigEncrypter.encrypt()
# ============================================================
def test_encrypt_only_secret_is_encrypted_and_non_secret_unchanged(encrypter_obj):
"""
Secret field should be encrypted, non-secret field unchanged.
Verify encrypt_token called only for secret field.
Also check deep copy (input not modified).
"""
data_in = {"username": "alice", "password": "plain_pwd"}
data_copy = copy.deepcopy(data_in)
with patch("core.helper.provider_encryption.encrypter.encrypt_token", return_value="CIPHERTEXT") as mock_encrypt:
out = encrypter_obj.encrypt(data_in)
assert out["username"] == "alice"
assert out["password"] == "CIPHERTEXT"
mock_encrypt.assert_called_once_with("tenant123", "plain_pwd")
assert data_in == data_copy # deep copy semantics
def test_encrypt_missing_secret_key_is_ok(encrypter_obj):
"""If secret field missing in input, no error and no encryption called."""
with patch("core.helper.provider_encryption.encrypter.encrypt_token") as mock_encrypt:
out = encrypter_obj.encrypt({"username": "alice"})
assert out["username"] == "alice"
mock_encrypt.assert_not_called()
# ============================================================
# ProviderConfigEncrypter.mask_plugin_credentials()
# ============================================================
@pytest.mark.parametrize(
("raw", "prefix", "suffix"),
[
("longsecret", "lo", "et"),
("abcdefg", "ab", "fg"),
("1234567", "12", "67"),
],
)
def test_mask_tool_credentials_long_secret(encrypter_obj, raw, prefix, suffix):
"""
For length > 6: keep first 2 and last 2, mask middle with '*'.
"""
data_in = {"username": "alice", "password": raw}
data_copy = copy.deepcopy(data_in)
out = encrypter_obj.mask_plugin_credentials(data_in)
masked = out["password"]
assert masked.startswith(prefix)
assert masked.endswith(suffix)
assert "*" in masked
assert len(masked) == len(raw)
assert data_in == data_copy # deep copy semantics
@pytest.mark.parametrize("raw", ["", "1", "12", "123", "123456"])
def test_mask_tool_credentials_short_secret(encrypter_obj, raw):
"""
For length <= 6: fully mask with '*' of same length.
"""
out = encrypter_obj.mask_plugin_credentials({"password": raw})
assert out["password"] == ("*" * len(raw))
def test_mask_tool_credentials_missing_key_noop(encrypter_obj):
"""If secret key missing, leave other fields unchanged."""
data_in = {"username": "alice"}
data_copy = copy.deepcopy(data_in)
out = encrypter_obj.mask_plugin_credentials(data_in)
assert out["username"] == "alice"
assert data_in == data_copy
# ============================================================
# ProviderConfigEncrypter.decrypt()
# ============================================================
def test_decrypt_normal_flow(encrypter_obj):
"""
Normal decrypt flow:
- decrypt_token called for secret field
- secret replaced with decrypted value
- non-secret unchanged
"""
data_in = {"username": "alice", "password": "ENC"}
data_copy = copy.deepcopy(data_in)
with patch("core.helper.provider_encryption.encrypter.decrypt_token", return_value="PLAIN") as mock_decrypt:
out = encrypter_obj.decrypt(data_in)
assert out["username"] == "alice"
assert out["password"] == "PLAIN"
mock_decrypt.assert_called_once_with("tenant123", "ENC")
assert data_in == data_copy # deep copy semantics
@pytest.mark.parametrize("empty_val", ["", None])
def test_decrypt_skip_empty_values(encrypter_obj, empty_val):
"""Skip decrypt if value is empty or None, keep original."""
with patch("core.helper.provider_encryption.encrypter.decrypt_token") as mock_decrypt:
out = encrypter_obj.decrypt({"password": empty_val})
mock_decrypt.assert_not_called()
assert out["password"] == empty_val
def test_decrypt_swallow_exception_and_keep_original(encrypter_obj):
"""
If decrypt_token raises, exception should be swallowed,
and original value preserved.
"""
with patch("core.helper.provider_encryption.encrypter.decrypt_token", side_effect=Exception("boom")):
out = encrypter_obj.decrypt({"password": "ENC_ERR"})
assert out["password"] == "ENC_ERR"

View File

@@ -0,0 +1,191 @@
import pytest
from flask import Flask
from core.tools.utils.parser import ApiBasedToolSchemaParser
@pytest.fixture
def app():
app = Flask(__name__)
return app
def test_parse_openapi_to_tool_bundle_operation_id(app):
openapi = {
"openapi": "3.0.0",
"info": {"title": "Simple API", "version": "1.0.0"},
"servers": [{"url": "http://localhost:3000"}],
"paths": {
"/": {
"get": {
"summary": "Root endpoint",
"responses": {
"200": {
"description": "Successful response",
}
},
}
},
"/api/resources": {
"get": {
"summary": "Non-root endpoint without an operationId",
"responses": {
"200": {
"description": "Successful response",
}
},
},
"post": {
"summary": "Non-root endpoint with an operationId",
"operationId": "createResource",
"responses": {
"201": {
"description": "Resource created",
}
},
},
},
},
}
with app.test_request_context():
tool_bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
assert len(tool_bundles) == 3
assert tool_bundles[0].operation_id == "<root>_get"
assert tool_bundles[1].operation_id == "apiresources_get"
assert tool_bundles[2].operation_id == "createResource"
def test_parse_openapi_to_tool_bundle_properties_all_of(app):
openapi = {
"openapi": "3.0.0",
"info": {"title": "Simple API", "version": "1.0.0"},
"servers": [{"url": "http://localhost:3000"}],
"paths": {
"/api/resource": {
"get": {
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Request",
},
},
},
"required": True,
},
},
},
},
"components": {
"schemas": {
"Request": {
"type": "object",
"properties": {
"prop1": {
"enum": ["option1"],
"description": "desc prop1",
"allOf": [
{"$ref": "#/components/schemas/AllOfItem"},
{
"enum": ["option2"],
},
],
},
},
},
"AllOfItem": {
"type": "string",
"enum": ["option3"],
"description": "desc allOf item",
},
}
},
}
with app.test_request_context():
tool_bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
assert tool_bundles[0].parameters[0].type == "string"
assert tool_bundles[0].parameters[0].llm_description == "desc prop1"
# TODO: support enum in OpenAPI
# assert set(tool_bundles[0].parameters[0].options) == {"option1", "option2", "option3"}
def test_parse_openapi_to_tool_bundle_default_value_type_casting(app):
"""
Test that default values are properly cast to match parameter types.
This addresses the issue where array default values like [] cause validation errors
when parameter type is inferred as string/number/boolean.
"""
openapi = {
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"servers": [{"url": "https://example.com"}],
"paths": {
"/product/create": {
"post": {
"operationId": "createProduct",
"summary": "Create a product",
"requestBody": {
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"categories": {
"description": "List of category identifiers",
"default": [],
"type": "array",
"items": {"type": "string"},
},
"name": {
"description": "Product name",
"default": "Default Product",
"type": "string",
},
"price": {"description": "Product price", "default": 0.0, "type": "number"},
"available": {
"description": "Product availability",
"default": True,
"type": "boolean",
},
},
}
}
}
},
"responses": {"200": {"description": "Default Response"}},
}
}
},
}
with app.test_request_context():
tool_bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
assert len(tool_bundles) == 1
bundle = tool_bundles[0]
assert len(bundle.parameters) == 4
# Find parameters by name
params_by_name = {param.name: param for param in bundle.parameters}
# Check categories parameter (array type with [] default)
categories_param = params_by_name["categories"]
assert categories_param.type == "array" # Will be detected by _get_tool_parameter_type
assert categories_param.default is None # Array default [] is converted to None
# Check name parameter (string type with string default)
name_param = params_by_name["name"]
assert name_param.type == "string"
assert name_param.default == "Default Product"
# Check price parameter (number type with number default)
price_param = params_by_name["price"]
assert price_param.type == "number"
assert price_param.default == 0.0
# Check available parameter (boolean type with boolean default)
available_param = params_by_name["available"]
assert available_param.type == "boolean"
assert available_param.default is True

View File

@@ -0,0 +1,481 @@
import json
from datetime import date, datetime
from decimal import Decimal
from uuid import uuid4
import numpy as np
import pytest
import pytz
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_dict, safe_json_value
class TestSafeJsonValue:
"""Test suite for safe_json_value function to ensure proper serialization of complex types"""
def test_datetime_conversion(self):
"""Test datetime conversion with timezone handling"""
# Test datetime with UTC timezone
dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC)
result = safe_json_value(dt)
assert isinstance(result, str)
assert "2024-01-01T12:00:00+00:00" in result
# Test datetime without timezone (should default to UTC)
dt_no_tz = datetime(2024, 1, 1, 12, 0, 0)
result = safe_json_value(dt_no_tz)
assert isinstance(result, str)
# The exact time will depend on the system's timezone, so we check the format
assert "T" in result # ISO format separator
# Check that it's a valid ISO format datetime string
assert len(result) >= 19 # At least YYYY-MM-DDTHH:MM:SS
def test_date_conversion(self):
"""Test date conversion to ISO format"""
test_date = date(2024, 1, 1)
result = safe_json_value(test_date)
assert result == "2024-01-01"
def test_uuid_conversion(self):
"""Test UUID conversion to string"""
test_uuid = uuid4()
result = safe_json_value(test_uuid)
assert isinstance(result, str)
assert result == str(test_uuid)
def test_decimal_conversion(self):
"""Test Decimal conversion to float"""
test_decimal = Decimal("123.456")
result = safe_json_value(test_decimal)
assert result == 123.456
assert isinstance(result, float)
def test_bytes_conversion(self):
"""Test bytes conversion with UTF-8 decoding"""
# Test valid UTF-8 bytes
test_bytes = b"Hello, World!"
result = safe_json_value(test_bytes)
assert result == "Hello, World!"
# Test invalid UTF-8 bytes (should fall back to hex)
invalid_bytes = b"\xff\xfe\xfd"
result = safe_json_value(invalid_bytes)
assert result == "fffefd"
def test_memoryview_conversion(self):
"""Test memoryview conversion to hex string"""
test_bytes = b"test data"
test_memoryview = memoryview(test_bytes)
result = safe_json_value(test_memoryview)
assert result == "746573742064617461" # hex of "test data"
def test_numpy_ndarray_conversion(self):
"""Test numpy ndarray conversion to list"""
# Test 1D array
test_array = np.array([1, 2, 3, 4])
result = safe_json_value(test_array)
assert result == [1, 2, 3, 4]
# Test 2D array
test_2d_array = np.array([[1, 2], [3, 4]])
result = safe_json_value(test_2d_array)
assert result == [[1, 2], [3, 4]]
# Test array with float values
test_float_array = np.array([1.5, 2.7, 3.14])
result = safe_json_value(test_float_array)
assert result == [1.5, 2.7, 3.14]
def test_dict_conversion(self):
"""Test dictionary conversion using safe_json_dict"""
test_dict = {
"string": "value",
"number": 42,
"float": 3.14,
"boolean": True,
"list": [1, 2, 3],
"nested": {"key": "value"},
}
result = safe_json_value(test_dict)
assert isinstance(result, dict)
assert result == test_dict
def test_list_conversion(self):
"""Test list conversion with mixed types"""
test_list = [
"string",
42,
3.14,
True,
[1, 2, 3],
{"key": "value"},
datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC),
Decimal("123.456"),
uuid4(),
]
result = safe_json_value(test_list)
assert isinstance(result, list)
assert len(result) == len(test_list)
assert isinstance(result[6], str) # datetime should be converted to string
assert isinstance(result[7], float) # Decimal should be converted to float
assert isinstance(result[8], str) # UUID should be converted to string
def test_tuple_conversion(self):
"""Test tuple conversion to list"""
test_tuple = (1, "string", 3.14)
result = safe_json_value(test_tuple)
assert isinstance(result, list)
assert result == [1, "string", 3.14]
def test_set_conversion(self):
"""Test set conversion to list"""
test_set = {1, "string", 3.14}
result = safe_json_value(test_set)
assert isinstance(result, list)
# Note: set order is not guaranteed, so we check length and content
assert len(result) == 3
assert 1 in result
assert "string" in result
assert 3.14 in result
def test_basic_types_passthrough(self):
"""Test that basic types are passed through unchanged"""
assert safe_json_value("string") == "string"
assert safe_json_value(42) == 42
assert safe_json_value(3.14) == 3.14
assert safe_json_value(True) is True
assert safe_json_value(False) is False
assert safe_json_value(None) is None
def test_nested_complex_structure(self):
"""Test complex nested structure with all types"""
complex_data = {
"dates": [date(2024, 1, 1), date(2024, 1, 2)],
"timestamps": [
datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC),
datetime(2024, 1, 2, 12, 0, 0, tzinfo=pytz.UTC),
],
"numbers": [Decimal("123.456"), Decimal("789.012")],
"identifiers": [uuid4(), uuid4()],
"binary_data": [b"hello", b"world"],
"arrays": [np.array([1, 2, 3]), np.array([4, 5, 6])],
}
result = safe_json_value(complex_data)
# Verify structure is maintained
assert isinstance(result, dict)
assert "dates" in result
assert "timestamps" in result
assert "numbers" in result
assert "identifiers" in result
assert "binary_data" in result
assert "arrays" in result
# Verify conversions
assert all(isinstance(d, str) for d in result["dates"])
assert all(isinstance(t, str) for t in result["timestamps"])
assert all(isinstance(n, float) for n in result["numbers"])
assert all(isinstance(i, str) for i in result["identifiers"])
assert all(isinstance(b, str) for b in result["binary_data"])
assert all(isinstance(a, list) for a in result["arrays"])
class TestSafeJsonDict:
"""Test suite for safe_json_dict function"""
def test_valid_dict_conversion(self):
"""Test valid dictionary conversion"""
test_dict = {
"string": "value",
"number": 42,
"datetime": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC),
"decimal": Decimal("123.456"),
}
result = safe_json_dict(test_dict)
assert isinstance(result, dict)
assert result["string"] == "value"
assert result["number"] == 42
assert isinstance(result["datetime"], str)
assert isinstance(result["decimal"], float)
def test_invalid_input_type(self):
"""Test that invalid input types raise TypeError"""
with pytest.raises(TypeError, match="safe_json_dict\\(\\) expects a dictionary \\(dict\\) as input"):
safe_json_dict("not a dict")
with pytest.raises(TypeError, match="safe_json_dict\\(\\) expects a dictionary \\(dict\\) as input"):
safe_json_dict([1, 2, 3])
with pytest.raises(TypeError, match="safe_json_dict\\(\\) expects a dictionary \\(dict\\) as input"):
safe_json_dict(42)
def test_empty_dict(self):
"""Test empty dictionary handling"""
result = safe_json_dict({})
assert result == {}
def test_nested_dict_conversion(self):
"""Test nested dictionary conversion"""
test_dict = {
"level1": {
"level2": {"datetime": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), "decimal": Decimal("123.456")}
}
}
result = safe_json_dict(test_dict)
assert isinstance(result["level1"]["level2"]["datetime"], str)
assert isinstance(result["level1"]["level2"]["decimal"], float)
class TestToolInvokeMessageJsonSerialization:
"""Test suite for ToolInvokeMessage JSON serialization through safe_json_value"""
def test_json_message_serialization(self):
"""Test JSON message serialization with complex data"""
complex_data = {
"timestamp": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC),
"amount": Decimal("123.45"),
"id": uuid4(),
"binary": b"test data",
"array": np.array([1, 2, 3]),
}
# Create JSON message
json_message = ToolInvokeMessage.JsonMessage(json_object=complex_data)
message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message)
# Apply safe_json_value transformation
transformed_data = safe_json_value(message.message.json_object)
# Verify transformations
assert isinstance(transformed_data["timestamp"], str)
assert isinstance(transformed_data["amount"], float)
assert isinstance(transformed_data["id"], str)
assert isinstance(transformed_data["binary"], str)
assert isinstance(transformed_data["array"], list)
# Verify JSON serialization works
json_string = json.dumps(transformed_data, ensure_ascii=False)
assert isinstance(json_string, str)
# Verify we can deserialize back
deserialized = json.loads(json_string)
assert deserialized["amount"] == 123.45
assert deserialized["array"] == [1, 2, 3]
def test_json_message_with_nested_structures(self):
"""Test JSON message with deeply nested complex structures"""
nested_data = {
"level1": {
"level2": {
"level3": {
"dates": [date(2024, 1, 1), date(2024, 1, 2)],
"timestamps": [datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC)],
"numbers": [Decimal("1.1"), Decimal("2.2")],
"arrays": [np.array([1, 2]), np.array([3, 4])],
}
}
}
}
json_message = ToolInvokeMessage.JsonMessage(json_object=nested_data)
message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message)
# Transform the data
transformed_data = safe_json_value(message.message.json_object)
# Verify nested transformations
level3 = transformed_data["level1"]["level2"]["level3"]
assert all(isinstance(d, str) for d in level3["dates"])
assert all(isinstance(t, str) for t in level3["timestamps"])
assert all(isinstance(n, float) for n in level3["numbers"])
assert all(isinstance(a, list) for a in level3["arrays"])
# Test JSON serialization
json_string = json.dumps(transformed_data, ensure_ascii=False)
assert isinstance(json_string, str)
# Verify deserialization
deserialized = json.loads(json_string)
assert deserialized["level1"]["level2"]["level3"]["numbers"] == [1.1, 2.2]
def test_json_message_transformer_integration(self):
"""Test integration with ToolFileMessageTransformer for JSON messages"""
complex_data = {
"metadata": {
"created_at": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC),
"version": Decimal("1.0"),
"tags": ["tag1", "tag2"],
},
"data": {"values": np.array([1.1, 2.2, 3.3]), "binary": b"binary content"},
}
# Create message generator
def message_generator():
json_message = ToolInvokeMessage.JsonMessage(json_object=complex_data)
message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message)
yield message
# Transform messages
transformed_messages = list(
ToolFileMessageTransformer.transform_tool_invoke_messages(
message_generator(), user_id="test_user", tenant_id="test_tenant"
)
)
assert len(transformed_messages) == 1
transformed_message = transformed_messages[0]
assert transformed_message.type == ToolInvokeMessage.MessageType.JSON
# Verify the JSON object was transformed
json_obj = transformed_message.message.json_object
assert isinstance(json_obj["metadata"]["created_at"], str)
assert isinstance(json_obj["metadata"]["version"], float)
assert isinstance(json_obj["data"]["values"], list)
assert isinstance(json_obj["data"]["binary"], str)
# Test final JSON serialization
final_json = json.dumps(json_obj, ensure_ascii=False)
assert isinstance(final_json, str)
# Verify we can deserialize
deserialized = json.loads(final_json)
assert deserialized["metadata"]["version"] == 1.0
assert deserialized["data"]["values"] == [1.1, 2.2, 3.3]
def test_edge_cases_and_error_handling(self):
"""Test edge cases and error handling in JSON serialization"""
# Test with None values
data_with_none = {"null_value": None, "empty_string": "", "zero": 0, "false_value": False}
json_message = ToolInvokeMessage.JsonMessage(json_object=data_with_none)
message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message)
transformed_data = safe_json_value(message.message.json_object)
json_string = json.dumps(transformed_data, ensure_ascii=False)
# Verify serialization works with edge cases
assert json_string is not None
deserialized = json.loads(json_string)
assert deserialized["null_value"] is None
assert deserialized["empty_string"] == ""
assert deserialized["zero"] == 0
assert deserialized["false_value"] is False
# Test with very large numbers
large_data = {
"large_int": 2**63 - 1,
"large_float": 1.7976931348623157e308,
"small_float": 2.2250738585072014e-308,
}
json_message = ToolInvokeMessage.JsonMessage(json_object=large_data)
message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message)
transformed_data = safe_json_value(message.message.json_object)
json_string = json.dumps(transformed_data, ensure_ascii=False)
# Verify large numbers are handled correctly
deserialized = json.loads(json_string)
assert deserialized["large_int"] == 2**63 - 1
assert deserialized["large_float"] == 1.7976931348623157e308
assert deserialized["small_float"] == 2.2250738585072014e-308
class TestEndToEndSerialization:
"""Test suite for end-to-end serialization workflow"""
def test_complete_workflow_with_real_data(self):
"""Test complete workflow from complex data to JSON string and back"""
# Simulate real-world complex data structure
real_world_data = {
"user_profile": {
"id": uuid4(),
"name": "John Doe",
"email": "john@example.com",
"created_at": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC),
"last_login": datetime(2024, 1, 15, 14, 30, 0, tzinfo=pytz.UTC),
"preferences": {"theme": "dark", "language": "en", "timezone": "UTC"},
},
"analytics": {
"session_count": 42,
"total_time": Decimal("123.45"),
"metrics": np.array([1.1, 2.2, 3.3, 4.4, 5.5]),
"events": [
{
"timestamp": datetime(2024, 1, 1, 10, 0, 0, tzinfo=pytz.UTC),
"action": "login",
"duration": Decimal("5.67"),
},
{
"timestamp": datetime(2024, 1, 1, 11, 0, 0, tzinfo=pytz.UTC),
"action": "logout",
"duration": Decimal("3600.0"),
},
],
},
"files": [
{
"id": uuid4(),
"name": "document.pdf",
"size": 1024,
"uploaded_at": datetime(2024, 1, 1, 9, 0, 0, tzinfo=pytz.UTC),
"checksum": b"abc123def456",
}
],
}
# Step 1: Create ToolInvokeMessage
json_message = ToolInvokeMessage.JsonMessage(json_object=real_world_data)
message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message)
# Step 2: Apply safe_json_value transformation
transformed_data = safe_json_value(message.message.json_object)
# Step 3: Serialize to JSON string
json_string = json.dumps(transformed_data, ensure_ascii=False)
# Step 4: Verify the string is valid JSON
assert isinstance(json_string, str)
assert json_string.startswith("{")
assert json_string.endswith("}")
# Step 5: Deserialize back to Python object
deserialized_data = json.loads(json_string)
# Step 6: Verify data integrity
assert deserialized_data["user_profile"]["name"] == "John Doe"
assert deserialized_data["user_profile"]["email"] == "john@example.com"
assert isinstance(deserialized_data["user_profile"]["created_at"], str)
assert isinstance(deserialized_data["analytics"]["total_time"], float)
assert deserialized_data["analytics"]["total_time"] == 123.45
assert isinstance(deserialized_data["analytics"]["metrics"], list)
assert deserialized_data["analytics"]["metrics"] == [1.1, 2.2, 3.3, 4.4, 5.5]
assert isinstance(deserialized_data["files"][0]["checksum"], str)
# Step 7: Verify all complex types were properly converted
self._verify_all_complex_types_converted(deserialized_data)
def _verify_all_complex_types_converted(self, data):
"""Helper method to verify all complex types were properly converted"""
if isinstance(data, dict):
for key, value in data.items():
if key in ["id", "checksum"]:
# These should be strings (UUID/bytes converted)
assert isinstance(value, str)
elif key in ["created_at", "last_login", "timestamp", "uploaded_at"]:
# These should be strings (datetime converted)
assert isinstance(value, str)
elif key in ["total_time", "duration"]:
# These should be floats (Decimal converted)
assert isinstance(value, float)
elif key == "metrics":
# This should be a list (ndarray converted)
assert isinstance(value, list)
else:
# Recursively check nested structures
self._verify_all_complex_types_converted(value)
elif isinstance(data, list):
for item in data:
self._verify_all_complex_types_converted(item)

View File

@@ -0,0 +1,312 @@
import pytest
from core.tools.utils.web_reader_tool import (
extract_using_readabilipy,
get_image_upload_file_ids,
get_url,
page_result,
)
class FakeResponse:
"""Minimal fake response object for ssrf_proxy / cloudscraper."""
def __init__(self, *, status_code=200, headers=None, content=b"", text=""):
self.status_code = status_code
self.headers = headers or {}
self.content = content
self.text = text or content.decode("utf-8", errors="ignore")
# ---------------------------
# Tests: page_result
# ---------------------------
@pytest.mark.parametrize(
("text", "cursor", "maxlen", "expected"),
[
("abcdef", 0, 3, "abc"),
("abcdef", 2, 10, "cdef"), # maxlen beyond end
("abcdef", 6, 5, ""), # cursor at end
("abcdef", 7, 5, ""), # cursor beyond end
("", 0, 5, ""), # empty text
],
)
def test_page_result(text, cursor, maxlen, expected):
assert page_result(text, cursor, maxlen) == expected
# ---------------------------
# Tests: get_url
# ---------------------------
@pytest.fixture
def stub_support_types(monkeypatch: pytest.MonkeyPatch):
"""Stub supported content types list."""
import core.tools.utils.web_reader_tool as mod
# e.g. binary types supported by ExtractProcessor
monkeypatch.setattr(mod.extract_processor, "SUPPORT_URL_CONTENT_TYPES", ["application/pdf", "text/plain"])
return mod
def test_get_url_unsupported_content_type(monkeypatch: pytest.MonkeyPatch, stub_support_types):
# HEAD 200 but content-type not supported and not text/html
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
return FakeResponse(
status_code=200,
headers={"Content-Type": "image/png"}, # not supported
)
monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head)
result = get_url("https://x.test/file.png")
assert result == "Unsupported content-type [image/png] of URL."
def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""
When content-type is in SUPPORT_URL_CONTENT_TYPES,
should call ExtractProcessor.load_from_url and return its text.
"""
calls = {"load": 0}
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
return FakeResponse(
status_code=200,
headers={"Content-Type": "application/pdf"},
)
def fake_load_from_url(url, return_text=False):
calls["load"] += 1
assert return_text is True
return "PDF extracted text"
monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head)
monkeypatch.setattr(stub_support_types.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url))
result = get_url("https://x.test/doc.pdf")
assert calls["load"] == 1
assert result == "PDF extracted text"
def test_get_url_html_flow_with_chardet_and_readability(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""200 + text/html → GET, chardet detects encoding, readability returns article which is templated."""
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
return FakeResponse(status_code=200, headers={"Content-Type": "text/html"})
def fake_get(url, headers=None, follow_redirects=True, timeout=None):
html = b"<html><head><title>x</title></head><body>hello</body></html>"
return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}, content=html)
# chardet.detect returns utf-8
import core.tools.utils.web_reader_tool as mod
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
# readability → a dict that maps to Article, then FULL_TEMPLATE
def fake_simple_json_from_html_string(html, use_readability=True):
return {
"title": "My Title",
"byline": "Bob",
"plain_text": [{"type": "text", "text": "Hello world"}],
}
monkeypatch.setattr(mod, "simple_json_from_html_string", fake_simple_json_from_html_string)
out = get_url("https://x.test/page")
assert "TITLE: My Title" in out
assert "AUTHOR: Bob" in out
assert "Hello world" in out
def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""If readability returns no text, should return empty string."""
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
return FakeResponse(status_code=200, headers={"Content-Type": "text/html"})
def fake_get(url, headers=None, follow_redirects=True, timeout=None):
return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}, content=b"<html/>")
import core.tools.utils.web_reader_tool as mod
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
# readability returns empty plain_text
monkeypatch.setattr(mod, "simple_json_from_html_string", lambda html, use_readability=True: {"plain_text": []})
out = get_url("https://x.test/empty")
assert out == ""
def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""HEAD 403 → use cloudscraper.get via ssrf_proxy.make_request, then proceed."""
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
return FakeResponse(status_code=403, headers={})
# cloudscraper.create_scraper() → object with .get()
class FakeScraper:
def __init__(self):
pass # removed unused attribute
def get(self, url, headers=None, follow_redirects=True, timeout=None):
# mimic html 200
html = b"<html><body>hi</body></html>"
return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}, content=html)
import core.tools.utils.web_reader_tool as mod
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
monkeypatch.setattr(mod.cloudscraper, "create_scraper", lambda: FakeScraper())
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
monkeypatch.setattr(
mod,
"simple_json_from_html_string",
lambda html, use_readability=True: {"title": "T", "byline": "A", "plain_text": [{"type": "text", "text": "X"}]},
)
out = get_url("https://x.test/403")
assert "TITLE: T" in out
assert "AUTHOR: A" in out
assert "X" in out
def test_get_url_head_non_200_returns_status(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""HEAD returns non-200 and non-403 → should directly return code message."""
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
return FakeResponse(status_code=500)
import core.tools.utils.web_reader_tool as mod
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
out = get_url("https://x.test/fail")
assert out == "URL returned status code 500."
def test_get_url_content_disposition_filename_detection(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""
If HEAD 200 with no Content-Type but Content-Disposition filename suggests a supported type,
it should route to ExtractProcessor.load_from_url.
"""
calls = {"load": 0}
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
return FakeResponse(status_code=200, headers={"Content-Disposition": 'attachment; filename="doc.pdf"'})
def fake_load_from_url(url, return_text=False):
calls["load"] += 1
return "From ExtractProcessor via filename"
import core.tools.utils.web_reader_tool as mod
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
monkeypatch.setattr(mod.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url))
out = get_url("https://x.test/fname")
assert calls["load"] == 1
assert out == "From ExtractProcessor via filename"
def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""
If chardet returns an encoding but content.decode raises, should fallback to response.text.
"""
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
return FakeResponse(status_code=200, headers={"Content-Type": "text/html"})
# Return bytes that will raise with the chosen encoding
def fake_get(url, headers=None, follow_redirects=True, timeout=None):
return FakeResponse(
status_code=200,
headers={"Content-Type": "text/html"},
content=b"\xff\xfe\xfa", # likely to fail under utf-8
text="<html>fallback text</html>",
)
import core.tools.utils.web_reader_tool as mod
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
monkeypatch.setattr(
mod,
"simple_json_from_html_string",
lambda html, use_readability=True: {"title": "", "byline": "", "plain_text": [{"type": "text", "text": "ok"}]},
)
out = get_url("https://x.test/enc-fallback")
assert "ok" in out
# ---------------------------
# Tests: extract_using_readabilipy
# ---------------------------
def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch: pytest.MonkeyPatch):
# stub readabilipy.simple_json_from_html_string
def fake_simple_json_from_html_string(html, use_readability=True):
return {
"title": "Hello",
"byline": "Alice",
"plain_text": [{"type": "text", "text": "world"}],
}
import core.tools.utils.web_reader_tool as mod
monkeypatch.setattr(mod, "simple_json_from_html_string", fake_simple_json_from_html_string)
article = extract_using_readabilipy("<html>...</html>")
assert article.title == "Hello"
assert article.author == "Alice"
assert isinstance(article.text, list)
assert article.text
assert article.text[0]["text"] == "world"
def test_extract_using_readabilipy_defaults_when_missing(monkeypatch: pytest.MonkeyPatch):
def fake_simple_json_from_html_string(html, use_readability=True):
return {} # all missing
import core.tools.utils.web_reader_tool as mod
monkeypatch.setattr(mod, "simple_json_from_html_string", fake_simple_json_from_html_string)
article = extract_using_readabilipy("<html>...</html>")
assert article.title == ""
assert article.author == ""
assert article.text == []
# ---------------------------
# Tests: get_image_upload_file_ids
# ---------------------------
def test_get_image_upload_file_ids():
# should extract id from https + file-preview
content = "![image](https://example.com/a/b/files/abc123/file-preview)"
assert get_image_upload_file_ids(content) == ["abc123"]
# should extract id from http + image-preview
content = "![image](http://host/files/xyz789/image-preview)"
assert get_image_upload_file_ids(content) == ["xyz789"]
# should not match invalid scheme 'htt://'
content = "![image](htt://host/files/bad/file-preview)"
assert get_image_upload_file_ids(content) == []
# should extract multiple ids in order
content = """
some text
![image](https://h/files/id1/file-preview)
middle
![image](http://h/files/id2/image-preview)
end
"""
assert get_image_upload_file_ids(content) == ["id1", "id2"]

View File

@@ -0,0 +1,53 @@
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity
from core.tools.errors import ToolInvokeError
from core.tools.workflow_as_tool.tool import WorkflowTool
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch):
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
`WorkflowAppGenerator.generate` returns a result with `error` key inside
the `data` element.
"""
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
# needs to patch those methods to avoid database access.
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
# Mock user resolution to avoid database access
from unittest.mock import Mock
mock_user = Mock()
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
# replace `WorkflowAppGenerator.generate` 's return value.
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
lambda *args, **kwargs: {"data": {"error": "oops"}},
)
with pytest.raises(ToolInvokeError) as exc_info:
# WorkflowTool always returns a generator, so we need to iterate to
# actually `run` the tool.
list(tool.invoke("test_user", {}))
assert exc_info.value.args == ("oops",)

View File

@@ -0,0 +1,382 @@
import dataclasses
from pydantic import BaseModel
from core.file import File, FileTransferMethod, FileType
from core.helper import encrypter
from core.variables.segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
SegmentUnion,
StringSegment,
get_segment_discriminator,
)
from core.variables.types import SegmentType
from core.variables.variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
VariableUnion,
)
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
def test_segment_group_to_text():
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="fake-user-id"),
user_inputs={},
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
],
conversation_variables=[],
)
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
template = (
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
)
segments_group = variable_pool.convert_template(template)
assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key."
assert segments_group.log == (
f"Hello, fake-user-id! Your query is fake-user-query."
f" And your key is {encrypter.obfuscated_token('fake-secret-key')}."
)
def test_convert_constant_to_segment_group():
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
template = "Hello, world!"
segments_group = variable_pool.convert_template(template)
assert segments_group.text == "Hello, world!"
assert segments_group.log == "Hello, world!"
def test_convert_variable_to_segment_group():
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="fake-user-id"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
template = "{{#sys.user_id#}}"
segments_group = variable_pool.convert_template(template)
assert segments_group.text == "fake-user-id"
assert segments_group.log == "fake-user-id"
assert isinstance(segments_group.value[0], StringVariable)
assert segments_group.value[0].value == "fake-user-id"
class _Segments(BaseModel):
segments: list[SegmentUnion]
class _Variables(BaseModel):
variables: list[VariableUnion]
def create_test_file(
file_type: FileType = FileType.DOCUMENT,
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
filename: str = "test.txt",
extension: str = ".txt",
mime_type: str = "text/plain",
size: int = 1024,
) -> File:
"""Factory function to create File objects for testing"""
return File(
tenant_id="test-tenant",
type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
storage_key="test-storage-key",
)
class TestSegmentDumpAndLoad:
"""Test suite for segment and variable serialization/deserialization"""
def test_segments(self):
"""Test basic segment serialization compatibility"""
model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")])
json = model.model_dump_json()
loaded = _Segments.model_validate_json(json)
assert loaded == model
def test_segment_number(self):
"""Test number segment serialization compatibility"""
model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)])
json = model.model_dump_json()
loaded = _Segments.model_validate_json(json)
assert loaded == model
def test_variables(self):
"""Test variable serialization compatibility"""
model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")])
json = model.model_dump_json()
restored = _Variables.model_validate_json(json)
assert restored == model
def test_all_segments_serialization(self):
"""Test serialization/deserialization of all segment types"""
# Create one instance of each segment type
test_file = create_test_file()
all_segments: list[SegmentUnion] = [
NoneSegment(),
StringSegment(value="test string"),
IntegerSegment(value=42),
FloatSegment(value=3.14),
ObjectSegment(value={"key": "value", "number": 123}),
FileSegment(value=test_file),
ArrayAnySegment(value=[1, "string", 3.14, {"key": "value"}]),
ArrayStringSegment(value=["hello", "world"]),
ArrayNumberSegment(value=[1, 2.5, 3]),
ArrayObjectSegment(value=[{"id": 1}, {"id": 2}]),
ArrayFileSegment(value=[]), # Empty array to avoid file complexity
]
# Test serialization and deserialization
model = _Segments(segments=all_segments)
json_str = model.model_dump_json()
loaded = _Segments.model_validate_json(json_str)
# Verify all segments are preserved
assert len(loaded.segments) == len(all_segments)
for original, loaded_segment in zip(all_segments, loaded.segments):
assert type(loaded_segment) == type(original)
assert loaded_segment.value_type == original.value_type
# For file segments, compare key properties instead of exact equality
if isinstance(original, FileSegment) and isinstance(loaded_segment, FileSegment):
orig_file = original.value
loaded_file = loaded_segment.value
assert isinstance(orig_file, File)
assert isinstance(loaded_file, File)
assert loaded_file.tenant_id == orig_file.tenant_id
assert loaded_file.type == orig_file.type
assert loaded_file.filename == orig_file.filename
else:
assert loaded_segment.value == original.value
def test_all_variables_serialization(self):
"""Test serialization/deserialization of all variable types"""
# Create one instance of each variable type
test_file = create_test_file()
all_variables: list[VariableUnion] = [
NoneVariable(name="none_var"),
StringVariable(value="test string", name="string_var"),
IntegerVariable(value=42, name="int_var"),
FloatVariable(value=3.14, name="float_var"),
ObjectVariable(value={"key": "value", "number": 123}, name="object_var"),
FileVariable(value=test_file, name="file_var"),
ArrayAnyVariable(value=[1, "string", 3.14, {"key": "value"}], name="array_any_var"),
ArrayStringVariable(value=["hello", "world"], name="array_string_var"),
ArrayNumberVariable(value=[1, 2.5, 3], name="array_number_var"),
ArrayObjectVariable(value=[{"id": 1}, {"id": 2}], name="array_object_var"),
ArrayFileVariable(value=[], name="array_file_var"), # Empty array to avoid file complexity
]
# Test serialization and deserialization
model = _Variables(variables=all_variables)
json_str = model.model_dump_json()
loaded = _Variables.model_validate_json(json_str)
# Verify all variables are preserved
assert len(loaded.variables) == len(all_variables)
for original, loaded_variable in zip(all_variables, loaded.variables):
assert type(loaded_variable) == type(original)
assert loaded_variable.value_type == original.value_type
assert loaded_variable.name == original.name
# For file variables, compare key properties instead of exact equality
if isinstance(original, FileVariable) and isinstance(loaded_variable, FileVariable):
orig_file = original.value
loaded_file = loaded_variable.value
assert isinstance(orig_file, File)
assert isinstance(loaded_file, File)
assert loaded_file.tenant_id == orig_file.tenant_id
assert loaded_file.type == orig_file.type
assert loaded_file.filename == orig_file.filename
else:
assert loaded_variable.value == original.value
def test_segment_discriminator_function_for_segment_types(self):
"""Test the segment discriminator function"""
@dataclasses.dataclass
class TestCase:
segment: Segment
expected_segment_type: SegmentType
file1 = create_test_file()
file2 = create_test_file(filename="test2.txt")
cases = [
TestCase(
NoneSegment(),
SegmentType.NONE,
),
TestCase(
StringSegment(value=""),
SegmentType.STRING,
),
TestCase(
FloatSegment(value=0.0),
SegmentType.FLOAT,
),
TestCase(
IntegerSegment(value=0),
SegmentType.INTEGER,
),
TestCase(
ObjectSegment(value={}),
SegmentType.OBJECT,
),
TestCase(
FileSegment(value=file1),
SegmentType.FILE,
),
TestCase(
ArrayAnySegment(value=[0, 0.0, ""]),
SegmentType.ARRAY_ANY,
),
TestCase(
ArrayStringSegment(value=[""]),
SegmentType.ARRAY_STRING,
),
TestCase(
ArrayNumberSegment(value=[0, 0.0]),
SegmentType.ARRAY_NUMBER,
),
TestCase(
ArrayObjectSegment(value=[{}]),
SegmentType.ARRAY_OBJECT,
),
TestCase(
ArrayFileSegment(value=[file1, file2]),
SegmentType.ARRAY_FILE,
),
]
for test_case in cases:
segment = test_case.segment
assert get_segment_discriminator(segment) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for type {type(segment)}"
)
model_dict = segment.model_dump(mode="json")
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for serialized form of type {type(segment)}"
)
def test_variable_discriminator_function_for_variable_types(self):
"""Test the variable discriminator function"""
@dataclasses.dataclass
class TestCase:
variable: Variable
expected_segment_type: SegmentType
file1 = create_test_file()
file2 = create_test_file(filename="test2.txt")
cases = [
TestCase(
NoneVariable(name="none_var"),
SegmentType.NONE,
),
TestCase(
StringVariable(value="test", name="string_var"),
SegmentType.STRING,
),
TestCase(
FloatVariable(value=0.0, name="float_var"),
SegmentType.FLOAT,
),
TestCase(
IntegerVariable(value=0, name="int_var"),
SegmentType.INTEGER,
),
TestCase(
ObjectVariable(value={}, name="object_var"),
SegmentType.OBJECT,
),
TestCase(
FileVariable(value=file1, name="file_var"),
SegmentType.FILE,
),
TestCase(
SecretVariable(value="secret", name="secret_var"),
SegmentType.SECRET,
),
TestCase(
ArrayAnyVariable(value=[0, 0.0, ""], name="array_any_var"),
SegmentType.ARRAY_ANY,
),
TestCase(
ArrayStringVariable(value=[""], name="array_string_var"),
SegmentType.ARRAY_STRING,
),
TestCase(
ArrayNumberVariable(value=[0, 0.0], name="array_number_var"),
SegmentType.ARRAY_NUMBER,
),
TestCase(
ArrayObjectVariable(value=[{}], name="array_object_var"),
SegmentType.ARRAY_OBJECT,
),
TestCase(
ArrayFileVariable(value=[file1, file2], name="array_file_var"),
SegmentType.ARRAY_FILE,
),
]
for test_case in cases:
variable = test_case.variable
assert get_segment_discriminator(variable) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for type {type(variable)}"
)
model_dict = variable.model_dump(mode="json")
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for serialized form of type {type(variable)}"
)
def test_invalid_value_for_discriminator(self):
# Test invalid cases
assert get_segment_discriminator({"value_type": "invalid"}) is None
assert get_segment_discriminator({}) is None
assert get_segment_discriminator("not_a_dict") is None
assert get_segment_discriminator(42) is None
assert get_segment_discriminator(object) is None

View File

@@ -0,0 +1,165 @@
import pytest
from core.variables.types import ArrayValidation, SegmentType
class TestSegmentTypeIsArrayType:
"""
Test class for SegmentType.is_array_type method.
Provides comprehensive coverage of all SegmentType values to ensure
correct identification of array and non-array types.
"""
def test_is_array_type(self):
"""
Test that all SegmentType enum values are covered in our test cases.
Ensures comprehensive coverage by verifying that every SegmentType
value is tested for the is_array_type method.
"""
# Arrange
expected_array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
]
expected_non_array_types = [
SegmentType.INTEGER,
SegmentType.FLOAT,
SegmentType.NUMBER,
SegmentType.STRING,
SegmentType.OBJECT,
SegmentType.SECRET,
SegmentType.FILE,
SegmentType.NONE,
SegmentType.GROUP,
SegmentType.BOOLEAN,
]
for seg_type in expected_array_types:
assert seg_type.is_array_type()
for seg_type in expected_non_array_types:
assert not seg_type.is_array_type()
# Act & Assert
covered_types = set(expected_array_types) | set(expected_non_array_types)
assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests"
def test_all_enum_values_are_supported(self):
"""
Test that all enum values are supported and return boolean values.
Validates that every SegmentType enum value can be processed by
is_array_type method and returns a boolean value.
"""
enum_values: list[SegmentType] = list(SegmentType)
for seg_type in enum_values:
is_array = seg_type.is_array_type()
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"
class TestSegmentTypeIsValidArrayValidation:
"""
Test SegmentType.is_valid with array types using different validation strategies.
"""
def test_array_validation_all_success(self):
value = ["hello", "world", "foo"]
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
def test_array_validation_all_fail(self):
value = ["hello", 123, "world"]
# Should return False, since 123 is not a string
assert not SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
def test_array_validation_first(self):
value = ["hello", 123, None]
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST)
def test_array_validation_none(self):
value = [1, 2, 3]
# validation is None, skip
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE)
class TestSegmentTypeGetZeroValue:
"""
Test class for SegmentType.get_zero_value static method.
Provides comprehensive coverage of all supported SegmentType values to ensure
correct zero value generation for each type.
"""
def test_array_types_return_empty_list(self):
"""Test that all array types return empty list segments."""
array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
]
for seg_type in array_types:
result = SegmentType.get_zero_value(seg_type)
assert result.value == []
assert result.value_type == seg_type
def test_object_returns_empty_dict(self):
"""Test that OBJECT type returns empty dictionary segment."""
result = SegmentType.get_zero_value(SegmentType.OBJECT)
assert result.value == {}
assert result.value_type == SegmentType.OBJECT
def test_string_returns_empty_string(self):
"""Test that STRING type returns empty string segment."""
result = SegmentType.get_zero_value(SegmentType.STRING)
assert result.value == ""
assert result.value_type == SegmentType.STRING
def test_integer_returns_zero(self):
"""Test that INTEGER type returns zero segment."""
result = SegmentType.get_zero_value(SegmentType.INTEGER)
assert result.value == 0
assert result.value_type == SegmentType.INTEGER
def test_float_returns_zero_point_zero(self):
"""Test that FLOAT type returns 0.0 segment."""
result = SegmentType.get_zero_value(SegmentType.FLOAT)
assert result.value == 0.0
assert result.value_type == SegmentType.FLOAT
def test_number_returns_zero(self):
"""Test that NUMBER type returns zero segment."""
result = SegmentType.get_zero_value(SegmentType.NUMBER)
assert result.value == 0
# NUMBER type with integer value returns INTEGER segment type
# (NUMBER is a union type that can be INTEGER or FLOAT)
assert result.value_type == SegmentType.INTEGER
# Verify that exposed_type returns NUMBER for frontend compatibility
assert result.value_type.exposed_type() == SegmentType.NUMBER
def test_boolean_returns_false(self):
"""Test that BOOLEAN type returns False segment."""
result = SegmentType.get_zero_value(SegmentType.BOOLEAN)
assert result.value is False
assert result.value_type == SegmentType.BOOLEAN
def test_unsupported_types_raise_value_error(self):
"""Test that unsupported types raise ValueError."""
unsupported_types = [
SegmentType.SECRET,
SegmentType.FILE,
SegmentType.NONE,
SegmentType.GROUP,
SegmentType.ARRAY_FILE,
]
for seg_type in unsupported_types:
with pytest.raises(ValueError, match="unsupported variable type"):
SegmentType.get_zero_value(seg_type)

View File

@@ -0,0 +1,849 @@
"""
Comprehensive unit tests for SegmentType.is_valid and SegmentType._validate_array methods.
This module provides thorough testing of the validation logic for all SegmentType values,
including edge cases, error conditions, and different ArrayValidation strategies.
"""
from dataclasses import dataclass
from typing import Any
import pytest
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
from core.variables.segment_group import SegmentGroup
from core.variables.segments import (
ArrayFileSegment,
BooleanSegment,
FileSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
StringSegment,
)
from core.variables.types import ArrayValidation, SegmentType
def create_test_file(
file_type: FileType = FileType.DOCUMENT,
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
filename: str = "test.txt",
extension: str = ".txt",
mime_type: str = "text/plain",
size: int = 1024,
) -> File:
"""Factory function to create File objects for testing."""
return File(
tenant_id="test-tenant",
type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
storage_key="test-storage-key",
)
@dataclass
class ValidationTestCase:
"""Test case data structure for validation tests."""
segment_type: SegmentType
value: Any
expected: bool
description: str
def get_id(self):
return self.description
@dataclass
class ArrayValidationTestCase:
"""Test case data structure for array validation tests."""
segment_type: SegmentType
value: Any
array_validation: ArrayValidation
expected: bool
description: str
def get_id(self):
return self.description
# Test data construction functions
def get_boolean_cases() -> list[ValidationTestCase]:
return [
# valid values
ValidationTestCase(SegmentType.BOOLEAN, True, True, "True boolean"),
ValidationTestCase(SegmentType.BOOLEAN, False, True, "False boolean"),
# Invalid values
ValidationTestCase(SegmentType.BOOLEAN, 1, False, "Integer 1 (not boolean)"),
ValidationTestCase(SegmentType.BOOLEAN, 0, False, "Integer 0 (not boolean)"),
ValidationTestCase(SegmentType.BOOLEAN, "true", False, "String 'true'"),
ValidationTestCase(SegmentType.BOOLEAN, "false", False, "String 'false'"),
ValidationTestCase(SegmentType.BOOLEAN, None, False, "None value"),
ValidationTestCase(SegmentType.BOOLEAN, [], False, "Empty list"),
ValidationTestCase(SegmentType.BOOLEAN, {}, False, "Empty dict"),
]
def get_number_cases() -> list[ValidationTestCase]:
"""Get test cases for valid number values."""
return [
# valid values
ValidationTestCase(SegmentType.NUMBER, 42, True, "Positive integer"),
ValidationTestCase(SegmentType.NUMBER, -42, True, "Negative integer"),
ValidationTestCase(SegmentType.NUMBER, 0, True, "Zero integer"),
ValidationTestCase(SegmentType.NUMBER, 3.14, True, "Positive float"),
ValidationTestCase(SegmentType.NUMBER, -3.14, True, "Negative float"),
ValidationTestCase(SegmentType.NUMBER, 0.0, True, "Zero float"),
ValidationTestCase(SegmentType.NUMBER, float("inf"), True, "Positive infinity"),
ValidationTestCase(SegmentType.NUMBER, float("-inf"), True, "Negative infinity"),
ValidationTestCase(SegmentType.NUMBER, float("nan"), True, "float(NaN)"),
# invalid number values
ValidationTestCase(SegmentType.NUMBER, "42", False, "String number"),
ValidationTestCase(SegmentType.NUMBER, None, False, "None value"),
ValidationTestCase(SegmentType.NUMBER, [], False, "Empty list"),
ValidationTestCase(SegmentType.NUMBER, {}, False, "Empty dict"),
ValidationTestCase(SegmentType.NUMBER, "3.14", False, "String float"),
]
def get_string_cases() -> list[ValidationTestCase]:
"""Get test cases for valid string values."""
return [
# valid values
ValidationTestCase(SegmentType.STRING, "", True, "Empty string"),
ValidationTestCase(SegmentType.STRING, "hello", True, "Simple string"),
ValidationTestCase(SegmentType.STRING, "🚀", True, "Unicode emoji"),
ValidationTestCase(SegmentType.STRING, "line1\nline2", True, "Multiline string"),
# invalid values
ValidationTestCase(SegmentType.STRING, 123, False, "Integer"),
ValidationTestCase(SegmentType.STRING, 3.14, False, "Float"),
ValidationTestCase(SegmentType.STRING, True, False, "Boolean"),
ValidationTestCase(SegmentType.STRING, None, False, "None value"),
ValidationTestCase(SegmentType.STRING, [], False, "Empty list"),
ValidationTestCase(SegmentType.STRING, {}, False, "Empty dict"),
]
def get_object_cases() -> list[ValidationTestCase]:
"""Get test cases for valid object values."""
return [
# valid cases
ValidationTestCase(SegmentType.OBJECT, {}, True, "Empty dict"),
ValidationTestCase(SegmentType.OBJECT, {"key": "value"}, True, "Simple dict"),
ValidationTestCase(SegmentType.OBJECT, {"a": 1, "b": 2}, True, "Dict with numbers"),
ValidationTestCase(SegmentType.OBJECT, {"nested": {"key": "value"}}, True, "Nested dict"),
ValidationTestCase(SegmentType.OBJECT, {"list": [1, 2, 3]}, True, "Dict with list"),
ValidationTestCase(SegmentType.OBJECT, {"mixed": [1, "two", {"three": 3}]}, True, "Complex dict"),
# invalid cases
ValidationTestCase(SegmentType.OBJECT, "not a dict", False, "String"),
ValidationTestCase(SegmentType.OBJECT, 123, False, "Integer"),
ValidationTestCase(SegmentType.OBJECT, 3.14, False, "Float"),
ValidationTestCase(SegmentType.OBJECT, True, False, "Boolean"),
ValidationTestCase(SegmentType.OBJECT, None, False, "None value"),
ValidationTestCase(SegmentType.OBJECT, [], False, "Empty list"),
ValidationTestCase(SegmentType.OBJECT, [1, 2, 3], False, "List with values"),
]
def get_secret_cases() -> list[ValidationTestCase]:
"""Get test cases for valid secret values."""
return [
# valid cases
ValidationTestCase(SegmentType.SECRET, "", True, "Empty secret"),
ValidationTestCase(SegmentType.SECRET, "secret", True, "Simple secret"),
ValidationTestCase(SegmentType.SECRET, "api_key_123", True, "API key format"),
ValidationTestCase(SegmentType.SECRET, "very_long_secret_key_with_special_chars!@#", True, "Complex secret"),
# invalid cases
ValidationTestCase(SegmentType.SECRET, 123, False, "Integer"),
ValidationTestCase(SegmentType.SECRET, 3.14, False, "Float"),
ValidationTestCase(SegmentType.SECRET, True, False, "Boolean"),
ValidationTestCase(SegmentType.SECRET, None, False, "None value"),
ValidationTestCase(SegmentType.SECRET, [], False, "Empty list"),
ValidationTestCase(SegmentType.SECRET, {}, False, "Empty dict"),
]
def get_file_cases() -> list[ValidationTestCase]:
"""Get test cases for valid file values."""
test_file = create_test_file()
image_file = create_test_file(
file_type=FileType.IMAGE, filename="image.jpg", extension=".jpg", mime_type="image/jpeg"
)
remote_file = create_test_file(
transfer_method=FileTransferMethod.REMOTE_URL, filename="remote.pdf", extension=".pdf"
)
return [
# valid cases
ValidationTestCase(SegmentType.FILE, test_file, True, "Document file"),
ValidationTestCase(SegmentType.FILE, image_file, True, "Image file"),
ValidationTestCase(SegmentType.FILE, remote_file, True, "Remote file"),
# invalid cases
ValidationTestCase(SegmentType.FILE, "not a file", False, "String"),
ValidationTestCase(SegmentType.FILE, 123, False, "Integer"),
ValidationTestCase(SegmentType.FILE, {"filename": "test.txt"}, False, "Dict resembling file"),
ValidationTestCase(SegmentType.FILE, None, False, "None value"),
ValidationTestCase(SegmentType.FILE, [], False, "Empty list"),
ValidationTestCase(SegmentType.FILE, True, False, "Boolean"),
]
def get_none_cases() -> list[ValidationTestCase]:
"""Get test cases for valid none values."""
return [
# valid cases
ValidationTestCase(SegmentType.NONE, None, True, "None value"),
# invalid cases
ValidationTestCase(SegmentType.NONE, "", False, "Empty string"),
ValidationTestCase(SegmentType.NONE, 0, False, "Zero integer"),
ValidationTestCase(SegmentType.NONE, 0.0, False, "Zero float"),
ValidationTestCase(SegmentType.NONE, False, False, "False boolean"),
ValidationTestCase(SegmentType.NONE, [], False, "Empty list"),
ValidationTestCase(SegmentType.NONE, {}, False, "Empty dict"),
ValidationTestCase(SegmentType.NONE, "null", False, "String 'null'"),
]
def get_group_cases() -> list[ValidationTestCase]:
"""Get test cases for valid group values."""
test_file = create_test_file()
segments = [
StringSegment(value="hello"),
IntegerSegment(value=42),
BooleanSegment(value=True),
ObjectSegment(value={"key": "value"}),
FileSegment(value=test_file),
NoneSegment(value=None),
]
return [
# valid cases
ValidationTestCase(
SegmentType.GROUP, SegmentGroup(value=segments), True, "Valid SegmentGroup with mixed segments"
),
ValidationTestCase(
SegmentType.GROUP, [StringSegment(value="test"), IntegerSegment(value=123)], True, "List of Segment objects"
),
ValidationTestCase(SegmentType.GROUP, SegmentGroup(value=[]), True, "Empty SegmentGroup"),
ValidationTestCase(SegmentType.GROUP, [], True, "Empty list"),
# invalid cases
ValidationTestCase(SegmentType.GROUP, "not a list", False, "String value"),
ValidationTestCase(SegmentType.GROUP, 123, False, "Integer value"),
ValidationTestCase(SegmentType.GROUP, True, False, "Boolean value"),
ValidationTestCase(SegmentType.GROUP, None, False, "None value"),
ValidationTestCase(SegmentType.GROUP, {"key": "value"}, False, "Dict value"),
ValidationTestCase(SegmentType.GROUP, test_file, False, "File value"),
ValidationTestCase(SegmentType.GROUP, ["string", 123, True], False, "List with non-Segment objects"),
ValidationTestCase(
SegmentType.GROUP,
[StringSegment(value="test"), "not a segment"],
False,
"Mixed list with some non-Segment objects",
),
]
def get_array_any_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_ANY validation."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_ANY,
[1, "string", 3.14, {"key": "value"}, True],
ArrayValidation.NONE,
True,
"Mixed types with NONE validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_ANY,
[1, "string", 3.14, {"key": "value"}, True],
ArrayValidation.FIRST,
True,
"Mixed types with FIRST validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_ANY,
[1, "string", 3.14, {"key": "value"}, True],
ArrayValidation.ALL,
True,
"Mixed types with ALL validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_ANY, [None, None, None], ArrayValidation.ALL, True, "All None values"
),
]
def get_array_string_validation_none_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_STRING validation with NONE strategy."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
["hello", "world"],
ArrayValidation.NONE,
True,
"Valid strings with NONE validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
[123, 456],
ArrayValidation.NONE,
True,
"Invalid elements with NONE validation",
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
["valid", 123, True],
ArrayValidation.NONE,
True,
"Mixed types with NONE validation",
),
]
def get_array_string_validation_first_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_STRING validation with FIRST strategy."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["hello", "world"], ArrayValidation.FIRST, True, "All valid strings"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
["hello", 123, True],
ArrayValidation.FIRST,
True,
"First valid, others invalid",
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING,
[123, "hello", "world"],
ArrayValidation.FIRST,
False,
"First invalid, others valid",
),
ArrayValidationTestCase(SegmentType.ARRAY_STRING, [None, "hello"], ArrayValidation.FIRST, False, "First None"),
]
def get_array_string_validation_all_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_STRING validation with ALL strategy."""
return [
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["hello", "world", "test"], ArrayValidation.ALL, True, "All valid strings"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["hello", 123, "world"], ArrayValidation.ALL, False, "One invalid element"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, [123, 456, 789], ArrayValidation.ALL, False, "All invalid elements"
),
ArrayValidationTestCase(
SegmentType.ARRAY_STRING, ["valid", None, "also_valid"], ArrayValidation.ALL, False, "Contains None"
),
]
def get_array_number_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_NUMBER validation with different strategies."""
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [1, 2.5, 3], ArrayValidation.NONE, True, "Valid numbers with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, ["not", "numbers"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [42, "not a number"], ArrayValidation.FIRST, True, "First valid number"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, ["not a number", 42], ArrayValidation.FIRST, False, "First invalid"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [3.14, 2.71, 1.41], ArrayValidation.FIRST, True, "All valid floats"
),
# ALL strategy
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [1, 2, 3, 4.5], ArrayValidation.ALL, True, "All valid numbers"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER, [1, "invalid", 3], ArrayValidation.ALL, False, "One invalid element"
),
ArrayValidationTestCase(
SegmentType.ARRAY_NUMBER,
[float("inf"), float("-inf"), float("nan")],
ArrayValidation.ALL,
True,
"Special float values",
),
]
def get_array_object_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_OBJECT validation with different strategies."""
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT, [{}, {"key": "value"}], ArrayValidation.NONE, True, "Valid objects with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT, ["not", "objects"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
[{"valid": "object"}, "not an object"],
ArrayValidation.FIRST,
True,
"First valid object",
),
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
["not an object", {"valid": "object"}],
ArrayValidation.FIRST,
False,
"First invalid",
),
# ALL strategy
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
[{}, {"a": 1}, {"nested": {"key": "value"}}],
ArrayValidation.ALL,
True,
"All valid objects",
),
ArrayValidationTestCase(
SegmentType.ARRAY_OBJECT,
[{"valid": "object"}, "invalid", {"another": "object"}],
ArrayValidation.ALL,
False,
"One invalid element",
),
]
def get_array_file_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_FILE validation with different strategies."""
file1 = create_test_file(filename="file1.txt")
file2 = create_test_file(filename="file2.txt")
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.NONE, True, "Valid files with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, ["not", "files"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, [file1, "not a file"], ArrayValidation.FIRST, True, "First valid file"
),
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, ["not a file", file1], ArrayValidation.FIRST, False, "First invalid"
),
# ALL strategy
ArrayValidationTestCase(SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.ALL, True, "All valid files"),
ArrayValidationTestCase(
SegmentType.ARRAY_FILE, [file1, "invalid", file2], ArrayValidation.ALL, False, "One invalid element"
),
]
def get_array_boolean_validation_cases() -> list[ArrayValidationTestCase]:
"""Get test cases for ARRAY_BOOLEAN validation with different strategies."""
return [
# NONE strategy
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, False, True], ArrayValidation.NONE, True, "Valid booleans with NONE"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [1, 0, "true"], ArrayValidation.NONE, True, "Invalid elements with NONE"
),
# FIRST strategy
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, 1, 0], ArrayValidation.FIRST, True, "First valid boolean"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [1, True, False], ArrayValidation.FIRST, False, "First invalid (integer 1)"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [0, True, False], ArrayValidation.FIRST, False, "First invalid (integer 0)"
),
# ALL strategy
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, False, True, False], ArrayValidation.ALL, True, "All valid booleans"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN, [True, 1, False], ArrayValidation.ALL, False, "One invalid element (integer)"
),
ArrayValidationTestCase(
SegmentType.ARRAY_BOOLEAN,
[True, "false", False],
ArrayValidation.ALL,
False,
"One invalid element (string)",
),
]
class TestSegmentTypeIsValid:
"""Test suite for SegmentType.is_valid method covering all non-array types."""
@pytest.mark.parametrize("case", get_boolean_cases(), ids=lambda case: case.description)
def test_boolean_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_number_cases(), ids=lambda case: case.description)
def test_number_validation(self, case: ValidationTestCase):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_string_cases(), ids=lambda case: case.description)
def test_string_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_object_cases(), ids=lambda case: case.description)
def test_object_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_secret_cases(), ids=lambda case: case.description)
def test_secret_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_file_cases(), ids=lambda case: case.description)
def test_file_validation(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_none_cases(), ids=lambda case: case.description)
def test_none_validation_valid_cases(self, case):
assert case.segment_type.is_valid(case.value) == case.expected
@pytest.mark.parametrize("case", get_group_cases(), ids=lambda case: case.description)
def test_group_validation(self, case):
"""Test GROUP type validation with various inputs."""
assert case.segment_type.is_valid(case.value) == case.expected
def test_group_validation_edge_cases(self):
"""Test GROUP validation edge cases."""
test_file = create_test_file()
# Test with nested SegmentGroups
inner_group = SegmentGroup(value=[StringSegment(value="inner"), IntegerSegment(value=42)])
outer_group = SegmentGroup(value=[StringSegment(value="outer"), inner_group])
assert SegmentType.GROUP.is_valid(outer_group) is True
# Test with ArrayFileSegment (which is also a Segment)
file_segment = FileSegment(value=test_file)
array_file_segment = ArrayFileSegment(value=[test_file, test_file])
group_with_arrays = SegmentGroup(value=[file_segment, array_file_segment, StringSegment(value="test")])
assert SegmentType.GROUP.is_valid(group_with_arrays) is True
# Test performance with large number of segments
large_segment_list = [StringSegment(value=f"item_{i}") for i in range(1000)]
large_group = SegmentGroup(value=large_segment_list)
assert SegmentType.GROUP.is_valid(large_group) is True
def test_no_truly_unsupported_segment_types_exist(self):
"""Test that all SegmentType enum values are properly handled in is_valid method.
This test ensures there are no SegmentType values that would raise AssertionError.
If this test fails, it means a new SegmentType was added without proper validation support.
"""
# Test that ALL segment types are handled and don't raise AssertionError
all_segment_types = set(SegmentType)
for segment_type in all_segment_types:
# Create a valid test value for each type
test_value: Any = None
if segment_type == SegmentType.STRING:
test_value = "test"
elif segment_type in {SegmentType.NUMBER, SegmentType.INTEGER}:
test_value = 42
elif segment_type == SegmentType.FLOAT:
test_value = 3.14
elif segment_type == SegmentType.BOOLEAN:
test_value = True
elif segment_type == SegmentType.OBJECT:
test_value = {"key": "value"}
elif segment_type == SegmentType.SECRET:
test_value = "secret"
elif segment_type == SegmentType.FILE:
test_value = create_test_file()
elif segment_type == SegmentType.NONE:
test_value = None
elif segment_type == SegmentType.GROUP:
test_value = SegmentGroup(value=[StringSegment(value="test")])
elif segment_type.is_array_type():
test_value = [] # Empty array is valid for all array types
else:
# If we get here, there's a segment type we don't know how to test
# This should prompt us to add validation logic
pytest.fail(f"Unknown segment type {segment_type} needs validation logic and test case")
# This should NOT raise AssertionError
try:
result = segment_type.is_valid(test_value)
assert isinstance(result, bool), f"is_valid should return boolean for {segment_type}"
except AssertionError as e:
pytest.fail(
f"SegmentType.{segment_type.name}.is_valid() raised AssertionError: {e}. "
"This segment type needs to be handled in the is_valid method."
)
class TestSegmentTypeArrayValidation:
"""Test suite for SegmentType._validate_array method and array type validation."""
def test_array_validation_non_list_values(self):
"""Test that non-list values return False for all array types."""
array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
]
non_list_values = [
"not a list",
123,
3.14,
True,
None,
{"key": "value"},
create_test_file(),
]
for array_type in array_types:
for value in non_list_values:
assert array_type.is_valid(value) is False, f"{array_type} should reject {type(value).__name__}"
def test_empty_array_validation(self):
"""Test that empty arrays are valid for all array types regardless of validation strategy."""
array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
]
validation_strategies = [ArrayValidation.NONE, ArrayValidation.FIRST, ArrayValidation.ALL]
for array_type in array_types:
for strategy in validation_strategies:
assert array_type.is_valid([], strategy) is True, (
f"{array_type} should accept empty array with {strategy}"
)
@pytest.mark.parametrize("case", get_array_any_validation_cases(), ids=lambda case: case.description)
def test_array_any_validation(self, case):
"""Test ARRAY_ANY validation accepts any list regardless of content."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_string_validation_none_cases(), ids=lambda case: case.description)
def test_array_string_validation_with_none_strategy(self, case):
"""Test ARRAY_STRING validation with NONE strategy (no element validation)."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_string_validation_first_cases(), ids=lambda case: case.description)
def test_array_string_validation_with_first_strategy(self, case):
"""Test ARRAY_STRING validation with FIRST strategy (validate first element only)."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_string_validation_all_cases(), ids=lambda case: case.description)
def test_array_string_validation_with_all_strategy(self, case):
"""Test ARRAY_STRING validation with ALL strategy (validate all elements)."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_number_validation_cases(), ids=lambda case: case.description)
def test_array_number_validation_with_different_strategies(self, case):
"""Test ARRAY_NUMBER validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_object_validation_cases(), ids=lambda case: case.description)
def test_array_object_validation_with_different_strategies(self, case):
"""Test ARRAY_OBJECT validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_file_validation_cases(), ids=lambda case: case.description)
def test_array_file_validation_with_different_strategies(self, case):
"""Test ARRAY_FILE validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
@pytest.mark.parametrize("case", get_array_boolean_validation_cases(), ids=lambda case: case.description)
def test_array_boolean_validation_with_different_strategies(self, case):
"""Test ARRAY_BOOLEAN validation with different validation strategies."""
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
def test_default_array_validation_strategy(self):
"""Test that default array validation strategy is FIRST."""
# When no array_validation parameter is provided, it should default to FIRST
assert SegmentType.ARRAY_STRING.is_valid(["valid", 123]) is False # First element valid
assert SegmentType.ARRAY_STRING.is_valid([123, "valid"]) is False # First element invalid
assert SegmentType.ARRAY_NUMBER.is_valid([42, "invalid"]) is False # First element valid
assert SegmentType.ARRAY_NUMBER.is_valid(["invalid", 42]) is False # First element invalid
def test_array_validation_edge_cases(self):
"""Test edge cases for array validation."""
# Test with nested arrays (should be invalid for specific array types)
nested_array = [["nested", "array"], ["another", "nested"]]
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.FIRST) is False
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.ALL) is False
assert SegmentType.ARRAY_ANY.is_valid(nested_array, ArrayValidation.ALL) is True
# Test with very large arrays (performance consideration)
large_valid_array = ["string"] * 1000
large_mixed_array = ["string"] * 999 + [123] # Last element invalid
assert SegmentType.ARRAY_STRING.is_valid(large_valid_array, ArrayValidation.ALL) is True
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.ALL) is False
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.FIRST) is True
class TestSegmentTypeValidationIntegration:
"""Integration tests for SegmentType validation covering interactions between methods."""
def test_non_array_types_ignore_array_validation_parameter(self):
"""Test that non-array types ignore the array_validation parameter."""
non_array_types = [
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
SegmentType.OBJECT,
SegmentType.SECRET,
SegmentType.FILE,
SegmentType.NONE,
SegmentType.GROUP,
]
for segment_type in non_array_types:
# Create appropriate valid value for each type
valid_value: Any
if segment_type == SegmentType.STRING:
valid_value = "test"
elif segment_type == SegmentType.NUMBER:
valid_value = 42
elif segment_type == SegmentType.BOOLEAN:
valid_value = True
elif segment_type == SegmentType.OBJECT:
valid_value = {"key": "value"}
elif segment_type == SegmentType.SECRET:
valid_value = "secret"
elif segment_type == SegmentType.FILE:
valid_value = create_test_file()
elif segment_type == SegmentType.NONE:
valid_value = None
elif segment_type == SegmentType.GROUP:
valid_value = SegmentGroup(value=[StringSegment(value="test")])
else:
continue # Skip unsupported types
# All array validation strategies should give the same result
result_none = segment_type.is_valid(valid_value, ArrayValidation.NONE)
result_first = segment_type.is_valid(valid_value, ArrayValidation.FIRST)
result_all = segment_type.is_valid(valid_value, ArrayValidation.ALL)
assert result_none == result_first == result_all == True, (
f"{segment_type} should ignore array_validation parameter"
)
def test_comprehensive_type_coverage(self):
"""Test that all SegmentType enum values are covered in validation tests."""
all_segment_types = set(SegmentType)
# Types that should be handled by is_valid method
handled_types = {
# Non-array types
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
SegmentType.OBJECT,
SegmentType.SECRET,
SegmentType.FILE,
SegmentType.NONE,
SegmentType.GROUP,
# Array types
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
SegmentType.ARRAY_BOOLEAN,
}
# Types that are not handled by is_valid (should raise AssertionError)
unhandled_types = {
SegmentType.INTEGER, # Handled by NUMBER validation logic
SegmentType.FLOAT, # Handled by NUMBER validation logic
}
# Verify all types are accounted for
assert handled_types | unhandled_types == all_segment_types, "All SegmentType values should be categorized"
# Test that handled types work correctly
for segment_type in handled_types:
if segment_type.is_array_type():
# Test with empty array (should always be valid)
assert segment_type.is_valid([]) is True, f"{segment_type} should accept empty array"
else:
# Test with appropriate valid value
if segment_type == SegmentType.STRING:
assert segment_type.is_valid("test") is True
elif segment_type == SegmentType.NUMBER:
assert segment_type.is_valid(42) is True
elif segment_type == SegmentType.BOOLEAN:
assert segment_type.is_valid(True) is True
elif segment_type == SegmentType.OBJECT:
assert segment_type.is_valid({}) is True
elif segment_type == SegmentType.SECRET:
assert segment_type.is_valid("secret") is True
elif segment_type == SegmentType.FILE:
assert segment_type.is_valid(create_test_file()) is True
elif segment_type == SegmentType.NONE:
assert segment_type.is_valid(None) is True
elif segment_type == SegmentType.GROUP:
assert segment_type.is_valid(SegmentGroup(value=[StringSegment(value="test")])) is True
def test_boolean_vs_integer_type_distinction(self):
"""Test the important distinction between boolean and integer types in validation."""
# This tests the comment in the code about bool being a subclass of int
# Boolean type should only accept actual booleans, not integers
assert SegmentType.BOOLEAN.is_valid(True) is True
assert SegmentType.BOOLEAN.is_valid(False) is True
assert SegmentType.BOOLEAN.is_valid(1) is False # Integer 1, not boolean
assert SegmentType.BOOLEAN.is_valid(0) is False # Integer 0, not boolean
# Number type should accept both integers and floats, including booleans (since bool is subclass of int)
assert SegmentType.NUMBER.is_valid(42) is True
assert SegmentType.NUMBER.is_valid(3.14) is True
assert SegmentType.NUMBER.is_valid(True) is True # bool is subclass of int
assert SegmentType.NUMBER.is_valid(False) is True # bool is subclass of int
def test_array_validation_recursive_behavior(self):
"""Test that array validation correctly handles recursive validation calls."""
# When validating array elements, _validate_array calls is_valid recursively
# with ArrayValidation.NONE to avoid infinite recursion
# Test nested validation doesn't cause issues
nested_arrays = [["inner", "array"], ["another", "inner"]]
# ARRAY_ANY should accept nested arrays
assert SegmentType.ARRAY_ANY.is_valid(nested_arrays, ArrayValidation.ALL) is True
# ARRAY_STRING should reject nested arrays (first element is not a string)
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.FIRST) is False
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.ALL) is False

View File

@@ -0,0 +1,91 @@
import pytest
from pydantic import ValidationError
from core.variables import (
ArrayFileVariable,
ArrayVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
SecretVariable,
SegmentType,
StringVariable,
)
from core.variables.variables import Variable
def test_frozen_variables():
var = StringVariable(name="text", value="text")
with pytest.raises(ValidationError):
var.value = "new value"
int_var = IntegerVariable(name="integer", value=42)
with pytest.raises(ValidationError):
int_var.value = 100
float_var = FloatVariable(name="float", value=3.14)
with pytest.raises(ValidationError):
float_var.value = 2.718
secret_var = SecretVariable(name="secret", value="secret_value")
with pytest.raises(ValidationError):
secret_var.value = "new_secret_value"
def test_variable_value_type_immutable():
with pytest.raises(ValidationError):
StringVariable(value_type=SegmentType.ARRAY_ANY, name="text", value="text")
with pytest.raises(ValidationError):
StringVariable.model_validate({"value_type": "not text", "name": "text", "value": "text"})
var = IntegerVariable(name="integer", value=42)
with pytest.raises(ValidationError):
IntegerVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
var = FloatVariable(name="float", value=3.14)
with pytest.raises(ValidationError):
FloatVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
var = SecretVariable(name="secret", value="secret_value")
with pytest.raises(ValidationError):
SecretVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
def test_object_variable_to_object():
var = ObjectVariable(
name="object",
value={
"key1": {
"key2": "value2",
},
"key2": ["value5_1", 42, {}],
},
)
assert var.to_object() == {
"key1": {
"key2": "value2",
},
"key2": [
"value5_1",
42,
{},
],
}
def test_variable_to_object():
var: Variable = StringVariable(name="text", value="text")
assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42
var = FloatVariable(name="float", value=3.14)
assert var.to_object() == 3.14
var = SecretVariable(name="secret", value="secret_value")
assert var.to_object() == "secret_value"
def test_array_file_variable_is_array_variable():
var = ArrayFileVariable(name="files", value=[])
assert isinstance(var, ArrayVariable)

View File

@@ -0,0 +1,281 @@
import json
from time import time
from unittest.mock import MagicMock, patch
import pytest
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
class StubCoordinator:
def __init__(self) -> None:
self.state = "initial"
def dumps(self) -> str:
return json.dumps({"state": self.state})
def loads(self, data: str) -> None:
payload = json.loads(data)
self.state = payload["state"]
class TestGraphRuntimeState:
def test_property_getters_and_setters(self):
# FIXME(-LAN-): Mock VariablePool if needed
variable_pool = VariablePool()
start_time = time()
state = GraphRuntimeState(variable_pool=variable_pool, start_at=start_time)
# Test variable_pool property (read-only)
assert state.variable_pool == variable_pool
# Test start_at property
assert state.start_at == start_time
new_time = time() + 100
state.start_at = new_time
assert state.start_at == new_time
# Test total_tokens property
assert state.total_tokens == 0
state.total_tokens = 100
assert state.total_tokens == 100
# Test node_run_steps property
assert state.node_run_steps == 0
state.node_run_steps = 5
assert state.node_run_steps == 5
def test_outputs_immutability(self):
variable_pool = VariablePool()
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
# Test that getting outputs returns a copy
outputs1 = state.outputs
outputs2 = state.outputs
assert outputs1 == outputs2
assert outputs1 is not outputs2 # Different objects
# Test that modifying retrieved outputs doesn't affect internal state
outputs = state.outputs
outputs["test"] = "value"
assert "test" not in state.outputs
# Test set_output method
state.set_output("key1", "value1")
assert state.get_output("key1") == "value1"
# Test update_outputs method
state.update_outputs({"key2": "value2", "key3": "value3"})
assert state.get_output("key2") == "value2"
assert state.get_output("key3") == "value3"
def test_llm_usage_immutability(self):
variable_pool = VariablePool()
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
# Test that getting llm_usage returns a copy
usage1 = state.llm_usage
usage2 = state.llm_usage
assert usage1 is not usage2 # Different objects
def test_type_validation(self):
variable_pool = VariablePool()
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
# Test total_tokens validation
with pytest.raises(ValueError):
state.total_tokens = -1
# Test node_run_steps validation
with pytest.raises(ValueError):
state.node_run_steps = -1
def test_helper_methods(self):
variable_pool = VariablePool()
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
# Test increment_node_run_steps
initial_steps = state.node_run_steps
state.increment_node_run_steps()
assert state.node_run_steps == initial_steps + 1
# Test add_tokens
initial_tokens = state.total_tokens
state.add_tokens(50)
assert state.total_tokens == initial_tokens + 50
# Test add_tokens validation
with pytest.raises(ValueError):
state.add_tokens(-1)
def test_ready_queue_default_instantiation(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
queue = state.ready_queue
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
assert isinstance(queue, InMemoryReadyQueue)
assert state.ready_queue is queue
def test_graph_execution_lazy_instantiation(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
execution = state.graph_execution
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
assert isinstance(execution, GraphExecution)
assert execution.workflow_id == ""
assert state.graph_execution is execution
def test_response_coordinator_configuration(self):
variable_pool = VariablePool()
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
with pytest.raises(ValueError):
_ = state.response_coordinator
mock_graph = MagicMock()
with patch("core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator") as coordinator_cls:
coordinator_instance = MagicMock()
coordinator_cls.return_value = coordinator_instance
state.configure(graph=mock_graph)
assert state.response_coordinator is coordinator_instance
coordinator_cls.assert_called_once_with(variable_pool=variable_pool, graph=mock_graph)
# Configure again with same graph should be idempotent
state.configure(graph=mock_graph)
other_graph = MagicMock()
with pytest.raises(ValueError):
state.attach_graph(other_graph)
def test_read_only_wrapper_exposes_additional_state(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
state.configure()
wrapper = ReadOnlyGraphRuntimeStateWrapper(state)
assert wrapper.ready_queue_size == 0
assert wrapper.exceptions_count == 0
def test_read_only_wrapper_serializes_runtime_state(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
state.total_tokens = 5
state.set_output("result", {"success": True})
state.ready_queue.put("node-1")
wrapper = ReadOnlyGraphRuntimeStateWrapper(state)
wrapper_snapshot = json.loads(wrapper.dumps())
state_snapshot = json.loads(state.dumps())
assert wrapper_snapshot == state_snapshot
def test_dumps_and_loads_roundtrip_with_response_coordinator(self):
variable_pool = VariablePool()
variable_pool.add(("node1", "value"), "payload")
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
state.total_tokens = 10
state.node_run_steps = 3
state.set_output("final", {"result": True})
usage = LLMUsage.from_metadata(
{
"prompt_tokens": 2,
"completion_tokens": 3,
"total_tokens": 5,
"total_price": "1.23",
"currency": "USD",
"latency": 0.5,
}
)
state.llm_usage = usage
state.ready_queue.put("node-A")
graph_execution = state.graph_execution
graph_execution.workflow_id = "wf-123"
graph_execution.exceptions_count = 4
graph_execution.started = True
mock_graph = MagicMock()
stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub):
state.attach_graph(mock_graph)
stub.state = "configured"
snapshot = state.dumps()
restored = GraphRuntimeState.from_snapshot(snapshot)
assert restored.total_tokens == 10
assert restored.node_run_steps == 3
assert restored.get_output("final") == {"result": True}
assert restored.llm_usage.total_tokens == usage.total_tokens
assert restored.ready_queue.qsize() == 1
assert restored.ready_queue.get(timeout=0.01) == "node-A"
restored_segment = restored.variable_pool.get(("node1", "value"))
assert restored_segment is not None
assert restored_segment.value == "payload"
restored_execution = restored.graph_execution
assert restored_execution.workflow_id == "wf-123"
assert restored_execution.exceptions_count == 4
assert restored_execution.started is True
new_stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
restored.attach_graph(mock_graph)
assert new_stub.state == "configured"
def test_loads_rehydrates_existing_instance(self):
variable_pool = VariablePool()
variable_pool.add(("node", "key"), "value")
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
state.total_tokens = 7
state.node_run_steps = 2
state.set_output("foo", "bar")
state.ready_queue.put("node-1")
execution = state.graph_execution
execution.workflow_id = "wf-456"
execution.started = True
mock_graph = MagicMock()
original_stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub):
state.attach_graph(mock_graph)
original_stub.state = "configured"
snapshot = state.dumps()
new_stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
restored.attach_graph(mock_graph)
restored.loads(snapshot)
assert restored.total_tokens == 7
assert restored.node_run_steps == 2
assert restored.get_output("foo") == "bar"
assert restored.ready_queue.qsize() == 1
assert restored.ready_queue.get(timeout=0.01) == "node-1"
restored_segment = restored.variable_pool.get(("node", "key"))
assert restored_segment is not None
assert restored_segment.value == "value"
restored_execution = restored.graph_execution
assert restored_execution.workflow_id == "wf-456"
assert restored_execution.started is True
assert new_stub.state == "configured"

View File

@@ -0,0 +1,171 @@
"""Tests for _PrivateWorkflowPauseEntity implementation."""
from datetime import datetime
from unittest.mock import MagicMock, patch
from models.workflow import WorkflowPause as WorkflowPauseModel
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
class TestPrivateWorkflowPauseEntity:
"""Test _PrivateWorkflowPauseEntity implementation."""
def test_entity_initialization(self):
"""Test entity initialization with required parameters."""
# Create mock models
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.id = "pause-123"
mock_pause_model.workflow_run_id = "execution-456"
mock_pause_model.resumed_at = None
# Create entity
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
# Verify initialization
assert entity._pause_model is mock_pause_model
assert entity._cached_state is None
def test_from_models_classmethod(self):
"""Test from_models class method."""
# Create mock models
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.id = "pause-123"
mock_pause_model.workflow_run_id = "execution-456"
# Create entity using from_models
entity = _PrivateWorkflowPauseEntity.from_models(
workflow_pause_model=mock_pause_model,
)
# Verify entity creation
assert isinstance(entity, _PrivateWorkflowPauseEntity)
assert entity._pause_model is mock_pause_model
def test_id_property(self):
"""Test id property returns pause model ID."""
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.id = "pause-123"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
assert entity.id == "pause-123"
def test_workflow_execution_id_property(self):
"""Test workflow_execution_id property returns workflow run ID."""
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.workflow_run_id = "execution-456"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
assert entity.workflow_execution_id == "execution-456"
def test_resumed_at_property(self):
"""Test resumed_at property returns pause model resumed_at."""
resumed_at = datetime(2023, 12, 25, 15, 30, 45)
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.resumed_at = resumed_at
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
assert entity.resumed_at == resumed_at
def test_resumed_at_property_none(self):
"""Test resumed_at property returns None when not set."""
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.resumed_at = None
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
assert entity.resumed_at is None
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
def test_get_state_first_call(self, mock_storage):
"""Test get_state loads from storage on first call."""
state_data = b'{"test": "data", "step": 5}'
mock_storage.load.return_value = state_data
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.state_object_key = "test-state-key"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
# First call should load from storage
result = entity.get_state()
assert result == state_data
mock_storage.load.assert_called_once_with("test-state-key")
assert entity._cached_state == state_data
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
def test_get_state_cached_call(self, mock_storage):
"""Test get_state returns cached data on subsequent calls."""
state_data = b'{"test": "data", "step": 5}'
mock_storage.load.return_value = state_data
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.state_object_key = "test-state-key"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
# First call
result1 = entity.get_state()
# Second call should use cache
result2 = entity.get_state()
assert result1 == state_data
assert result2 == state_data
# Storage should only be called once
mock_storage.load.assert_called_once_with("test-state-key")
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
def test_get_state_with_pre_cached_data(self, mock_storage):
"""Test get_state returns pre-cached data."""
state_data = b'{"test": "data", "step": 5}'
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
# Pre-cache data
entity._cached_state = state_data
# Should return cached data without calling storage
result = entity.get_state()
assert result == state_data
mock_storage.load.assert_not_called()
def test_entity_with_binary_state_data(self):
"""Test entity with binary state data."""
# Test with binary data that's not valid JSON
binary_data = b"\x00\x01\x02\x03\x04\x05\xff\xfe"
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
mock_storage.load.return_value = binary_data
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
result = entity.get_state()
assert result == binary_data

View File

@@ -0,0 +1,87 @@
"""Tests for template module."""
from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment
class TestTemplate:
"""Test Template class functionality."""
def test_from_answer_template_simple(self):
"""Test parsing a simple answer template."""
template_str = "Hello, {{#node1.name#}}!"
template = Template.from_answer_template(template_str)
assert len(template.segments) == 3
assert isinstance(template.segments[0], TextSegment)
assert template.segments[0].text == "Hello, "
assert isinstance(template.segments[1], VariableSegment)
assert template.segments[1].selector == ["node1", "name"]
assert isinstance(template.segments[2], TextSegment)
assert template.segments[2].text == "!"
def test_from_answer_template_multiple_vars(self):
"""Test parsing an answer template with multiple variables."""
template_str = "Hello {{#node1.name#}}, your age is {{#node2.age#}}."
template = Template.from_answer_template(template_str)
assert len(template.segments) == 5
assert isinstance(template.segments[0], TextSegment)
assert template.segments[0].text == "Hello "
assert isinstance(template.segments[1], VariableSegment)
assert template.segments[1].selector == ["node1", "name"]
assert isinstance(template.segments[2], TextSegment)
assert template.segments[2].text == ", your age is "
assert isinstance(template.segments[3], VariableSegment)
assert template.segments[3].selector == ["node2", "age"]
assert isinstance(template.segments[4], TextSegment)
assert template.segments[4].text == "."
def test_from_answer_template_no_vars(self):
"""Test parsing an answer template with no variables."""
template_str = "Hello, world!"
template = Template.from_answer_template(template_str)
assert len(template.segments) == 1
assert isinstance(template.segments[0], TextSegment)
assert template.segments[0].text == "Hello, world!"
def test_from_end_outputs_single(self):
"""Test creating template from End node outputs with single variable."""
outputs_config = [{"variable": "text", "value_selector": ["node1", "text"]}]
template = Template.from_end_outputs(outputs_config)
assert len(template.segments) == 1
assert isinstance(template.segments[0], VariableSegment)
assert template.segments[0].selector == ["node1", "text"]
def test_from_end_outputs_multiple(self):
"""Test creating template from End node outputs with multiple variables."""
outputs_config = [
{"variable": "text", "value_selector": ["node1", "text"]},
{"variable": "result", "value_selector": ["node2", "result"]},
]
template = Template.from_end_outputs(outputs_config)
assert len(template.segments) == 3
assert isinstance(template.segments[0], VariableSegment)
assert template.segments[0].selector == ["node1", "text"]
assert template.segments[0].variable_name == "text"
assert isinstance(template.segments[1], TextSegment)
assert template.segments[1].text == "\n"
assert isinstance(template.segments[2], VariableSegment)
assert template.segments[2].selector == ["node2", "result"]
assert template.segments[2].variable_name == "result"
def test_from_end_outputs_empty(self):
"""Test creating template from empty End node outputs."""
outputs_config = []
template = Template.from_end_outputs(outputs_config)
assert len(template.segments) == 0
def test_template_str_representation(self):
"""Test string representation of template."""
template_str = "Hello, {{#node1.name#}}!"
template = Template.from_answer_template(template_str)
assert str(template) == template_str

View File

@@ -0,0 +1,136 @@
from core.variables.segments import (
BooleanSegment,
IntegerSegment,
NoneSegment,
StringSegment,
)
from core.workflow.runtime import VariablePool
class TestVariablePoolGetAndNestedAttribute:
#
# _get_nested_attribute tests
#
def test__get_nested_attribute_existing_key(self):
pool = VariablePool.empty()
obj = {"a": 123}
segment = pool._get_nested_attribute(obj, "a")
assert segment is not None
assert segment.value == 123
def test__get_nested_attribute_missing_key(self):
pool = VariablePool.empty()
obj = {"a": 123}
segment = pool._get_nested_attribute(obj, "b")
assert segment is None
def test__get_nested_attribute_non_dict(self):
pool = VariablePool.empty()
obj = ["not", "a", "dict"]
segment = pool._get_nested_attribute(obj, "a")
assert segment is None
def test__get_nested_attribute_with_none_value(self):
pool = VariablePool.empty()
obj = {"a": None}
segment = pool._get_nested_attribute(obj, "a")
assert segment is not None
assert isinstance(segment, NoneSegment)
def test__get_nested_attribute_with_empty_string(self):
pool = VariablePool.empty()
obj = {"a": ""}
segment = pool._get_nested_attribute(obj, "a")
assert segment is not None
assert isinstance(segment, StringSegment)
assert segment.value == ""
#
# get tests
#
def test_get_simple_variable(self):
pool = VariablePool.empty()
pool.add(("node1", "var1"), "value1")
segment = pool.get(("node1", "var1"))
assert segment is not None
assert segment.value == "value1"
def test_get_missing_variable(self):
pool = VariablePool.empty()
result = pool.get(("node1", "unknown"))
assert result is None
def test_get_with_too_short_selector(self):
pool = VariablePool.empty()
result = pool.get(("only_node",))
assert result is None
def test_get_nested_object_attribute(self):
pool = VariablePool.empty()
obj_value = {"inner": "hello"}
pool.add(("node1", "obj"), obj_value)
# simulate selector with nested attr
segment = pool.get(("node1", "obj", "inner"))
assert segment is not None
assert segment.value == "hello"
def test_get_nested_object_missing_attribute(self):
pool = VariablePool.empty()
obj_value = {"inner": "hello"}
pool.add(("node1", "obj"), obj_value)
result = pool.get(("node1", "obj", "not_exist"))
assert result is None
def test_get_nested_object_attribute_with_falsy_values(self):
pool = VariablePool.empty()
obj_value = {
"inner_none": None,
"inner_empty": "",
"inner_zero": 0,
"inner_false": False,
}
pool.add(("node1", "obj"), obj_value)
segment_none = pool.get(("node1", "obj", "inner_none"))
assert segment_none is not None
assert isinstance(segment_none, NoneSegment)
segment_empty = pool.get(("node1", "obj", "inner_empty"))
assert segment_empty is not None
assert isinstance(segment_empty, StringSegment)
assert segment_empty.value == ""
segment_zero = pool.get(("node1", "obj", "inner_zero"))
assert segment_zero is not None
assert isinstance(segment_zero, IntegerSegment)
assert segment_zero.value == 0
segment_false = pool.get(("node1", "obj", "inner_false"))
assert segment_false is not None
assert isinstance(segment_false, BooleanSegment)
assert segment_false.value is False
class TestVariablePoolGetNotModifyVariableDictionary:
_NODE_ID = "start"
_VAR_NAME = "name"
def test_convert_to_template_should_not_introduce_extra_keys(self):
pool = VariablePool.empty()
pool.add([self._NODE_ID, self._VAR_NAME], 0)
pool.convert_template("The start.name is {{#start.name#}}")
assert "The start" not in pool.variable_dictionary
def test_get_should_not_modify_variable_dictionary(self):
pool = VariablePool.empty()
pool.get([self._NODE_ID, self._VAR_NAME])
assert len(pool.variable_dictionary) == 1 # only contains `sys` node id
assert "start" not in pool.variable_dictionary
pool = VariablePool.empty()
pool.add([self._NODE_ID, self._VAR_NAME], "Joe")
pool.get([self._NODE_ID, "count"])
start_subdict = pool.variable_dictionary[self._NODE_ID]
assert "count" not in start_subdict

View File

@@ -0,0 +1,225 @@
"""
Unit tests for WorkflowNodeExecution domain model, focusing on process_data truncation functionality.
"""
from dataclasses import dataclass
from datetime import datetime
from typing import Any
import pytest
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.enums import NodeType
class TestWorkflowNodeExecutionProcessDataTruncation:
"""Test process_data truncation functionality in WorkflowNodeExecution domain model."""
def create_workflow_node_execution(
self,
process_data: dict[str, Any] | None = None,
) -> WorkflowNodeExecution:
"""Create a WorkflowNodeExecution instance for testing."""
return WorkflowNodeExecution(
id="test-execution-id",
workflow_id="test-workflow-id",
index=1,
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
process_data=process_data,
created_at=datetime.now(),
)
def test_initial_process_data_truncated_state(self):
"""Test that process_data_truncated returns False initially."""
execution = self.create_workflow_node_execution()
assert execution.process_data_truncated is False
assert execution.get_truncated_process_data() is None
def test_set_and_get_truncated_process_data(self):
"""Test setting and getting truncated process_data."""
execution = self.create_workflow_node_execution()
test_truncated_data = {"truncated": True, "key": "value"}
execution.set_truncated_process_data(test_truncated_data)
assert execution.process_data_truncated is True
assert execution.get_truncated_process_data() == test_truncated_data
def test_set_truncated_process_data_to_none(self):
"""Test setting truncated process_data to None."""
execution = self.create_workflow_node_execution()
# First set some data
execution.set_truncated_process_data({"key": "value"})
assert execution.process_data_truncated is True
# Then set to None
execution.set_truncated_process_data(None)
assert execution.process_data_truncated is False
assert execution.get_truncated_process_data() is None
def test_get_response_process_data_with_no_truncation(self):
"""Test get_response_process_data when no truncation is set."""
original_data = {"original": True, "data": "value"}
execution = self.create_workflow_node_execution(process_data=original_data)
response_data = execution.get_response_process_data()
assert response_data == original_data
assert execution.process_data_truncated is False
def test_get_response_process_data_with_truncation(self):
"""Test get_response_process_data when truncation is set."""
original_data = {"original": True, "large_data": "x" * 10000}
truncated_data = {"original": True, "large_data": "[TRUNCATED]"}
execution = self.create_workflow_node_execution(process_data=original_data)
execution.set_truncated_process_data(truncated_data)
response_data = execution.get_response_process_data()
# Should return truncated data, not original
assert response_data == truncated_data
assert response_data != original_data
assert execution.process_data_truncated is True
def test_get_response_process_data_with_none_process_data(self):
"""Test get_response_process_data when process_data is None."""
execution = self.create_workflow_node_execution(process_data=None)
response_data = execution.get_response_process_data()
assert response_data is None
assert execution.process_data_truncated is False
def test_consistency_with_inputs_outputs_pattern(self):
"""Test that process_data truncation follows the same pattern as inputs/outputs."""
execution = self.create_workflow_node_execution()
# Test that all truncation methods exist and behave consistently
test_data = {"test": "data"}
# Test inputs truncation
execution.set_truncated_inputs(test_data)
assert execution.inputs_truncated is True
assert execution.get_truncated_inputs() == test_data
# Test outputs truncation
execution.set_truncated_outputs(test_data)
assert execution.outputs_truncated is True
assert execution.get_truncated_outputs() == test_data
# Test process_data truncation
execution.set_truncated_process_data(test_data)
assert execution.process_data_truncated is True
assert execution.get_truncated_process_data() == test_data
@pytest.mark.parametrize(
"test_data",
[
{"simple": "value"},
{"nested": {"key": "value"}},
{"list": [1, 2, 3]},
{"mixed": {"string": "value", "number": 42, "list": [1, 2]}},
{}, # empty dict
],
)
def test_truncated_process_data_with_various_data_types(self, test_data):
"""Test that truncated process_data works with various data types."""
execution = self.create_workflow_node_execution()
execution.set_truncated_process_data(test_data)
assert execution.process_data_truncated is True
assert execution.get_truncated_process_data() == test_data
assert execution.get_response_process_data() == test_data
@dataclass
class ProcessDataScenario:
"""Test scenario data for process_data functionality."""
name: str
original_data: dict[str, Any] | None
truncated_data: dict[str, Any] | None
expected_truncated_flag: bool
expected_response_data: dict[str, Any] | None
class TestWorkflowNodeExecutionProcessDataScenarios:
"""Test various scenarios for process_data handling."""
def get_process_data_scenarios(self) -> list[ProcessDataScenario]:
"""Create test scenarios for process_data functionality."""
return [
ProcessDataScenario(
name="no_process_data",
original_data=None,
truncated_data=None,
expected_truncated_flag=False,
expected_response_data=None,
),
ProcessDataScenario(
name="process_data_without_truncation",
original_data={"small": "data"},
truncated_data=None,
expected_truncated_flag=False,
expected_response_data={"small": "data"},
),
ProcessDataScenario(
name="process_data_with_truncation",
original_data={"large": "x" * 10000, "metadata": "info"},
truncated_data={"large": "[TRUNCATED]", "metadata": "info"},
expected_truncated_flag=True,
expected_response_data={"large": "[TRUNCATED]", "metadata": "info"},
),
ProcessDataScenario(
name="empty_process_data",
original_data={},
truncated_data=None,
expected_truncated_flag=False,
expected_response_data={},
),
ProcessDataScenario(
name="complex_nested_data_with_truncation",
original_data={
"config": {"setting": "value"},
"logs": ["log1", "log2"] * 1000, # Large list
"status": "running",
},
truncated_data={"config": {"setting": "value"}, "logs": "[TRUNCATED: 2000 items]", "status": "running"},
expected_truncated_flag=True,
expected_response_data={
"config": {"setting": "value"},
"logs": "[TRUNCATED: 2000 items]",
"status": "running",
},
),
]
@pytest.mark.parametrize(
"scenario",
get_process_data_scenarios(None),
ids=[scenario.name for scenario in get_process_data_scenarios(None)],
)
def test_process_data_scenarios(self, scenario: ProcessDataScenario):
"""Test various process_data scenarios."""
execution = WorkflowNodeExecution(
id="test-execution-id",
workflow_id="test-workflow-id",
index=1,
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
process_data=scenario.original_data,
created_at=datetime.now(),
)
if scenario.truncated_data is not None:
execution.set_truncated_process_data(scenario.truncated_data)
assert execution.process_data_truncated == scenario.expected_truncated_flag
assert execution.get_response_process_data() == scenario.expected_response_data

View File

@@ -0,0 +1,281 @@
"""Unit tests for Graph class methods."""
from unittest.mock import Mock
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.graph.edge import Edge
from core.workflow.graph.graph import Graph
from core.workflow.nodes.base.node import Node
def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node:
"""Create a mock node for testing."""
node = Mock(spec=Node)
node.id = node_id
node.execution_type = execution_type
node.state = state
node.node_type = NodeType.START
return node
class TestMarkInactiveRootBranches:
"""Test cases for _mark_inactive_root_branches method."""
def test_single_root_no_marking(self):
"""Test that single root graph doesn't mark anything as skipped."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
}
edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
}
in_edges = {"child1": ["edge1"]}
out_edges = {"root1": ["edge1"]}
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["child1"].state == NodeState.UNKNOWN
assert edges["edge1"].state == NodeState.UNKNOWN
def test_multiple_roots_mark_inactive(self):
"""Test marking inactive root branches with multiple root nodes."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
}
edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
}
in_edges = {"child1": ["edge1"], "child2": ["edge2"]}
out_edges = {"root1": ["edge1"], "root2": ["edge2"]}
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["root2"].state == NodeState.SKIPPED
assert nodes["child1"].state == NodeState.UNKNOWN
assert nodes["child2"].state == NodeState.SKIPPED
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.SKIPPED
def test_shared_downstream_node(self):
"""Test that shared downstream nodes are not skipped if at least one path is active."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
"shared": create_mock_node("shared", NodeExecutionType.EXECUTABLE),
}
edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
"edge3": Edge(id="edge3", tail="child1", head="shared", source_handle="source"),
"edge4": Edge(id="edge4", tail="child2", head="shared", source_handle="source"),
}
in_edges = {
"child1": ["edge1"],
"child2": ["edge2"],
"shared": ["edge3", "edge4"],
}
out_edges = {
"root1": ["edge1"],
"root2": ["edge2"],
"child1": ["edge3"],
"child2": ["edge4"],
}
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["root2"].state == NodeState.SKIPPED
assert nodes["child1"].state == NodeState.UNKNOWN
assert nodes["child2"].state == NodeState.SKIPPED
assert nodes["shared"].state == NodeState.UNKNOWN # Not skipped because edge3 is active
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.SKIPPED
assert edges["edge3"].state == NodeState.UNKNOWN
assert edges["edge4"].state == NodeState.SKIPPED
def test_deep_branch_marking(self):
"""Test marking deep branches with multiple levels."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"level1_a": create_mock_node("level1_a", NodeExecutionType.EXECUTABLE),
"level1_b": create_mock_node("level1_b", NodeExecutionType.EXECUTABLE),
"level2_a": create_mock_node("level2_a", NodeExecutionType.EXECUTABLE),
"level2_b": create_mock_node("level2_b", NodeExecutionType.EXECUTABLE),
"level3": create_mock_node("level3", NodeExecutionType.EXECUTABLE),
}
edges = {
"edge1": Edge(id="edge1", tail="root1", head="level1_a", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="level1_b", source_handle="source"),
"edge3": Edge(id="edge3", tail="level1_a", head="level2_a", source_handle="source"),
"edge4": Edge(id="edge4", tail="level1_b", head="level2_b", source_handle="source"),
"edge5": Edge(id="edge5", tail="level2_b", head="level3", source_handle="source"),
}
in_edges = {
"level1_a": ["edge1"],
"level1_b": ["edge2"],
"level2_a": ["edge3"],
"level2_b": ["edge4"],
"level3": ["edge5"],
}
out_edges = {
"root1": ["edge1"],
"root2": ["edge2"],
"level1_a": ["edge3"],
"level1_b": ["edge4"],
"level2_b": ["edge5"],
}
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["root2"].state == NodeState.SKIPPED
assert nodes["level1_a"].state == NodeState.UNKNOWN
assert nodes["level1_b"].state == NodeState.SKIPPED
assert nodes["level2_a"].state == NodeState.UNKNOWN
assert nodes["level2_b"].state == NodeState.SKIPPED
assert nodes["level3"].state == NodeState.SKIPPED
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.SKIPPED
assert edges["edge3"].state == NodeState.UNKNOWN
assert edges["edge4"].state == NodeState.SKIPPED
assert edges["edge5"].state == NodeState.SKIPPED
def test_non_root_execution_type(self):
"""Test that nodes with non-ROOT execution type are not treated as root nodes."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"non_root": create_mock_node("non_root", NodeExecutionType.EXECUTABLE),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
}
edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
"edge2": Edge(id="edge2", tail="non_root", head="child2", source_handle="source"),
}
in_edges = {"child1": ["edge1"], "child2": ["edge2"]}
out_edges = {"root1": ["edge1"], "non_root": ["edge2"]}
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["non_root"].state == NodeState.UNKNOWN # Not marked as skipped
assert nodes["child1"].state == NodeState.UNKNOWN
assert nodes["child2"].state == NodeState.UNKNOWN
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.UNKNOWN
def test_empty_graph(self):
"""Test handling of empty graph structures."""
nodes = {}
edges = {}
in_edges = {}
out_edges = {}
# Should not raise any errors
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "non_existent")
def test_three_roots_mark_two_inactive(self):
"""Test with three root nodes where two should be marked inactive."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"root3": create_mock_node("root3", NodeExecutionType.ROOT),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
"child3": create_mock_node("child3", NodeExecutionType.EXECUTABLE),
}
edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
"edge3": Edge(id="edge3", tail="root3", head="child3", source_handle="source"),
}
in_edges = {
"child1": ["edge1"],
"child2": ["edge2"],
"child3": ["edge3"],
}
out_edges = {
"root1": ["edge1"],
"root2": ["edge2"],
"root3": ["edge3"],
}
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root2")
assert nodes["root1"].state == NodeState.SKIPPED
assert nodes["root2"].state == NodeState.UNKNOWN # Active root
assert nodes["root3"].state == NodeState.SKIPPED
assert nodes["child1"].state == NodeState.SKIPPED
assert nodes["child2"].state == NodeState.UNKNOWN
assert nodes["child3"].state == NodeState.SKIPPED
assert edges["edge1"].state == NodeState.SKIPPED
assert edges["edge2"].state == NodeState.UNKNOWN
assert edges["edge3"].state == NodeState.SKIPPED
def test_convergent_paths(self):
"""Test convergent paths where multiple inactive branches lead to same node."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"root3": create_mock_node("root3", NodeExecutionType.ROOT),
"mid1": create_mock_node("mid1", NodeExecutionType.EXECUTABLE),
"mid2": create_mock_node("mid2", NodeExecutionType.EXECUTABLE),
"convergent": create_mock_node("convergent", NodeExecutionType.EXECUTABLE),
}
edges = {
"edge1": Edge(id="edge1", tail="root1", head="mid1", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="mid2", source_handle="source"),
"edge3": Edge(id="edge3", tail="root3", head="convergent", source_handle="source"),
"edge4": Edge(id="edge4", tail="mid1", head="convergent", source_handle="source"),
"edge5": Edge(id="edge5", tail="mid2", head="convergent", source_handle="source"),
}
in_edges = {
"mid1": ["edge1"],
"mid2": ["edge2"],
"convergent": ["edge3", "edge4", "edge5"],
}
out_edges = {
"root1": ["edge1"],
"root2": ["edge2"],
"root3": ["edge3"],
"mid1": ["edge4"],
"mid2": ["edge5"],
}
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["root2"].state == NodeState.SKIPPED
assert nodes["root3"].state == NodeState.SKIPPED
assert nodes["mid1"].state == NodeState.UNKNOWN
assert nodes["mid2"].state == NodeState.SKIPPED
assert nodes["convergent"].state == NodeState.UNKNOWN # Not skipped due to active path from root1
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.SKIPPED
assert edges["edge3"].state == NodeState.SKIPPED
assert edges["edge4"].state == NodeState.UNKNOWN
assert edges["edge5"].state == NodeState.SKIPPED

View File

@@ -0,0 +1,59 @@
from unittest.mock import MagicMock
import pytest
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
from core.workflow.nodes.base.node import Node
def _make_node(node_id: str, node_type: NodeType = NodeType.START) -> Node:
node = MagicMock(spec=Node)
node.id = node_id
node.node_type = node_type
node.execution_type = None # attribute not used in builder path
return node
def test_graph_builder_creates_linear_graph():
builder = Graph.new()
root = _make_node("root", NodeType.START)
mid = _make_node("mid", NodeType.LLM)
end = _make_node("end", NodeType.END)
graph = builder.add_root(root).add_node(mid).add_node(end).build()
assert graph.root_node is root
assert graph.nodes == {"root": root, "mid": mid, "end": end}
assert len(graph.edges) == 2
first_edge = next(iter(graph.edges.values()))
assert first_edge.tail == "root"
assert first_edge.head == "mid"
assert graph.out_edges["mid"] == [edge_id for edge_id, edge in graph.edges.items() if edge.tail == "mid"]
def test_graph_builder_supports_custom_predecessor():
builder = Graph.new()
root = _make_node("root")
branch = _make_node("branch")
other = _make_node("other")
graph = builder.add_root(root).add_node(branch).add_node(other, from_node_id="root").build()
outgoing_root = graph.out_edges["root"]
assert len(outgoing_root) == 2
edge_targets = {graph.edges[eid].head for eid in outgoing_root}
assert edge_targets == {"branch", "other"}
def test_graph_builder_validates_usage():
builder = Graph.new()
node = _make_node("node")
with pytest.raises(ValueError, match="Root node"):
builder.add_node(node)
builder.add_root(node)
duplicate = _make_node("node")
with pytest.raises(ValueError, match="Duplicate"):
builder.add_node(duplicate)

Some files were not shown because too many files have changed in this diff Show More