dify
This commit is contained in:
0
dify/api/tests/unit_tests/core/__init__.py
Normal file
0
dify/api/tests/unit_tests/core/__init__.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta
|
||||
|
||||
|
||||
def mock_llm_response(text) -> Generator[LLMResultChunk, None, None]:
|
||||
for i in range(len(text)):
|
||||
yield LLMResultChunk(
|
||||
model="model",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text[i], tool_calls=[])),
|
||||
)
|
||||
|
||||
|
||||
def test_cot_output_parser():
|
||||
test_cases = [
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# code block with json
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```json\n{"action": "Final Answer", "action_input": "```echarts\n {'
|
||||
'}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# code block with JSON
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```JSON\n{"action": "Final Answer", "action_input": "```echarts\n {'
|
||||
'}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# list
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```[{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}]```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# no code block
|
||||
{
|
||||
"input": 'Through: abc\nAction: {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# no code block and json
|
||||
{"input": "Through: abc\nAction: efg", "action": {}, "output": "Through: abc\n efg"},
|
||||
]
|
||||
|
||||
parser = CotAgentOutputParser()
|
||||
usage_dict = {}
|
||||
for test_case in test_cases:
|
||||
# mock llm_response as a generator by text
|
||||
llm_response: Generator[LLMResultChunk, None, None] = mock_llm_response(test_case["input"])
|
||||
results = parser.handle_react_stream_output(llm_response, usage_dict)
|
||||
output = ""
|
||||
for result in results:
|
||||
if isinstance(result, str):
|
||||
output += result
|
||||
elif isinstance(result, AgentScratchpadUnit.Action):
|
||||
if test_case["action"]:
|
||||
assert result.to_dict() == test_case["action"]
|
||||
output += json.dumps(result.to_dict())
|
||||
if test_case["output"]:
|
||||
assert output == test_case["output"]
|
||||
@@ -0,0 +1,65 @@
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file.models import FileTransferMethod, FileUploadConfig, ImageConfig
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
|
||||
def test_convert_with_vision():
|
||||
config = {
|
||||
"file_upload": {
|
||||
"enabled": True,
|
||||
"number_limits": 5,
|
||||
"allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL],
|
||||
"image": {"detail": "high"},
|
||||
}
|
||||
}
|
||||
result = FileUploadConfigManager.convert(config, is_vision=True)
|
||||
expected = FileUploadConfig(
|
||||
image_config=ImageConfig(
|
||||
number_limits=5,
|
||||
transfer_methods=[FileTransferMethod.REMOTE_URL],
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
),
|
||||
allowed_file_upload_methods=[FileTransferMethod.REMOTE_URL],
|
||||
number_limits=5,
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_convert_without_vision():
|
||||
config = {
|
||||
"file_upload": {
|
||||
"enabled": True,
|
||||
"number_limits": 5,
|
||||
"allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL],
|
||||
}
|
||||
}
|
||||
result = FileUploadConfigManager.convert(config, is_vision=False)
|
||||
expected = FileUploadConfig(
|
||||
image_config=ImageConfig(number_limits=5, transfer_methods=[FileTransferMethod.REMOTE_URL]),
|
||||
allowed_file_upload_methods=[FileTransferMethod.REMOTE_URL],
|
||||
number_limits=5,
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_validate_and_set_defaults():
|
||||
config = {}
|
||||
result, keys = FileUploadConfigManager.validate_and_set_defaults(config)
|
||||
assert "file_upload" in result
|
||||
assert keys == ["file_upload"]
|
||||
|
||||
|
||||
def test_validate_and_set_defaults_with_existing_config():
|
||||
config = {
|
||||
"file_upload": {
|
||||
"enabled": True,
|
||||
"number_limits": 5,
|
||||
"allowed_file_upload_methods": [FileTransferMethod.REMOTE_URL],
|
||||
}
|
||||
}
|
||||
result, keys = FileUploadConfigManager.validate_and_set_defaults(config)
|
||||
assert "file_upload" in result
|
||||
assert keys == ["file_upload"]
|
||||
assert result["file_upload"]["enabled"] is True
|
||||
assert result["file_upload"]["number_limits"] == 5
|
||||
assert result["file_upload"]["allowed_file_upload_methods"] == [FileTransferMethod.REMOTE_URL]
|
||||
@@ -0,0 +1,443 @@
|
||||
"""Test conversation variable handling in AdvancedChatAppRunner."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.variables import SegmentType
|
||||
from factories import variable_factory
|
||||
from models import ConversationVariable, Workflow
|
||||
|
||||
|
||||
class TestAdvancedChatAppRunnerConversationVariables:
|
||||
"""Test that AdvancedChatAppRunner correctly handles conversation variables."""
|
||||
|
||||
def test_missing_conversation_variables_are_added(self):
|
||||
"""Test that new conversation variables added to workflow are created for existing conversations."""
|
||||
# Setup
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
|
||||
# Create workflow with two conversation variables
|
||||
workflow_vars = [
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var1",
|
||||
"name": "existing_var",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default1",
|
||||
}
|
||||
),
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var2",
|
||||
"name": "new_var",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default2",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow with conversation variables
|
||||
mock_workflow = MagicMock(spec=Workflow)
|
||||
mock_workflow.conversation_variables = workflow_vars
|
||||
mock_workflow.tenant_id = str(uuid4())
|
||||
mock_workflow.app_id = app_id
|
||||
mock_workflow.id = workflow_id
|
||||
mock_workflow.type = "chat"
|
||||
mock_workflow.graph_dict = {}
|
||||
mock_workflow.environment_variables = []
|
||||
|
||||
# Create existing conversation variable (only var1 exists in DB)
|
||||
existing_db_var = MagicMock(spec=ConversationVariable)
|
||||
existing_db_var.id = "var1"
|
||||
existing_db_var.app_id = app_id
|
||||
existing_db_var.conversation_id = conversation_id
|
||||
existing_db_var.to_variable = MagicMock(return_value=workflow_vars[0])
|
||||
|
||||
# Mock conversation and message
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.app_id = app_id
|
||||
mock_conversation.id = conversation_id
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = str(uuid4())
|
||||
|
||||
# Mock app config
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.app_id = app_id
|
||||
mock_app_config.workflow_id = workflow_id
|
||||
mock_app_config.tenant_id = str(uuid4())
|
||||
|
||||
# Mock app generate entity
|
||||
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
|
||||
mock_app_generate_entity.app_config = mock_app_config
|
||||
mock_app_generate_entity.inputs = {}
|
||||
mock_app_generate_entity.query = "test query"
|
||||
mock_app_generate_entity.files = []
|
||||
mock_app_generate_entity.user_id = str(uuid4())
|
||||
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
mock_app_generate_entity.workflow_run_id = str(uuid4())
|
||||
mock_app_generate_entity.task_id = str(uuid4())
|
||||
mock_app_generate_entity.call_depth = 0
|
||||
mock_app_generate_entity.single_iteration_run = None
|
||||
mock_app_generate_entity.single_loop_run = None
|
||||
mock_app_generate_entity.trace_manager = None
|
||||
|
||||
# Create runner
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=mock_app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
dialogue_count=1,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
mock_session = MagicMock(spec=Session)
|
||||
|
||||
# First query returns only existing variable
|
||||
mock_scalars_result = MagicMock()
|
||||
mock_scalars_result.all.return_value = [existing_db_var]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Track what gets added to session
|
||||
added_items = []
|
||||
|
||||
def track_add_all(items):
|
||||
added_items.extend(items)
|
||||
|
||||
mock_session.add_all.side_effect = track_add_all
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
mock_graph_runtime_state_class.return_value = MagicMock()
|
||||
|
||||
# Mock graph initialization
|
||||
mock_init_graph.return_value = MagicMock()
|
||||
|
||||
# Mock workflow entry
|
||||
mock_workflow_entry = MagicMock()
|
||||
mock_workflow_entry.run.return_value = iter([]) # Empty generator
|
||||
mock_workflow_entry_class.return_value = mock_workflow_entry
|
||||
|
||||
# Run the method
|
||||
runner.run()
|
||||
|
||||
# Verify that the missing variable was added
|
||||
assert len(added_items) == 1, "Should have added exactly one missing variable"
|
||||
|
||||
# Check that the added item is the missing variable (var2)
|
||||
added_var = added_items[0]
|
||||
assert hasattr(added_var, "id"), "Added item should be a ConversationVariable"
|
||||
# Note: Since we're mocking ConversationVariable.from_variable,
|
||||
# we can't directly check the id, but we can verify add_all was called
|
||||
assert mock_session.add_all.called, "Session add_all should have been called"
|
||||
assert mock_session.commit.called, "Session commit should have been called"
|
||||
|
||||
def test_no_variables_creates_all(self):
|
||||
"""Test that all conversation variables are created when none exist in DB."""
|
||||
# Setup
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
|
||||
# Create workflow with conversation variables
|
||||
workflow_vars = [
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var1",
|
||||
"name": "var1",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default1",
|
||||
}
|
||||
),
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var2",
|
||||
"name": "var2",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default2",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow
|
||||
mock_workflow = MagicMock(spec=Workflow)
|
||||
mock_workflow.conversation_variables = workflow_vars
|
||||
mock_workflow.tenant_id = str(uuid4())
|
||||
mock_workflow.app_id = app_id
|
||||
mock_workflow.id = workflow_id
|
||||
mock_workflow.type = "chat"
|
||||
mock_workflow.graph_dict = {}
|
||||
mock_workflow.environment_variables = []
|
||||
|
||||
# Mock conversation and message
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.app_id = app_id
|
||||
mock_conversation.id = conversation_id
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = str(uuid4())
|
||||
|
||||
# Mock app config
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.app_id = app_id
|
||||
mock_app_config.workflow_id = workflow_id
|
||||
mock_app_config.tenant_id = str(uuid4())
|
||||
|
||||
# Mock app generate entity
|
||||
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
|
||||
mock_app_generate_entity.app_config = mock_app_config
|
||||
mock_app_generate_entity.inputs = {}
|
||||
mock_app_generate_entity.query = "test query"
|
||||
mock_app_generate_entity.files = []
|
||||
mock_app_generate_entity.user_id = str(uuid4())
|
||||
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
mock_app_generate_entity.workflow_run_id = str(uuid4())
|
||||
mock_app_generate_entity.task_id = str(uuid4())
|
||||
mock_app_generate_entity.call_depth = 0
|
||||
mock_app_generate_entity.single_iteration_run = None
|
||||
mock_app_generate_entity.single_loop_run = None
|
||||
mock_app_generate_entity.trace_manager = None
|
||||
|
||||
# Create runner
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=mock_app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
dialogue_count=1,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
mock_session = MagicMock(spec=Session)
|
||||
|
||||
# Query returns empty list (no existing variables)
|
||||
mock_scalars_result = MagicMock()
|
||||
mock_scalars_result.all.return_value = []
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Track what gets added to session
|
||||
added_items = []
|
||||
|
||||
def track_add_all(items):
|
||||
added_items.extend(items)
|
||||
|
||||
mock_session.add_all.side_effect = track_add_all
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.ConversationVariable") as mock_conv_var_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock ConversationVariable.from_variable to return mock objects
|
||||
mock_conv_vars = []
|
||||
for var in workflow_vars:
|
||||
mock_cv = MagicMock()
|
||||
mock_cv.id = var.id
|
||||
mock_cv.to_variable.return_value = var
|
||||
mock_conv_vars.append(mock_cv)
|
||||
|
||||
mock_conv_var_class.from_variable.side_effect = mock_conv_vars
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
mock_graph_runtime_state_class.return_value = MagicMock()
|
||||
|
||||
# Mock graph initialization
|
||||
mock_init_graph.return_value = MagicMock()
|
||||
|
||||
# Mock workflow entry
|
||||
mock_workflow_entry = MagicMock()
|
||||
mock_workflow_entry.run.return_value = iter([]) # Empty generator
|
||||
mock_workflow_entry_class.return_value = mock_workflow_entry
|
||||
|
||||
# Run the method
|
||||
runner.run()
|
||||
|
||||
# Verify that all variables were created
|
||||
assert len(added_items) == 2, "Should have added both variables"
|
||||
assert mock_session.add_all.called, "Session add_all should have been called"
|
||||
assert mock_session.commit.called, "Session commit should have been called"
|
||||
|
||||
def test_all_variables_exist_no_changes(self):
|
||||
"""Test that no changes are made when all variables already exist in DB."""
|
||||
# Setup
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
|
||||
# Create workflow with conversation variables
|
||||
workflow_vars = [
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var1",
|
||||
"name": "var1",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default1",
|
||||
}
|
||||
),
|
||||
variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": "var2",
|
||||
"name": "var2",
|
||||
"value_type": SegmentType.STRING,
|
||||
"value": "default2",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow
|
||||
mock_workflow = MagicMock(spec=Workflow)
|
||||
mock_workflow.conversation_variables = workflow_vars
|
||||
mock_workflow.tenant_id = str(uuid4())
|
||||
mock_workflow.app_id = app_id
|
||||
mock_workflow.id = workflow_id
|
||||
mock_workflow.type = "chat"
|
||||
mock_workflow.graph_dict = {}
|
||||
mock_workflow.environment_variables = []
|
||||
|
||||
# Create existing conversation variables (both exist in DB)
|
||||
existing_db_vars = []
|
||||
for var in workflow_vars:
|
||||
db_var = MagicMock(spec=ConversationVariable)
|
||||
db_var.id = var.id
|
||||
db_var.app_id = app_id
|
||||
db_var.conversation_id = conversation_id
|
||||
db_var.to_variable = MagicMock(return_value=var)
|
||||
existing_db_vars.append(db_var)
|
||||
|
||||
# Mock conversation and message
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.app_id = app_id
|
||||
mock_conversation.id = conversation_id
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = str(uuid4())
|
||||
|
||||
# Mock app config
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.app_id = app_id
|
||||
mock_app_config.workflow_id = workflow_id
|
||||
mock_app_config.tenant_id = str(uuid4())
|
||||
|
||||
# Mock app generate entity
|
||||
mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
|
||||
mock_app_generate_entity.app_config = mock_app_config
|
||||
mock_app_generate_entity.inputs = {}
|
||||
mock_app_generate_entity.query = "test query"
|
||||
mock_app_generate_entity.files = []
|
||||
mock_app_generate_entity.user_id = str(uuid4())
|
||||
mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API
|
||||
mock_app_generate_entity.workflow_run_id = str(uuid4())
|
||||
mock_app_generate_entity.task_id = str(uuid4())
|
||||
mock_app_generate_entity.call_depth = 0
|
||||
mock_app_generate_entity.single_iteration_run = None
|
||||
mock_app_generate_entity.single_loop_run = None
|
||||
mock_app_generate_entity.trace_manager = None
|
||||
|
||||
# Create runner
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=mock_app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
dialogue_count=1,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
mock_session = MagicMock(spec=Session)
|
||||
|
||||
# Query returns all existing variables
|
||||
mock_scalars_result = MagicMock()
|
||||
mock_scalars_result.all.return_value = existing_db_vars
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
mock_graph_runtime_state_class.return_value = MagicMock()
|
||||
|
||||
# Mock graph initialization
|
||||
mock_init_graph.return_value = MagicMock()
|
||||
|
||||
# Mock workflow entry
|
||||
mock_workflow_entry = MagicMock()
|
||||
mock_workflow_entry.run.return_value = iter([]) # Empty generator
|
||||
mock_workflow_entry_class.return_value = mock_workflow_entry
|
||||
|
||||
# Run the method
|
||||
runner.run()
|
||||
|
||||
# Verify that no variables were added
|
||||
assert not mock_session.add_all.called, "Session add_all should not have been called"
|
||||
assert mock_session.commit.called, "Session commit should still be called"
|
||||
@@ -0,0 +1,63 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def _make_state(workflow_run_id: str | None) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=workflow_run_id))
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
|
||||
|
||||
class _StubPipeline(GraphRuntimeStateSupport):
|
||||
def __init__(self, *, cached_state: GraphRuntimeState | None, queue_state: GraphRuntimeState | None):
|
||||
self._graph_runtime_state = cached_state
|
||||
self._base_task_pipeline = SimpleNamespace(queue_manager=SimpleNamespace(graph_runtime_state=queue_state))
|
||||
|
||||
|
||||
def test_ensure_graph_runtime_initialized_caches_explicit_state():
|
||||
explicit_state = _make_state("run-explicit")
|
||||
pipeline = _StubPipeline(cached_state=None, queue_state=None)
|
||||
|
||||
resolved = pipeline._ensure_graph_runtime_initialized(explicit_state)
|
||||
|
||||
assert resolved is explicit_state
|
||||
assert pipeline._graph_runtime_state is explicit_state
|
||||
|
||||
|
||||
def test_resolve_graph_runtime_state_reads_from_queue_when_cache_empty():
|
||||
queued_state = _make_state("run-queue")
|
||||
pipeline = _StubPipeline(cached_state=None, queue_state=queued_state)
|
||||
|
||||
resolved = pipeline._resolve_graph_runtime_state()
|
||||
|
||||
assert resolved is queued_state
|
||||
assert pipeline._graph_runtime_state is queued_state
|
||||
|
||||
|
||||
def test_resolve_graph_runtime_state_raises_when_no_state_available():
|
||||
pipeline = _StubPipeline(cached_state=None, queue_state=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
pipeline._resolve_graph_runtime_state()
|
||||
|
||||
|
||||
def test_extract_workflow_run_id_returns_value():
|
||||
state = _make_state("run-identifier")
|
||||
pipeline = _StubPipeline(cached_state=state, queue_state=None)
|
||||
|
||||
run_id = pipeline._extract_workflow_run_id(state)
|
||||
|
||||
assert run_id == "run-identifier"
|
||||
|
||||
|
||||
def test_extract_workflow_run_id_raises_when_missing():
|
||||
state = _make_state(None)
|
||||
pipeline = _StubPipeline(cached_state=state, queue_state=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
pipeline._extract_workflow_run_id(state)
|
||||
@@ -0,0 +1,259 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
|
||||
class TestWorkflowResponseConverterFetchFilesFromVariableValue:
|
||||
"""Test class for WorkflowResponseConverter._fetch_files_from_variable_value method"""
|
||||
|
||||
def create_test_file(self, file_id: str = "test_file_1") -> File:
|
||||
"""Create a test File object"""
|
||||
return File(
|
||||
id=file_id,
|
||||
tenant_id="test_tenant",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related_123",
|
||||
filename=f"{file_id}.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=1024,
|
||||
storage_key="storage_key_123",
|
||||
)
|
||||
|
||||
def create_file_dict(self, file_id: str = "test_file_dict"):
|
||||
"""Create a file dictionary with correct dify_model_identity"""
|
||||
return {
|
||||
"dify_model_identity": FILE_MODEL_IDENTITY,
|
||||
"id": file_id,
|
||||
"tenant_id": "test_tenant",
|
||||
"type": "document",
|
||||
"transfer_method": "local_file",
|
||||
"related_id": "related_456",
|
||||
"filename": f"{file_id}.txt",
|
||||
"extension": ".txt",
|
||||
"mime_type": "text/plain",
|
||||
"size": 2048,
|
||||
"url": "http://example.com/file.txt",
|
||||
}
|
||||
|
||||
def test_fetch_files_from_variable_value_with_none(self):
|
||||
"""Test with None input"""
|
||||
# The method signature expects Union[dict, list, Segment], but implementation handles None
|
||||
# We'll test the actual behavior by passing an empty dict instead
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(None)
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_empty_dict(self):
|
||||
"""Test with empty dictionary"""
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value({})
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_empty_list(self):
|
||||
"""Test with empty list"""
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value([])
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_file_segment(self):
|
||||
"""Test with valid FileSegment"""
|
||||
test_file = self.create_test_file("segment_file")
|
||||
file_segment = FileSegment(value=test_file)
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_segment)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], dict)
|
||||
assert result[0]["id"] == "segment_file"
|
||||
assert result[0]["dify_model_identity"] == FILE_MODEL_IDENTITY
|
||||
|
||||
def test_fetch_files_from_variable_value_with_array_file_segment_single(self):
|
||||
"""Test with ArrayFileSegment containing single file"""
|
||||
test_file = self.create_test_file("array_file_1")
|
||||
array_segment = ArrayFileSegment(value=[test_file])
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], dict)
|
||||
assert result[0]["id"] == "array_file_1"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_array_file_segment_multiple(self):
|
||||
"""Test with ArrayFileSegment containing multiple files"""
|
||||
test_file_1 = self.create_test_file("array_file_1")
|
||||
test_file_2 = self.create_test_file("array_file_2")
|
||||
array_segment = ArrayFileSegment(value=[test_file_1, test_file_2])
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "array_file_1"
|
||||
assert result[1]["id"] == "array_file_2"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_array_file_segment_empty(self):
|
||||
"""Test with ArrayFileSegment containing empty array"""
|
||||
array_segment = ArrayFileSegment(value=[])
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(array_segment)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_of_file_dicts(self):
|
||||
"""Test with list containing file dictionaries"""
|
||||
file_dict_1 = self.create_file_dict("list_file_1")
|
||||
file_dict_2 = self.create_file_dict("list_file_2")
|
||||
test_list = [file_dict_1, file_dict_2]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "list_file_1"
|
||||
assert result[1]["id"] == "list_file_2"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_of_file_objects(self):
|
||||
"""Test with list containing File objects"""
|
||||
file_obj_1 = self.create_test_file("list_obj_1")
|
||||
file_obj_2 = self.create_test_file("list_obj_2")
|
||||
test_list = [file_obj_1, file_obj_2]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "list_obj_1"
|
||||
assert result[1]["id"] == "list_obj_2"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_mixed_valid_invalid(self):
|
||||
"""Test with list containing mix of valid files and invalid items"""
|
||||
file_dict = self.create_file_dict("mixed_file")
|
||||
invalid_dict = {"not_a_file": "value"}
|
||||
test_list = [file_dict, invalid_dict, "string_item", 123]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == "mixed_file"
|
||||
|
||||
def test_fetch_files_from_variable_value_with_list_nested_structures(self):
|
||||
"""Test with list containing nested structures"""
|
||||
file_dict = self.create_file_dict("nested_file")
|
||||
nested_list = [file_dict, ["inner_list"]]
|
||||
test_list = [nested_list, {"nested": "dict"}]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_list)
|
||||
|
||||
# Should not process nested structures in list items
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_dict_incorrect_identity(self):
|
||||
"""Test with dictionary having incorrect dify_model_identity"""
|
||||
invalid_dict = {"dify_model_identity": "wrong_identity", "id": "invalid_file", "filename": "test.txt"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_dict_missing_identity(self):
|
||||
"""Test with dictionary missing dify_model_identity"""
|
||||
invalid_dict = {"id": "no_identity_file", "filename": "test.txt"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_dict_file_object(self):
|
||||
"""Test with dictionary containing File object"""
|
||||
file_obj = self.create_test_file("dict_obj_file")
|
||||
test_dict = {"file_key": file_obj}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(test_dict)
|
||||
|
||||
# Should not extract File objects from dict values
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_mixed_data_types(self):
|
||||
"""Test with various mixed data types"""
|
||||
mixed_data = {"string": "text", "number": 42, "boolean": True, "null": None, "dify_model_identity": "wrong"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(mixed_data)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_invalid_objects(self):
|
||||
"""Test with invalid objects that are not supported types"""
|
||||
# Test with an invalid dict that doesn't match expected patterns
|
||||
invalid_dict = {"custom_key": "custom_value"}
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(invalid_dict)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_string_input(self):
|
||||
"""Test with string input (unsupported type)"""
|
||||
# Since method expects Union[dict, list, Segment], test with empty list instead
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value([])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_with_number_input(self):
|
||||
"""Test with number input (unsupported type)"""
|
||||
# Test with list containing numbers (should be ignored)
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value([42, "string", None])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_from_variable_value_return_type_is_sequence(self):
|
||||
"""Test that return type is Sequence[Mapping[str, Any]]"""
|
||||
file_dict = self.create_file_dict("type_test_file")
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_dict)
|
||||
|
||||
assert isinstance(result, Sequence)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], Mapping)
|
||||
assert all(isinstance(key, str) for key in result[0])
|
||||
|
||||
def test_fetch_files_from_variable_value_preserves_file_properties(self):
|
||||
"""Test that all file properties are preserved in the result"""
|
||||
original_file = self.create_test_file("property_test")
|
||||
file_segment = FileSegment(value=original_file)
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(file_segment)
|
||||
|
||||
assert len(result) == 1
|
||||
file_dict = result[0]
|
||||
assert file_dict["id"] == "property_test"
|
||||
assert file_dict["tenant_id"] == "test_tenant"
|
||||
assert file_dict["type"] == "document"
|
||||
assert file_dict["transfer_method"] == "local_file"
|
||||
assert file_dict["filename"] == "property_test.txt"
|
||||
assert file_dict["extension"] == ".txt"
|
||||
assert file_dict["mime_type"] == "text/plain"
|
||||
assert file_dict["size"] == 1024
|
||||
|
||||
def test_fetch_files_from_variable_value_with_complex_nested_scenario(self):
|
||||
"""Test complex scenario with nested valid and invalid data"""
|
||||
file_dict = self.create_file_dict("complex_file")
|
||||
file_obj = self.create_test_file("complex_obj")
|
||||
|
||||
# Complex nested structure
|
||||
complex_data = [
|
||||
file_dict, # Valid file dict
|
||||
file_obj, # Valid file object
|
||||
{ # Invalid dict
|
||||
"not_file": "data",
|
||||
"nested": {"deep": "value"},
|
||||
},
|
||||
[ # Nested list (should be ignored)
|
||||
self.create_file_dict("nested_file")
|
||||
],
|
||||
"string", # Invalid string
|
||||
None, # None value
|
||||
42, # Invalid number
|
||||
]
|
||||
|
||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(complex_data)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "complex_file"
|
||||
assert result[1]["id"] == "complex_obj"
|
||||
@@ -0,0 +1,810 @@
|
||||
"""
|
||||
Unit tests for WorkflowResponseConverter focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestWorkflowResponseConverter:
|
||||
"""Test truncation in WorkflowResponseConverter."""
|
||||
|
||||
def create_mock_generate_entity(self) -> WorkflowAppGenerateEntity:
|
||||
"""Create a mock WorkflowAppGenerateEntity."""
|
||||
mock_entity = Mock(spec=WorkflowAppGenerateEntity)
|
||||
mock_app_config = Mock()
|
||||
mock_app_config.tenant_id = "test-tenant-id"
|
||||
mock_entity.invoke_from = InvokeFrom.WEB_APP
|
||||
mock_entity.app_config = mock_app_config
|
||||
mock_entity.inputs = {}
|
||||
return mock_entity
|
||||
|
||||
def create_workflow_response_converter(self) -> WorkflowResponseConverter:
|
||||
"""Create a WorkflowResponseConverter for testing."""
|
||||
|
||||
mock_entity = self.create_mock_generate_entity()
|
||||
mock_user = Mock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.name = "Test User"
|
||||
mock_user.email = "test@example.com"
|
||||
|
||||
system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id")
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=mock_entity,
|
||||
user=mock_user,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
def create_node_started_event(self, *, node_execution_id: str | None = None) -> QueueNodeStartedEvent:
|
||||
"""Create a QueueNodeStartedEvent for testing."""
|
||||
return QueueNodeStartedEvent(
|
||||
node_execution_id=node_execution_id or str(uuid.uuid4()),
|
||||
node_id="test-node-id",
|
||||
node_title="Test Node",
|
||||
node_type=NodeType.CODE,
|
||||
start_at=naive_utc_now(),
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
provider_type="built-in",
|
||||
provider_id="code",
|
||||
)
|
||||
|
||||
def create_node_succeeded_event(
|
||||
self,
|
||||
*,
|
||||
node_execution_id: str,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
) -> QueueNodeSucceededEvent:
|
||||
"""Create a QueueNodeSucceededEvent for testing."""
|
||||
return QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_execution_id=node_execution_id,
|
||||
start_at=naive_utc_now(),
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
inputs={},
|
||||
process_data=process_data or {},
|
||||
outputs={},
|
||||
execution_metadata={},
|
||||
)
|
||||
|
||||
def create_node_retry_event(
|
||||
self,
|
||||
*,
|
||||
node_execution_id: str,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
) -> QueueNodeRetryEvent:
|
||||
"""Create a QueueNodeRetryEvent for testing."""
|
||||
return QueueNodeRetryEvent(
|
||||
inputs={"data": "inputs"},
|
||||
outputs={"data": "outputs"},
|
||||
process_data=process_data or {},
|
||||
error="oops",
|
||||
retry_index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_title="test code",
|
||||
provider_type="built-in",
|
||||
provider_id="code",
|
||||
node_execution_id=node_execution_id,
|
||||
start_at=naive_utc_now(),
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
def test_workflow_node_finish_response_uses_truncated_process_data(self):
|
||||
"""Test that node finish response uses get_response_process_data()."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
if mapping == dict(original_data):
|
||||
return truncated_data, True
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
# Response should use truncated data, not original
|
||||
assert response is not None
|
||||
assert response.data.process_data == truncated_data
|
||||
assert response.data.process_data != original_data
|
||||
assert response.data.process_data_truncated is True
|
||||
|
||||
def test_workflow_node_finish_response_without_truncation(self):
|
||||
"""Test node finish response when no truncation is applied."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
# Response should use original data
|
||||
assert response is not None
|
||||
assert response.data.process_data == original_data
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_finish_response_with_none_process_data(self):
|
||||
"""Test node finish response when process_data is None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=None,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
# Response should normalize missing process_data to an empty mapping
|
||||
assert response is not None
|
||||
assert response.data.process_data == {}
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_retry_response_uses_truncated_process_data(self):
|
||||
"""Test that node retry response uses get_response_process_data()."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_retry_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
if mapping == dict(original_data):
|
||||
return truncated_data, True
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
# Response should use truncated data, not original
|
||||
assert response is not None
|
||||
assert response.data.process_data == truncated_data
|
||||
assert response.data.process_data != original_data
|
||||
assert response.data.process_data_truncated is True
|
||||
|
||||
def test_workflow_node_retry_response_without_truncation(self):
|
||||
"""Test node retry response when no truncation is applied."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_retry_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.process_data == original_data
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_iteration_and_loop_nodes_return_none(self):
|
||||
"""Test that iteration and loop nodes return None (no streaming events)."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
iteration_event = QueueNodeSucceededEvent(
|
||||
node_id="iteration-node",
|
||||
node_type=NodeType.ITERATION,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
start_at=naive_utc_now(),
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
execution_metadata={},
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=iteration_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
assert response is None
|
||||
|
||||
loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP})
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=loop_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
assert response is None
|
||||
|
||||
def test_finish_without_start_raises(self):
|
||||
"""Ensure finish responses require a prior workflow start."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
process_data={},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestCase:
|
||||
"""Test case data for table-driven tests."""
|
||||
|
||||
name: str
|
||||
invoke_from: InvokeFrom
|
||||
expected_truncation_enabled: bool
|
||||
description: str
|
||||
|
||||
|
||||
class TestWorkflowResponseConverterServiceApiTruncation:
|
||||
"""Test class for Service API truncation functionality in WorkflowResponseConverter."""
|
||||
|
||||
def create_test_app_generate_entity(self, invoke_from: InvokeFrom) -> WorkflowAppGenerateEntity:
|
||||
"""Create a test WorkflowAppGenerateEntity with specified invoke_from."""
|
||||
# Create a minimal WorkflowUIBasedAppConfig for testing
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="test_workflow_id",
|
||||
)
|
||||
|
||||
entity = WorkflowAppGenerateEntity(
|
||||
task_id="test_task_id",
|
||||
app_id="test_app_id",
|
||||
app_config=app_config,
|
||||
tenant_id="test_tenant",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
invoke_from=invoke_from,
|
||||
inputs={"test_input": "test_value"},
|
||||
user_id="test_user_id",
|
||||
stream=True,
|
||||
files=[],
|
||||
workflow_execution_id="test_workflow_exec_id",
|
||||
)
|
||||
return entity
|
||||
|
||||
def create_test_user(self) -> Account:
|
||||
"""Create a test user account."""
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
# Manually set the ID for testing purposes
|
||||
account.id = "test_user_id"
|
||||
return account
|
||||
|
||||
def create_test_system_variables(self) -> SystemVariable:
|
||||
"""Create test system variables."""
|
||||
return SystemVariable()
|
||||
|
||||
def create_test_converter(self, invoke_from: InvokeFrom) -> WorkflowResponseConverter:
|
||||
"""Create WorkflowResponseConverter with specified invoke_from."""
|
||||
entity = self.create_test_app_generate_entity(invoke_from)
|
||||
user = self.create_test_user()
|
||||
system_variables = self.create_test_system_variables()
|
||||
|
||||
converter = WorkflowResponseConverter(
|
||||
application_generate_entity=entity,
|
||||
user=user,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
# ensure `workflow_run_id` is set.
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="test-task-id",
|
||||
workflow_run_id="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
)
|
||||
return converter
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
TestCase(
|
||||
name="service_api_truncation_disabled",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
expected_truncation_enabled=False,
|
||||
description="Service API calls should have truncation disabled",
|
||||
),
|
||||
TestCase(
|
||||
name="web_app_truncation_enabled",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
expected_truncation_enabled=True,
|
||||
description="Web app calls should have truncation enabled",
|
||||
),
|
||||
TestCase(
|
||||
name="debugger_truncation_enabled",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
expected_truncation_enabled=True,
|
||||
description="Debugger calls should have truncation enabled",
|
||||
),
|
||||
TestCase(
|
||||
name="explore_truncation_enabled",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
expected_truncation_enabled=True,
|
||||
description="Explore calls should have truncation enabled",
|
||||
),
|
||||
TestCase(
|
||||
name="published_truncation_enabled",
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
expected_truncation_enabled=True,
|
||||
description="Published app calls should have truncation enabled",
|
||||
),
|
||||
],
|
||||
ids=lambda x: x.name,
|
||||
)
|
||||
def test_truncator_selection_based_on_invoke_from(self, test_case: TestCase):
|
||||
"""Test that the correct truncator is selected based on invoke_from."""
|
||||
converter = self.create_test_converter(test_case.invoke_from)
|
||||
|
||||
# Test truncation behavior instead of checking private attribute
|
||||
|
||||
# Create a test event with large data
|
||||
large_value = {"key": ["x"] * 2000} # Large data that would be truncated
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=large_value,
|
||||
process_data=large_value,
|
||||
outputs=large_value,
|
||||
error=None,
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test_task",
|
||||
)
|
||||
|
||||
# Verify response is not None
|
||||
assert response is not None
|
||||
|
||||
# Verify truncation behavior matches expectations
|
||||
if test_case.expected_truncation_enabled:
|
||||
# Truncation should be enabled for non-service-api calls
|
||||
assert response.data.inputs_truncated
|
||||
assert response.data.process_data_truncated
|
||||
assert response.data.outputs_truncated
|
||||
else:
|
||||
# SERVICE_API should not truncate
|
||||
assert not response.data.inputs_truncated
|
||||
assert not response.data.process_data_truncated
|
||||
assert not response.data.outputs_truncated
|
||||
|
||||
def test_service_api_truncator_no_op_mapping(self):
|
||||
"""Test that Service API truncator doesn't truncate variable mappings."""
|
||||
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||
|
||||
# Create a test event with large data
|
||||
large_value: dict[str, Any] = {
|
||||
"large_string": "x" * 10000, # Large string
|
||||
"large_list": list(range(2000)), # Large array
|
||||
"nested_data": {"deep_nested": {"very_deep": {"value": "x" * 5000}}},
|
||||
}
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=large_value,
|
||||
process_data=large_value,
|
||||
outputs=large_value,
|
||||
error=None,
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test_task",
|
||||
)
|
||||
|
||||
# Verify response is not None
|
||||
data = response.data
|
||||
assert data.inputs == large_value
|
||||
assert data.process_data == large_value
|
||||
assert data.outputs == large_value
|
||||
# Service API should not truncate
|
||||
assert data.inputs_truncated is False
|
||||
assert data.process_data_truncated is False
|
||||
assert data.outputs_truncated is False
|
||||
|
||||
def test_web_app_truncator_works_normally(self):
|
||||
"""Test that web app truncator still works normally."""
|
||||
converter = self.create_test_converter(InvokeFrom.WEB_APP)
|
||||
|
||||
# Create a test event with large data
|
||||
large_value = {
|
||||
"large_string": "x" * 10000, # Large string
|
||||
"large_list": list(range(2000)), # Large array
|
||||
}
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=large_value,
|
||||
process_data=large_value,
|
||||
outputs=large_value,
|
||||
error=None,
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test_task",
|
||||
)
|
||||
|
||||
# Verify response is not None
|
||||
assert response is not None
|
||||
|
||||
# Web app should truncate
|
||||
data = response.data
|
||||
assert data.inputs != large_value
|
||||
assert data.process_data != large_value
|
||||
assert data.outputs != large_value
|
||||
# The exact behavior depends on VariableTruncator implementation
|
||||
# Just verify that truncation flags are present
|
||||
assert data.inputs_truncated is True
|
||||
assert data.process_data_truncated is True
|
||||
assert data.outputs_truncated is True
|
||||
|
||||
@staticmethod
|
||||
def _create_event_by_type(
|
||||
type_: QueueEvent, inputs: Mapping[str, Any], process_data: Mapping[str, Any], outputs: Mapping[str, Any]
|
||||
) -> QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent:
|
||||
if type_ == QueueEvent.NODE_SUCCEEDED:
|
||||
return QueueNodeSucceededEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
error=None,
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
elif type_ == QueueEvent.NODE_FAILED:
|
||||
return QueueNodeFailedEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
error="oops",
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
elif type_ == QueueEvent.NODE_EXCEPTION:
|
||||
return QueueNodeExceptionEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
error="oops",
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
else:
|
||||
raise Exception("unknown type.")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"event_type",
|
||||
[
|
||||
QueueEvent.NODE_SUCCEEDED,
|
||||
QueueEvent.NODE_FAILED,
|
||||
QueueEvent.NODE_EXCEPTION,
|
||||
],
|
||||
)
|
||||
def test_service_api_node_finish_event_no_truncation(self, event_type: QueueEvent):
|
||||
"""Test that Service API doesn't truncate node finish events."""
|
||||
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||
# Create test event with large data
|
||||
large_inputs = {"input1": "x" * 5000, "input2": list(range(2000))}
|
||||
large_process_data = {"process1": "y" * 5000, "process2": {"nested": ["z"] * 2000}}
|
||||
large_outputs = {"output1": "result" * 1000, "output2": list(range(2000))}
|
||||
|
||||
event = TestWorkflowResponseConverterServiceApiTruncation._create_event_by_type(
|
||||
event_type, large_inputs, large_process_data, large_outputs
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test_task",
|
||||
)
|
||||
|
||||
# Verify response is not None
|
||||
assert response is not None
|
||||
|
||||
# Verify response contains full data (not truncated)
|
||||
assert response.data.inputs == large_inputs
|
||||
assert response.data.process_data == large_process_data
|
||||
assert response.data.outputs == large_outputs
|
||||
assert not response.data.inputs_truncated
|
||||
assert not response.data.process_data_truncated
|
||||
assert not response.data.outputs_truncated
|
||||
|
||||
def test_service_api_node_retry_event_no_truncation(self):
|
||||
"""Test that Service API doesn't truncate node retry events."""
|
||||
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||
|
||||
# Create test event with large data
|
||||
large_inputs = {"retry_input": "x" * 5000}
|
||||
large_process_data = {"retry_process": "y" * 5000}
|
||||
large_outputs = {"retry_output": "z" * 5000}
|
||||
|
||||
# First, we need to store a snapshot by simulating a start event
|
||||
start_event = QueueNodeStartedEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="Test Node",
|
||||
node_run_index=1,
|
||||
start_at=naive_utc_now(),
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
agent_strategy=None,
|
||||
provider_type="plugin",
|
||||
provider_id="test/test_plugin",
|
||||
)
|
||||
converter.workflow_node_start_to_stream_response(event=start_event, task_id="test_task")
|
||||
|
||||
# Now create retry event
|
||||
event = QueueNodeRetryEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="Test Node",
|
||||
node_run_index=1,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=large_inputs,
|
||||
process_data=large_process_data,
|
||||
outputs=large_outputs,
|
||||
error="Retry error",
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
retry_index=1,
|
||||
provider_type="plugin",
|
||||
provider_id="test/test_plugin",
|
||||
)
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test_task",
|
||||
)
|
||||
|
||||
# Verify response is not None
|
||||
assert response is not None
|
||||
|
||||
# Verify response contains full data (not truncated)
|
||||
assert response.data.inputs == large_inputs
|
||||
assert response.data.process_data == large_process_data
|
||||
assert response.data.outputs == large_outputs
|
||||
assert not response.data.inputs_truncated
|
||||
assert not response.data.process_data_truncated
|
||||
assert not response.data.outputs_truncated
|
||||
|
||||
def test_service_api_iteration_events_no_truncation(self):
|
||||
"""Test that Service API doesn't truncate iteration events."""
|
||||
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||
|
||||
# Test iteration start event
|
||||
large_value = {"iteration_input": ["x"] * 2000}
|
||||
|
||||
start_event = QueueIterationStartEvent(
|
||||
node_execution_id="test_iter_exec_id",
|
||||
node_id="test_iteration",
|
||||
node_type=NodeType.ITERATION,
|
||||
node_title="Test Iteration",
|
||||
node_run_index=0,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=large_value,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
response = converter.workflow_iteration_start_to_stream_response(
|
||||
task_id="test_task",
|
||||
workflow_execution_id="test_workflow_exec_id",
|
||||
event=start_event,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.inputs == large_value
|
||||
assert not response.data.inputs_truncated
|
||||
|
||||
def test_service_api_loop_events_no_truncation(self):
|
||||
"""Test that Service API doesn't truncate loop events."""
|
||||
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||
|
||||
# Test loop start event
|
||||
large_inputs = {"loop_input": ["x"] * 2000}
|
||||
|
||||
start_event = QueueLoopStartEvent(
|
||||
node_execution_id="test_loop_exec_id",
|
||||
node_id="test_loop",
|
||||
node_type=NodeType.LOOP,
|
||||
node_title="Test Loop",
|
||||
start_at=naive_utc_now(),
|
||||
inputs=large_inputs,
|
||||
metadata={},
|
||||
node_run_index=0,
|
||||
)
|
||||
|
||||
response = converter.workflow_loop_start_to_stream_response(
|
||||
task_id="test_task",
|
||||
workflow_execution_id="test_workflow_exec_id",
|
||||
event=start_event,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.inputs == large_inputs
|
||||
assert not response.data.inputs_truncated
|
||||
|
||||
def test_web_app_node_finish_event_truncation_works(self):
|
||||
"""Test that web app still truncates node finish events."""
|
||||
converter = self.create_test_converter(InvokeFrom.WEB_APP)
|
||||
|
||||
# Create test event with large data that should be truncated
|
||||
large_inputs = {"input1": ["x"] * 2000}
|
||||
large_process_data = {"process1": ["y"] * 2000}
|
||||
large_outputs = {"output1": ["z"] * 2000}
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=large_inputs,
|
||||
process_data=large_process_data,
|
||||
outputs=large_outputs,
|
||||
error=None,
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test_task",
|
||||
)
|
||||
|
||||
# Verify response is not None
|
||||
assert response is not None
|
||||
|
||||
# Verify response contains truncated data
|
||||
# The exact behavior depends on VariableTruncator implementation
|
||||
# Just verify truncation flags are set correctly (may or may not be truncated depending on size)
|
||||
# At minimum, the truncation mechanism should work
|
||||
assert isinstance(response.data.inputs, dict)
|
||||
assert response.data.inputs_truncated
|
||||
assert isinstance(response.data.process_data, dict)
|
||||
assert response.data.process_data_truncated
|
||||
assert isinstance(response.data.outputs, dict)
|
||||
assert response.data.outputs_truncated
|
||||
@@ -0,0 +1,267 @@
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
|
||||
|
||||
def test_validate_inputs_with_zero():
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
var = VariableEntity(
|
||||
variable="test_var",
|
||||
label="test_var",
|
||||
type=VariableEntityType.NUMBER,
|
||||
required=True,
|
||||
)
|
||||
|
||||
# Test with input 0
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var,
|
||||
value=0,
|
||||
)
|
||||
|
||||
assert result == 0
|
||||
|
||||
# Test with input "0" (string)
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var,
|
||||
value="0",
|
||||
)
|
||||
|
||||
assert result == 0
|
||||
|
||||
|
||||
def test_validate_input_with_none_for_required_variable():
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
for var_type in VariableEntityType:
|
||||
var = VariableEntity(
|
||||
variable="test_var",
|
||||
label="test_var",
|
||||
type=var_type,
|
||||
required=True,
|
||||
)
|
||||
|
||||
# Test with input None
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
base_app_generator._validate_inputs(
|
||||
variable_entity=var,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert str(exc_info.value) == "test_var is required in input form"
|
||||
|
||||
|
||||
def test_validate_inputs_with_default_value():
|
||||
"""Test that default values are used when input is None for optional variables"""
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
# Test with string default value for TEXT_INPUT
|
||||
var_string = VariableEntity(
|
||||
variable="test_var",
|
||||
label="test_var",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
required=False,
|
||||
default="default_string",
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_string,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == "default_string"
|
||||
|
||||
# Test with string default value for PARAGRAPH
|
||||
var_paragraph = VariableEntity(
|
||||
variable="test_paragraph",
|
||||
label="test_paragraph",
|
||||
type=VariableEntityType.PARAGRAPH,
|
||||
required=False,
|
||||
default="default paragraph text",
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_paragraph,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == "default paragraph text"
|
||||
|
||||
# Test with SELECT default value
|
||||
var_select = VariableEntity(
|
||||
variable="test_select",
|
||||
label="test_select",
|
||||
type=VariableEntityType.SELECT,
|
||||
required=False,
|
||||
default="option1",
|
||||
options=["option1", "option2", "option3"],
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_select,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == "option1"
|
||||
|
||||
# Test with number default value (int)
|
||||
var_number_int = VariableEntity(
|
||||
variable="test_number_int",
|
||||
label="test_number_int",
|
||||
type=VariableEntityType.NUMBER,
|
||||
required=False,
|
||||
default=42,
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_number_int,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == 42
|
||||
|
||||
# Test with number default value (float)
|
||||
var_number_float = VariableEntity(
|
||||
variable="test_number_float",
|
||||
label="test_number_float",
|
||||
type=VariableEntityType.NUMBER,
|
||||
required=False,
|
||||
default=3.14,
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_number_float,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == 3.14
|
||||
|
||||
# Test with number default value as string (frontend sends as string)
|
||||
var_number_string = VariableEntity(
|
||||
variable="test_number_string",
|
||||
label="test_number_string",
|
||||
type=VariableEntityType.NUMBER,
|
||||
required=False,
|
||||
default="123",
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_number_string,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == 123
|
||||
assert isinstance(result, int)
|
||||
|
||||
# Test with float number default value as string
|
||||
var_number_float_string = VariableEntity(
|
||||
variable="test_number_float_string",
|
||||
label="test_number_float_string",
|
||||
type=VariableEntityType.NUMBER,
|
||||
required=False,
|
||||
default="45.67",
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_number_float_string,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == 45.67
|
||||
assert isinstance(result, float)
|
||||
|
||||
# Test with CHECKBOX default value (bool)
|
||||
var_checkbox_true = VariableEntity(
|
||||
variable="test_checkbox_true",
|
||||
label="test_checkbox_true",
|
||||
type=VariableEntityType.CHECKBOX,
|
||||
required=False,
|
||||
default=True,
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_checkbox_true,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
var_checkbox_false = VariableEntity(
|
||||
variable="test_checkbox_false",
|
||||
label="test_checkbox_false",
|
||||
type=VariableEntityType.CHECKBOX,
|
||||
required=False,
|
||||
default=False,
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_checkbox_false,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
# Test with None as explicit default value
|
||||
var_none_default = VariableEntity(
|
||||
variable="test_none",
|
||||
label="test_none",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_none_default,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
# Test that actual input value takes precedence over default
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_string,
|
||||
value="actual_value",
|
||||
)
|
||||
|
||||
assert result == "actual_value"
|
||||
|
||||
# Test that actual number input takes precedence over default
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_number_int,
|
||||
value=999,
|
||||
)
|
||||
|
||||
assert result == 999
|
||||
|
||||
# Test with FILE default value (dict format from frontend)
|
||||
var_file = VariableEntity(
|
||||
variable="test_file",
|
||||
label="test_file",
|
||||
type=VariableEntityType.FILE,
|
||||
required=False,
|
||||
default={"id": "file123", "name": "default.pdf"},
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_file,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == {"id": "file123", "name": "default.pdf"}
|
||||
|
||||
# Test with FILE_LIST default value (list of dicts)
|
||||
var_file_list = VariableEntity(
|
||||
variable="test_file_list",
|
||||
label="test_file_list",
|
||||
type=VariableEntityType.FILE_LIST,
|
||||
required=False,
|
||||
default=[{"id": "file1", "name": "doc1.pdf"}, {"id": "file2", "name": "doc2.pdf"}],
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_file_list,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == [{"id": "file1", "name": "doc1.pdf"}, {"id": "file2", "name": "doc2.pdf"}]
|
||||
@@ -0,0 +1,19 @@
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
|
||||
|
||||
def test_should_prepare_user_inputs_defaults_to_true():
|
||||
args = {"inputs": {}}
|
||||
|
||||
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
|
||||
|
||||
|
||||
def test_should_prepare_user_inputs_skips_when_flag_truthy():
|
||||
args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: True}
|
||||
|
||||
assert not WorkflowAppGenerator()._should_prepare_user_inputs(args)
|
||||
|
||||
|
||||
def test_should_prepare_user_inputs_keeps_validation_when_flag_false():
|
||||
args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False}
|
||||
|
||||
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
|
||||
@@ -0,0 +1,124 @@
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
"""Mock Redis client with realistic behavior for rate limiting tests."""
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Redis data storage for simulation
|
||||
mock_data = {}
|
||||
mock_hashes = {}
|
||||
mock_expiry = {}
|
||||
|
||||
def mock_setex(key, ttl, value):
|
||||
mock_data[key] = str(value)
|
||||
mock_expiry[key] = time.time() + ttl.total_seconds() if hasattr(ttl, "total_seconds") else time.time() + ttl
|
||||
return True
|
||||
|
||||
def mock_get(key):
|
||||
if key in mock_data and (key not in mock_expiry or time.time() < mock_expiry[key]):
|
||||
return mock_data[key].encode("utf-8")
|
||||
return None
|
||||
|
||||
def mock_exists(key):
|
||||
return key in mock_data or key in mock_hashes
|
||||
|
||||
def mock_expire(key, ttl):
|
||||
if key in mock_data or key in mock_hashes:
|
||||
mock_expiry[key] = time.time() + ttl.total_seconds() if hasattr(ttl, "total_seconds") else time.time() + ttl
|
||||
return True
|
||||
|
||||
def mock_hset(key, field, value):
|
||||
if key not in mock_hashes:
|
||||
mock_hashes[key] = {}
|
||||
mock_hashes[key][field] = str(value).encode("utf-8")
|
||||
return True
|
||||
|
||||
def mock_hgetall(key):
|
||||
return mock_hashes.get(key, {})
|
||||
|
||||
def mock_hdel(key, *fields):
|
||||
if key in mock_hashes:
|
||||
count = 0
|
||||
for field in fields:
|
||||
if field in mock_hashes[key]:
|
||||
del mock_hashes[key][field]
|
||||
count += 1
|
||||
return count
|
||||
return 0
|
||||
|
||||
def mock_hlen(key):
|
||||
return len(mock_hashes.get(key, {}))
|
||||
|
||||
# Configure mock methods
|
||||
mock_client.setex = mock_setex
|
||||
mock_client.get = mock_get
|
||||
mock_client.exists = mock_exists
|
||||
mock_client.expire = mock_expire
|
||||
mock_client.hset = mock_hset
|
||||
mock_client.hgetall = mock_hgetall
|
||||
mock_client.hdel = mock_hdel
|
||||
mock_client.hlen = mock_hlen
|
||||
|
||||
# Store references for test verification
|
||||
mock_client._mock_data = mock_data
|
||||
mock_client._mock_hashes = mock_hashes
|
||||
mock_client._mock_expiry = mock_expiry
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_time():
|
||||
"""Mock time.time() for deterministic tests."""
|
||||
mock_time_val = 1000.0
|
||||
|
||||
def increment_time(seconds=1):
|
||||
nonlocal mock_time_val
|
||||
mock_time_val += seconds
|
||||
return mock_time_val
|
||||
|
||||
with patch("time.time", return_value=mock_time_val) as mock:
|
||||
mock.increment = increment_time
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_generator():
|
||||
"""Sample generator for testing RateLimitGenerator."""
|
||||
|
||||
def _create_generator(items=None, raise_error=False):
|
||||
items = items or ["item1", "item2", "item3"]
|
||||
for item in items:
|
||||
if raise_error and item == "item2":
|
||||
raise ValueError("Test error")
|
||||
yield item
|
||||
|
||||
return _create_generator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_mapping():
|
||||
"""Sample mapping for testing RateLimitGenerator."""
|
||||
return {"key1": "value1", "key2": "value2"}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_rate_limit_instances():
|
||||
"""Clear RateLimit singleton instances between tests."""
|
||||
RateLimit._instance_dict.clear()
|
||||
yield
|
||||
RateLimit._instance_dict.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_patch():
|
||||
"""Patch redis_client globally for rate limit tests."""
|
||||
with patch("core.app.features.rate_limiting.rate_limit.redis_client") as mock:
|
||||
yield mock
|
||||
@@ -0,0 +1,569 @@
|
||||
import threading
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimit
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
|
||||
|
||||
class TestRateLimit:
|
||||
"""Core rate limiting functionality tests."""
|
||||
|
||||
def test_should_return_same_instance_for_same_client_id(self, redis_patch):
|
||||
"""Test singleton behavior for same client ID."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit1 = RateLimit("client1", 5)
|
||||
rate_limit2 = RateLimit("client1", 10) # Second instance with different limit
|
||||
|
||||
assert rate_limit1 is rate_limit2
|
||||
# Current implementation: last constructor call overwrites max_active_requests
|
||||
# This reflects the actual behavior where __init__ always sets max_active_requests
|
||||
assert rate_limit1.max_active_requests == 10
|
||||
|
||||
def test_should_create_different_instances_for_different_client_ids(self, redis_patch):
|
||||
"""Test different instances for different client IDs."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit1 = RateLimit("client1", 5)
|
||||
rate_limit2 = RateLimit("client2", 10)
|
||||
|
||||
assert rate_limit1 is not rate_limit2
|
||||
assert rate_limit1.client_id == "client1"
|
||||
assert rate_limit2.client_id == "client2"
|
||||
|
||||
def test_should_initialize_with_valid_parameters(self, redis_patch):
|
||||
"""Test normal initialization."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
assert rate_limit.client_id == "test_client"
|
||||
assert rate_limit.max_active_requests == 5
|
||||
assert hasattr(rate_limit, "initialized")
|
||||
redis_patch.setex.assert_called_once()
|
||||
|
||||
def test_should_skip_initialization_if_disabled(self):
|
||||
"""Test no initialization when rate limiting is disabled."""
|
||||
rate_limit = RateLimit("test_client", 0)
|
||||
|
||||
assert rate_limit.disabled()
|
||||
assert not hasattr(rate_limit, "initialized")
|
||||
|
||||
def test_should_skip_reinitialization_of_existing_instance(self, redis_patch):
|
||||
"""Test that existing instance doesn't reinitialize."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
RateLimit("client1", 5)
|
||||
redis_patch.reset_mock()
|
||||
|
||||
RateLimit("client1", 10)
|
||||
|
||||
redis_patch.setex.assert_not_called()
|
||||
|
||||
def test_should_be_disabled_when_max_requests_is_zero_or_negative(self):
|
||||
"""Test disabled state for zero or negative limits."""
|
||||
rate_limit_zero = RateLimit("client1", 0)
|
||||
rate_limit_negative = RateLimit("client2", -5)
|
||||
|
||||
assert rate_limit_zero.disabled()
|
||||
assert rate_limit_negative.disabled()
|
||||
|
||||
def test_should_set_redis_keys_on_first_flush(self, redis_patch):
|
||||
"""Test Redis keys are set correctly on initial flush."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
expected_max_key = "dify:rate_limit:test_client:max_active_requests"
|
||||
redis_patch.setex.assert_called_with(expected_max_key, timedelta(days=1), 5)
|
||||
|
||||
def test_should_sync_max_requests_from_redis_on_subsequent_flush(self, redis_patch):
|
||||
"""Test max requests syncs from Redis when key exists."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": True,
|
||||
"get.return_value": b"10",
|
||||
"expire.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
rate_limit.flush_cache()
|
||||
|
||||
assert rate_limit.max_active_requests == 10
|
||||
|
||||
@patch("time.time")
|
||||
def test_should_clean_timeout_requests_from_active_list(self, mock_time, redis_patch):
|
||||
"""Test cleanup of timed-out requests."""
|
||||
current_time = 1000.0
|
||||
mock_time.return_value = current_time
|
||||
|
||||
# Setup mock Redis with timed-out requests
|
||||
timeout_requests = {
|
||||
b"req1": str(current_time - 700).encode(), # 700 seconds ago (timeout)
|
||||
b"req2": str(current_time - 100).encode(), # 100 seconds ago (active)
|
||||
}
|
||||
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": True,
|
||||
"get.return_value": b"5",
|
||||
"expire.return_value": True,
|
||||
"hgetall.return_value": timeout_requests,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
redis_patch.reset_mock() # Reset to avoid counting initialization calls
|
||||
rate_limit.flush_cache()
|
||||
|
||||
# Verify timeout request was cleaned up
|
||||
redis_patch.hdel.assert_called_once()
|
||||
call_args = redis_patch.hdel.call_args[0]
|
||||
assert call_args[0] == "dify:rate_limit:test_client:active_requests"
|
||||
assert b"req1" in call_args # Timeout request should be removed
|
||||
assert b"req2" not in call_args # Active request should remain
|
||||
|
||||
|
||||
class TestRateLimitEnterExit:
|
||||
"""Rate limiting enter/exit logic tests."""
|
||||
|
||||
def test_should_allow_request_within_limit(self, redis_patch):
|
||||
"""Test allowing requests within the rate limit."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 2,
|
||||
"hset.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
request_id = rate_limit.enter()
|
||||
|
||||
assert request_id != RateLimit._UNLIMITED_REQUEST_ID
|
||||
redis_patch.hset.assert_called_once()
|
||||
|
||||
def test_should_generate_request_id_if_not_provided(self, redis_patch):
|
||||
"""Test auto-generation of request ID."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 0,
|
||||
"hset.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
request_id = rate_limit.enter()
|
||||
|
||||
assert len(request_id) == 36 # UUID format
|
||||
|
||||
def test_should_use_provided_request_id(self, redis_patch):
|
||||
"""Test using provided request ID."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 0,
|
||||
"hset.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
custom_id = "custom_request_123"
|
||||
request_id = rate_limit.enter(custom_id)
|
||||
|
||||
assert request_id == custom_id
|
||||
|
||||
def test_should_remove_request_on_exit(self, redis_patch):
|
||||
"""Test request removal on exit."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
rate_limit.exit("test_request_id")
|
||||
|
||||
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", "test_request_id")
|
||||
|
||||
def test_should_raise_quota_exceeded_when_at_limit(self, redis_patch):
|
||||
"""Test quota exceeded error when at limit."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 5, # At limit
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
with pytest.raises(AppInvokeQuotaExceededError) as exc_info:
|
||||
rate_limit.enter()
|
||||
|
||||
assert "Too many requests" in str(exc_info.value)
|
||||
assert "test_client" in str(exc_info.value)
|
||||
|
||||
def test_should_allow_request_after_previous_exit(self, redis_patch):
|
||||
"""Test allowing new request after previous exit."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 4, # Under limit after exit
|
||||
"hset.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
request_id = rate_limit.enter()
|
||||
rate_limit.exit(request_id)
|
||||
|
||||
new_request_id = rate_limit.enter()
|
||||
assert new_request_id is not None
|
||||
|
||||
@patch("time.time")
|
||||
def test_should_flush_cache_when_interval_exceeded(self, mock_time, redis_patch):
|
||||
"""Test cache flush when time interval exceeded."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.return_value": 0,
|
||||
}
|
||||
)
|
||||
|
||||
mock_time.return_value = 1000.0
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
|
||||
# Advance time beyond flush interval
|
||||
mock_time.return_value = 1400.0 # 400 seconds later
|
||||
redis_patch.reset_mock()
|
||||
|
||||
rate_limit.enter()
|
||||
|
||||
# Should have called setex again due to cache flush
|
||||
redis_patch.setex.assert_called()
|
||||
|
||||
def test_should_return_unlimited_id_when_disabled(self):
|
||||
"""Test unlimited ID return when rate limiting disabled."""
|
||||
rate_limit = RateLimit("test_client", 0)
|
||||
request_id = rate_limit.enter()
|
||||
|
||||
assert request_id == RateLimit._UNLIMITED_REQUEST_ID
|
||||
|
||||
def test_should_ignore_exit_for_unlimited_requests(self, redis_patch):
|
||||
"""Test ignoring exit for unlimited requests."""
|
||||
rate_limit = RateLimit("test_client", 0)
|
||||
rate_limit.exit(RateLimit._UNLIMITED_REQUEST_ID)
|
||||
|
||||
redis_patch.hdel.assert_not_called()
|
||||
|
||||
|
||||
class TestRateLimitGenerator:
|
||||
"""Rate limit generator wrapper tests."""
|
||||
|
||||
def test_should_wrap_generator_and_iterate_normally(self, redis_patch, sample_generator):
|
||||
"""Test normal generator iteration with rate limit wrapper."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = sample_generator()
|
||||
request_id = "test_request"
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, request_id)
|
||||
result = list(wrapped_gen)
|
||||
|
||||
assert result == ["item1", "item2", "item3"]
|
||||
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", request_id)
|
||||
|
||||
def test_should_handle_mapping_input_directly(self, sample_mapping):
|
||||
"""Test direct return of mapping input."""
|
||||
rate_limit = RateLimit("test_client", 0) # Disabled
|
||||
result = rate_limit.generate(sample_mapping, "test_request")
|
||||
|
||||
assert result is sample_mapping
|
||||
|
||||
def test_should_cleanup_on_exception_during_iteration(self, redis_patch, sample_generator):
|
||||
"""Test cleanup when exception occurs during iteration."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = sample_generator(raise_error=True)
|
||||
request_id = "test_request"
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, request_id)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
list(wrapped_gen)
|
||||
|
||||
redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", request_id)
|
||||
|
||||
def test_should_cleanup_on_explicit_close(self, redis_patch, sample_generator):
|
||||
"""Test cleanup on explicit generator close."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = sample_generator()
|
||||
request_id = "test_request"
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, request_id)
|
||||
wrapped_gen.close()
|
||||
|
||||
redis_patch.hdel.assert_called_once()
|
||||
|
||||
def test_should_handle_generator_without_close_method(self, redis_patch):
|
||||
"""Test handling generator without close method."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
# Create a generator-like object without close method
|
||||
class SimpleGenerator:
|
||||
def __init__(self):
|
||||
self.items = ["test"]
|
||||
self.index = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= len(self.items):
|
||||
raise StopIteration
|
||||
item = self.items[self.index]
|
||||
self.index += 1
|
||||
return item
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = SimpleGenerator()
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, "test_request")
|
||||
wrapped_gen.close() # Should not raise error
|
||||
|
||||
redis_patch.hdel.assert_called_once()
|
||||
|
||||
def test_should_prevent_iteration_after_close(self, redis_patch, sample_generator):
|
||||
"""Test StopIteration after generator is closed."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hdel.return_value": 1,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("test_client", 5)
|
||||
generator = sample_generator()
|
||||
|
||||
wrapped_gen = rate_limit.generate(generator, "test_request")
|
||||
wrapped_gen.close()
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
next(wrapped_gen)
|
||||
|
||||
|
||||
class TestRateLimitConcurrency:
|
||||
"""Concurrent access safety tests."""
|
||||
|
||||
def test_should_handle_concurrent_instance_creation(self, redis_patch):
|
||||
"""Test thread-safe singleton instance creation."""
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
}
|
||||
)
|
||||
|
||||
instances = []
|
||||
errors = []
|
||||
|
||||
def create_instance():
|
||||
try:
|
||||
instance = RateLimit("concurrent_client", 5)
|
||||
instances.append(instance)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=create_instance) for _ in range(10)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0
|
||||
assert len({id(inst) for inst in instances}) == 1 # All same instance
|
||||
|
||||
def test_should_handle_concurrent_enter_requests(self, redis_patch):
|
||||
"""Test concurrent enter requests handling."""
|
||||
# Setup mock to simulate realistic Redis behavior
|
||||
request_count = 0
|
||||
|
||||
def mock_hlen(key):
|
||||
nonlocal request_count
|
||||
return request_count
|
||||
|
||||
def mock_hset(key, field, value):
|
||||
nonlocal request_count
|
||||
request_count += 1
|
||||
return True
|
||||
|
||||
redis_patch.configure_mock(
|
||||
**{
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.side_effect": mock_hlen,
|
||||
"hset.side_effect": mock_hset,
|
||||
}
|
||||
)
|
||||
|
||||
rate_limit = RateLimit("concurrent_client", 3)
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def try_enter():
|
||||
try:
|
||||
request_id = rate_limit.enter()
|
||||
results.append(request_id)
|
||||
except AppInvokeQuotaExceededError as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=try_enter) for _ in range(5)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Should have some successful requests and some quota exceeded
|
||||
assert len(results) + len(errors) == 5
|
||||
assert len(errors) > 0 # Some should be rejected
|
||||
|
||||
@patch("time.time")
|
||||
def test_should_maintain_accurate_count_under_load(self, mock_time, redis_patch):
|
||||
"""Test accurate count maintenance under concurrent load."""
|
||||
mock_time.return_value = 1000.0
|
||||
|
||||
# Use real mock_redis fixture for better simulation
|
||||
mock_client = self._create_mock_redis()
|
||||
redis_patch.configure_mock(**mock_client)
|
||||
|
||||
rate_limit = RateLimit("load_test_client", 10)
|
||||
active_requests = []
|
||||
|
||||
def enter_and_exit():
|
||||
try:
|
||||
request_id = rate_limit.enter()
|
||||
active_requests.append(request_id)
|
||||
time.sleep(0.01) # Simulate some work
|
||||
rate_limit.exit(request_id)
|
||||
active_requests.remove(request_id)
|
||||
except AppInvokeQuotaExceededError:
|
||||
pass # Expected under load
|
||||
|
||||
threads = [threading.Thread(target=enter_and_exit) for _ in range(20)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# All requests should have been cleaned up
|
||||
assert len(active_requests) == 0
|
||||
|
||||
def _create_mock_redis(self):
|
||||
"""Create a thread-safe mock Redis for concurrency tests."""
|
||||
import threading
|
||||
|
||||
lock = threading.Lock()
|
||||
data = {}
|
||||
hashes = {}
|
||||
|
||||
def mock_hlen(key):
|
||||
with lock:
|
||||
return len(hashes.get(key, {}))
|
||||
|
||||
def mock_hset(key, field, value):
|
||||
with lock:
|
||||
if key not in hashes:
|
||||
hashes[key] = {}
|
||||
hashes[key][field] = str(value).encode("utf-8")
|
||||
return True
|
||||
|
||||
def mock_hdel(key, *fields):
|
||||
with lock:
|
||||
if key in hashes:
|
||||
count = 0
|
||||
for field in fields:
|
||||
if field in hashes[key]:
|
||||
del hashes[key][field]
|
||||
count += 1
|
||||
return count
|
||||
return 0
|
||||
|
||||
return {
|
||||
"exists.return_value": False,
|
||||
"setex.return_value": True,
|
||||
"hlen.side_effect": mock_hlen,
|
||||
"hset.side_effect": mock_hset,
|
||||
"hdel.side_effect": mock_hdel,
|
||||
}
|
||||
@@ -0,0 +1,410 @@
|
||||
import json
|
||||
from time import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.layers.pause_state_persist_layer import (
|
||||
PauseStatePersistenceLayer,
|
||||
WorkflowResumptionContext,
|
||||
_AdvancedChatAppGenerateEntityWrapper,
|
||||
_WorkflowGenerateEntityWrapper,
|
||||
)
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||
from core.workflow.graph_events.graph import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class TestDataFactory:
|
||||
"""Factory helpers for constructing graph events used in tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
|
||||
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_started_event() -> GraphRunStartedEvent:
|
||||
return GraphRunStartedEvent()
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_succeeded_event(outputs: dict[str, object] | None = None) -> GraphRunSucceededEvent:
|
||||
return GraphRunSucceededEvent(outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_failed_event(
|
||||
error: str = "Test error",
|
||||
exceptions_count: int = 1,
|
||||
) -> GraphRunFailedEvent:
|
||||
return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count)
|
||||
|
||||
|
||||
class MockSystemVariableReadOnlyView:
|
||||
"""Minimal read-only system variable view for testing."""
|
||||
|
||||
def __init__(self, workflow_execution_id: str | None = None) -> None:
|
||||
self._workflow_execution_id = workflow_execution_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str | None:
|
||||
return self._workflow_execution_id
|
||||
|
||||
|
||||
class MockReadOnlyVariablePool:
|
||||
"""Mock implementation of ReadOnlyVariablePool for testing."""
|
||||
|
||||
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
|
||||
self._variables = variables or {}
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
value = self._variables.get((node_id, variable_key))
|
||||
if value is None:
|
||||
return None
|
||||
mock_segment = Mock(spec=Segment)
|
||||
mock_segment.value = value
|
||||
return mock_segment
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> dict[str, object]:
|
||||
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> dict[str, object]:
|
||||
return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)}
|
||||
|
||||
|
||||
class MockReadOnlyGraphRuntimeState:
|
||||
"""Mock implementation of ReadOnlyGraphRuntimeState for testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_at: float | None = None,
|
||||
total_tokens: int = 0,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue_size: int = 0,
|
||||
exceptions_count: int = 0,
|
||||
outputs: dict[str, object] | None = None,
|
||||
variables: dict[tuple[str, str], object] | None = None,
|
||||
workflow_execution_id: str | None = None,
|
||||
):
|
||||
self._start_at = start_at or time()
|
||||
self._total_tokens = total_tokens
|
||||
self._node_run_steps = node_run_steps
|
||||
self._ready_queue_size = ready_queue_size
|
||||
self._exceptions_count = exceptions_count
|
||||
self._outputs = outputs or {}
|
||||
self._variable_pool = MockReadOnlyVariablePool(variables)
|
||||
self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id)
|
||||
|
||||
@property
|
||||
def system_variable(self) -> MockSystemVariableReadOnlyView:
|
||||
return self._system_variable
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||
return self._variable_pool
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
return self._start_at
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self._total_tokens
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
return self._node_run_steps
|
||||
|
||||
@property
|
||||
def ready_queue_size(self) -> int:
|
||||
return self._ready_queue_size
|
||||
|
||||
@property
|
||||
def exceptions_count(self) -> int:
|
||||
return self._exceptions_count
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, object]:
|
||||
return self._outputs.copy()
|
||||
|
||||
@property
|
||||
def llm_usage(self):
|
||||
mock_usage = Mock()
|
||||
mock_usage.prompt_tokens = 10
|
||||
mock_usage.completion_tokens = 20
|
||||
mock_usage.total_tokens = 30
|
||||
return mock_usage
|
||||
|
||||
def get_output(self, key: str, default: object = None) -> object:
|
||||
return self._outputs.get(key, default)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"start_at": self._start_at,
|
||||
"total_tokens": self._total_tokens,
|
||||
"node_run_steps": self._node_run_steps,
|
||||
"ready_queue_size": self._ready_queue_size,
|
||||
"exceptions_count": self._exceptions_count,
|
||||
"outputs": self._outputs,
|
||||
"variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()},
|
||||
"workflow_execution_id": self._system_variable.workflow_execution_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class MockCommandChannel:
|
||||
"""Mock implementation of CommandChannel for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._commands: list[GraphEngineCommand] = []
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
return self._commands.copy()
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
self._commands.append(command)
|
||||
|
||||
|
||||
class TestPauseStatePersistenceLayer:
|
||||
"""Unit tests for PauseStatePersistenceLayer."""
|
||||
|
||||
@staticmethod
|
||||
def _create_generate_entity(workflow_execution_id: str = "run-123") -> WorkflowAppGenerateEntity:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-123",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-123",
|
||||
)
|
||||
return WorkflowAppGenerateEntity(
|
||||
task_id="task-123",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-123",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
)
|
||||
|
||||
def test_init_with_dependency_injection(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
state_owner_user_id = "user-123"
|
||||
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
generate_entity=self._create_generate_entity(),
|
||||
)
|
||||
|
||||
assert layer._session_maker is session_factory
|
||||
assert layer._state_owner_user_id == state_owner_user_id
|
||||
assert not hasattr(layer, "graph_runtime_state")
|
||||
assert not hasattr(layer, "command_channel")
|
||||
|
||||
def test_initialize_sets_dependencies(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id="owner",
|
||||
generate_entity=self._create_generate_entity(),
|
||||
)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||
command_channel = MockCommandChannel()
|
||||
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
assert layer.graph_runtime_state is graph_runtime_state
|
||||
assert layer.command_channel is command_channel
|
||||
|
||||
def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
generate_entity = self._create_generate_entity(workflow_execution_id="run-123")
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id="owner-123",
|
||||
generate_entity=generate_entity,
|
||||
)
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState(
|
||||
outputs={"result": "test_output"},
|
||||
total_tokens=100,
|
||||
workflow_execution_id="run-123",
|
||||
)
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event(outputs={"intermediate": "result"})
|
||||
expected_state = graph_runtime_state.dumps()
|
||||
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_called_once_with(session_factory)
|
||||
mock_repo.create_workflow_pause.assert_called_once_with(
|
||||
workflow_run_id="run-123",
|
||||
state_owner_user_id="owner-123",
|
||||
state=mock_repo.create_workflow_pause.call_args.kwargs["state"],
|
||||
)
|
||||
serialized_state = mock_repo.create_workflow_pause.call_args.kwargs["state"]
|
||||
resumption_context = WorkflowResumptionContext.loads(serialized_state)
|
||||
assert resumption_context.serialized_graph_runtime_state == expected_state
|
||||
assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump()
|
||||
|
||||
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id="owner-123",
|
||||
generate_entity=self._create_generate_entity(),
|
||||
)
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
events = [
|
||||
TestDataFactory.create_graph_run_started_event(),
|
||||
TestDataFactory.create_graph_run_succeeded_event(),
|
||||
TestDataFactory.create_graph_run_failed_event(),
|
||||
]
|
||||
|
||||
for event in events:
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_not_called()
|
||||
mock_repo.create_workflow_pause.assert_not_called()
|
||||
|
||||
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id="owner-123",
|
||||
generate_entity=self._create_generate_entity(),
|
||||
)
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event()
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
layer.on_event(event)
|
||||
|
||||
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id="owner-123",
|
||||
generate_entity=self._create_generate_entity(),
|
||||
)
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState(workflow_execution_id=None)
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_not_called()
|
||||
mock_repo.create_workflow_pause.assert_not_called()
|
||||
|
||||
|
||||
def _build_workflow_generate_entity_for_roundtrip() -> WorkflowResumptionContext:
|
||||
"""Create a WorkflowAppGenerateEntity with realistic data for WorkflowResumptionContext tests."""
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-roundtrip",
|
||||
app_id="app-roundtrip",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-roundtrip",
|
||||
)
|
||||
serialized_state = json.dumps({"state": "workflow"})
|
||||
|
||||
return WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=serialized_state,
|
||||
generate_entity=_WorkflowGenerateEntityWrapper(
|
||||
entity=WorkflowAppGenerateEntity(
|
||||
task_id="workflow-task",
|
||||
app_config=app_config,
|
||||
inputs={"input_key": "input_value"},
|
||||
files=[],
|
||||
user_id="user-roundtrip",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_execution_id="workflow-exec-roundtrip",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_advanced_chat_generate_entity_for_roundtrip() -> WorkflowResumptionContext:
|
||||
"""Create an AdvancedChatAppGenerateEntity with realistic data for WorkflowResumptionContext tests."""
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-advanced",
|
||||
app_id="app-advanced",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
workflow_id="workflow-advanced",
|
||||
)
|
||||
serialized_state = json.dumps({"state": "workflow"})
|
||||
|
||||
return WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=serialized_state,
|
||||
generate_entity=_AdvancedChatAppGenerateEntityWrapper(
|
||||
entity=AdvancedChatAppGenerateEntity(
|
||||
task_id="advanced-task",
|
||||
app_config=app_config,
|
||||
inputs={"topic": "roundtrip"},
|
||||
files=[],
|
||||
user_id="advanced-user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_run_id="advanced-run-id",
|
||||
query="Explain serialization behavior",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"state",
|
||||
[
|
||||
pytest.param(
|
||||
_build_advanced_chat_generate_entity_for_roundtrip(),
|
||||
id="advanced_chat",
|
||||
),
|
||||
pytest.param(
|
||||
_build_workflow_generate_entity_for_roundtrip(),
|
||||
id="workflow",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_workflow_resumption_context_dumps_loads_roundtrip(state: WorkflowResumptionContext):
|
||||
"""WorkflowResumptionContext roundtrip preserves workflow generate entity metadata."""
|
||||
dumped = state.dumps()
|
||||
loaded = WorkflowResumptionContext.loads(dumped)
|
||||
|
||||
assert loaded == state
|
||||
assert loaded.serialized_graph_runtime_state == state.serialized_graph_runtime_state
|
||||
restored_entity = loaded.get_generate_entity()
|
||||
assert isinstance(restored_entity, type(state.generate_entity.entity))
|
||||
54
dify/api/tests/unit_tests/core/file/test_models.py
Normal file
54
dify/api/tests/unit_tests/core/file/test_models.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
|
||||
|
||||
def test_file():
|
||||
file = File(
|
||||
id="test-file",
|
||||
tenant_id="test-tenant-id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="test-related-id",
|
||||
filename="image.png",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=67,
|
||||
storage_key="test-storage-key",
|
||||
url="https://example.com/image.png",
|
||||
)
|
||||
assert file.tenant_id == "test-tenant-id"
|
||||
assert file.type == FileType.IMAGE
|
||||
assert file.transfer_method == FileTransferMethod.TOOL_FILE
|
||||
assert file.related_id == "test-related-id"
|
||||
assert file.filename == "image.png"
|
||||
assert file.extension == ".png"
|
||||
assert file.mime_type == "image/png"
|
||||
assert file.size == 67
|
||||
|
||||
|
||||
def test_file_model_validate_with_legacy_fields():
|
||||
"""Test `File` model can handle data containing compatibility fields."""
|
||||
data = {
|
||||
"id": "test-file",
|
||||
"tenant_id": "test-tenant-id",
|
||||
"type": "image",
|
||||
"transfer_method": "tool_file",
|
||||
"related_id": "test-related-id",
|
||||
"filename": "image.png",
|
||||
"extension": ".png",
|
||||
"mime_type": "image/png",
|
||||
"size": 67,
|
||||
"storage_key": "test-storage-key",
|
||||
"url": "https://example.com/image.png",
|
||||
# Extra legacy fields
|
||||
"tool_file_id": "tool-file-123",
|
||||
"upload_file_id": "upload-file-456",
|
||||
"datasource_file_id": "datasource-file-789",
|
||||
}
|
||||
|
||||
# Should be able to create `File` object without raising an exception
|
||||
file = File.model_validate(data)
|
||||
|
||||
# The File object does not have tool_file_id, upload_file_id, or datasource_file_id as attributes.
|
||||
# Instead, check it does not expose unrecognized legacy fields (should raise on getattr).
|
||||
for legacy_field in ("tool_file_id", "upload_file_id", "datasource_file_id"):
|
||||
assert not hasattr(file, legacy_field)
|
||||
0
dify/api/tests/unit_tests/core/helper/__init__.py
Normal file
0
dify/api/tests/unit_tests/core/helper/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
|
||||
|
||||
|
||||
def test_get_runner_script():
|
||||
code = JavascriptCodeProvider.get_default_code()
|
||||
inputs = {"arg1": "hello, ", "arg2": "world!"}
|
||||
script = NodeJsTemplateTransformer.assemble_runner_script(code, inputs)
|
||||
script_lines = script.splitlines()
|
||||
code_lines = code.splitlines()
|
||||
# Check that the first lines of script are exactly the same as code
|
||||
assert script_lines[: len(code_lines)] == code_lines
|
||||
@@ -0,0 +1,12 @@
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
|
||||
|
||||
|
||||
def test_get_runner_script():
|
||||
code = Python3CodeProvider.get_default_code()
|
||||
inputs = {"arg1": "hello, ", "arg2": "world!"}
|
||||
script = Python3TemplateTransformer.assemble_runner_script(code, inputs)
|
||||
script_lines = script.splitlines()
|
||||
code_lines = code.splitlines()
|
||||
# Check that the first lines of script are exactly the same as code
|
||||
assert script_lines[: len(code_lines)] == code_lines
|
||||
280
dify/api/tests/unit_tests/core/helper/test_encrypter.py
Normal file
280
dify/api/tests/unit_tests/core/helper/test_encrypter.py
Normal file
@@ -0,0 +1,280 @@
|
||||
import base64
|
||||
import binascii
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.helper.encrypter import (
|
||||
batch_decrypt_token,
|
||||
decrypt_token,
|
||||
encrypt_token,
|
||||
get_decrypt_decoding,
|
||||
obfuscated_token,
|
||||
)
|
||||
from libs.rsa import PrivkeyNotFoundError
|
||||
|
||||
|
||||
class TestObfuscatedToken:
|
||||
@pytest.mark.parametrize(
|
||||
("token", "expected"),
|
||||
[
|
||||
("", ""), # Empty token
|
||||
("1234567", "*" * 20), # Short token (<8 chars)
|
||||
("12345678", "*" * 20), # Boundary case (8 chars)
|
||||
("123456789abcdef", "123456" + "*" * 12 + "ef"), # Long token
|
||||
("abc!@#$%^&*()def", "abc!@#" + "*" * 12 + "ef"), # Special chars
|
||||
],
|
||||
)
|
||||
def test_obfuscation_logic(self, token, expected):
|
||||
"""Test core obfuscation logic for various token lengths"""
|
||||
assert obfuscated_token(token) == expected
|
||||
|
||||
def test_sensitive_data_protection(self):
|
||||
"""Ensure obfuscation never reveals full sensitive data"""
|
||||
token = "api_key_secret_12345"
|
||||
obfuscated = obfuscated_token(token)
|
||||
assert token not in obfuscated
|
||||
assert "*" * 12 in obfuscated
|
||||
|
||||
|
||||
class TestEncryptToken:
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_successful_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test successful token encryption"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
|
||||
result = encrypt_token("tenant-123", "test_token")
|
||||
|
||||
assert result == base64.b64encode(b"encrypted_data").decode()
|
||||
mock_encrypt.assert_called_with("test_token", "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
def test_tenant_not_found(self, mock_query):
|
||||
"""Test error when tenant doesn't exist"""
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypt_token("invalid-tenant", "test_token")
|
||||
|
||||
assert "Tenant with id invalid-tenant not found" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestDecryptToken:
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_successful_decryption(self, mock_decrypt):
|
||||
"""Test successful token decryption"""
|
||||
mock_decrypt.return_value = "decrypted_token"
|
||||
encrypted_data = base64.b64encode(b"encrypted_data").decode()
|
||||
|
||||
result = decrypt_token("tenant-123", encrypted_data)
|
||||
|
||||
assert result == "decrypted_token"
|
||||
mock_decrypt.assert_called_once_with(b"encrypted_data", "tenant-123")
|
||||
|
||||
def test_invalid_base64(self):
|
||||
"""Test handling of invalid base64 input"""
|
||||
with pytest.raises(binascii.Error):
|
||||
decrypt_token("tenant-123", "invalid_base64!!!")
|
||||
|
||||
|
||||
class TestBatchDecryptToken:
|
||||
@patch("libs.rsa.get_decrypt_decoding")
|
||||
@patch("libs.rsa.decrypt_token_with_decoding")
|
||||
def test_batch_decryption(self, mock_decrypt_with_decoding, mock_get_decoding):
|
||||
"""Test batch decryption functionality"""
|
||||
mock_rsa_key = MagicMock()
|
||||
mock_cipher_rsa = MagicMock()
|
||||
mock_get_decoding.return_value = (mock_rsa_key, mock_cipher_rsa)
|
||||
|
||||
# Test multiple tokens
|
||||
mock_decrypt_with_decoding.side_effect = ["token1", "token2", "token3"]
|
||||
tokens = [
|
||||
base64.b64encode(b"encrypted1").decode(),
|
||||
base64.b64encode(b"encrypted2").decode(),
|
||||
base64.b64encode(b"encrypted3").decode(),
|
||||
]
|
||||
result = batch_decrypt_token("tenant-123", tokens)
|
||||
|
||||
assert result == ["token1", "token2", "token3"]
|
||||
# Key should only be loaded once
|
||||
mock_get_decoding.assert_called_once_with("tenant-123")
|
||||
|
||||
|
||||
class TestGetDecryptDecoding:
|
||||
@patch("extensions.ext_redis.redis_client.get")
|
||||
@patch("extensions.ext_storage.storage.load")
|
||||
def test_private_key_not_found(self, mock_storage_load, mock_redis_get):
|
||||
"""Test error when private key file doesn't exist"""
|
||||
mock_redis_get.return_value = None
|
||||
mock_storage_load.side_effect = FileNotFoundError()
|
||||
|
||||
with pytest.raises(PrivkeyNotFoundError) as exc_info:
|
||||
get_decrypt_decoding("tenant-123")
|
||||
|
||||
assert "Private key not found, tenant_id: tenant-123" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestEncryptDecryptIntegration:
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
|
||||
"""Test that encryption and decryption are consistent"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
|
||||
# Setup mock encryption/decryption
|
||||
original_token = "test_token_123"
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
mock_decrypt.return_value = original_token
|
||||
|
||||
# Test encryption
|
||||
encrypted = encrypt_token("tenant-123", original_token)
|
||||
|
||||
# Test decryption
|
||||
decrypted = decrypt_token("tenant-123", encrypted)
|
||||
|
||||
assert decrypted == original_token
|
||||
|
||||
|
||||
class TestSecurity:
|
||||
"""Critical security tests for encryption system"""
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
|
||||
"""Ensure tokens encrypted for one tenant cannot be used by another"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "tenant1_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_for_tenant1"
|
||||
|
||||
# Encrypt token for tenant1
|
||||
encrypted = encrypt_token("tenant-123", "sensitive_data")
|
||||
|
||||
# Attempt to decrypt with different tenant should fail
|
||||
with patch("libs.rsa.decrypt") as mock_decrypt:
|
||||
mock_decrypt.side_effect = Exception("Invalid tenant key")
|
||||
|
||||
with pytest.raises(Exception, match="Invalid tenant key"):
|
||||
decrypt_token("different-tenant", encrypted)
|
||||
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_tampered_ciphertext_rejection(self, mock_decrypt):
|
||||
"""Detect and reject tampered ciphertext"""
|
||||
valid_encrypted = base64.b64encode(b"valid_data").decode()
|
||||
|
||||
# Tamper with ciphertext
|
||||
tampered_bytes = bytearray(base64.b64decode(valid_encrypted))
|
||||
tampered_bytes[0] ^= 0xFF
|
||||
tampered = base64.b64encode(bytes(tampered_bytes)).decode()
|
||||
|
||||
mock_decrypt.side_effect = Exception("Decryption error")
|
||||
|
||||
with pytest.raises(Exception, match="Decryption error"):
|
||||
decrypt_token("tenant-123", tampered)
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_encryption_randomness(self, mock_encrypt, mock_query):
|
||||
"""Ensure same plaintext produces different ciphertext"""
|
||||
mock_tenant = MagicMock(encrypt_public_key="key")
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
|
||||
# Different outputs for same input
|
||||
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
|
||||
|
||||
results = [encrypt_token("tenant-123", "token") for _ in range(3)]
|
||||
|
||||
# All results should be different
|
||||
assert len(set(results)) == 3
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Additional security-focused edge case tests"""
|
||||
|
||||
def test_should_handle_empty_string_in_obfuscation(self):
|
||||
"""Test handling of empty string in obfuscation"""
|
||||
# Test empty string (which is a valid str type)
|
||||
assert obfuscated_token("") == ""
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test encryption of empty token"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_empty"
|
||||
|
||||
result = encrypt_token("tenant-123", "")
|
||||
|
||||
assert result == base64.b64encode(b"encrypted_empty").decode()
|
||||
mock_encrypt.assert_called_with("", "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
|
||||
"""Test tokens containing special/unicode characters"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_special"
|
||||
|
||||
# Test various special characters
|
||||
special_tokens = [
|
||||
"token\x00with\x00null", # Null bytes
|
||||
"token_with_emoji_😀🎉", # Unicode emoji
|
||||
"token\nwith\nnewlines", # Newlines
|
||||
"token\twith\ttabs", # Tabs
|
||||
"token_with_中文字符", # Chinese characters
|
||||
]
|
||||
|
||||
for token in special_tokens:
|
||||
result = encrypt_token("tenant-123", token)
|
||||
assert result == base64.b64encode(b"encrypted_special").decode()
|
||||
mock_encrypt.assert_called_with(token, "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
|
||||
"""Test behavior when token exceeds RSA encryption limits"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
|
||||
# RSA 2048-bit can only encrypt ~245 bytes
|
||||
# The actual limit depends on padding scheme
|
||||
mock_encrypt.side_effect = ValueError("Message too long for RSA key size")
|
||||
|
||||
# Create a token that would exceed RSA limits
|
||||
long_token = "x" * 300
|
||||
|
||||
with pytest.raises(ValueError, match="Message too long for RSA key size"):
|
||||
encrypt_token("tenant-123", long_token)
|
||||
|
||||
@patch("libs.rsa.get_decrypt_decoding")
|
||||
@patch("libs.rsa.decrypt_token_with_decoding")
|
||||
def test_batch_decrypt_loads_key_only_once(self, mock_decrypt_with_decoding, mock_get_decoding):
|
||||
"""Verify batch decryption optimization - loads key only once"""
|
||||
mock_rsa_key = MagicMock()
|
||||
mock_cipher_rsa = MagicMock()
|
||||
mock_get_decoding.return_value = (mock_rsa_key, mock_cipher_rsa)
|
||||
|
||||
# Test with multiple tokens
|
||||
mock_decrypt_with_decoding.side_effect = ["token1", "token2", "token3", "token4", "token5"]
|
||||
tokens = [base64.b64encode(f"encrypted{i}".encode()).decode() for i in range(5)]
|
||||
|
||||
result = batch_decrypt_token("tenant-123", tokens)
|
||||
|
||||
assert result == ["token1", "token2", "token3", "token4", "token5"]
|
||||
# Key should only be loaded once regardless of token count
|
||||
mock_get_decoding.assert_called_once_with("tenant-123")
|
||||
assert mock_decrypt_with_decoding.call_count == 5
|
||||
52
dify/api/tests/unit_tests/core/helper/test_ssrf_proxy.py
Normal file
52
dify/api/tests/unit_tests/core/helper/test_ssrf_proxy.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import secrets
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
|
||||
|
||||
|
||||
@patch("httpx.Client.request")
|
||||
def test_successful_request(mock_request):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
response = make_request("GET", "http://example.com")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@patch("httpx.Client.request")
|
||||
def test_retry_exceed_max_retries(mock_request):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
|
||||
side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES
|
||||
mock_request.side_effect = side_effects
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
|
||||
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
|
||||
|
||||
|
||||
@patch("httpx.Client.request")
|
||||
def test_retry_logic_success(mock_request):
|
||||
side_effects = []
|
||||
|
||||
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
|
||||
status_code = secrets.choice(STATUS_FORCELIST)
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
side_effects.append(mock_response)
|
||||
|
||||
mock_response_200 = MagicMock()
|
||||
mock_response_200.status_code = 200
|
||||
side_effects.append(mock_response_200)
|
||||
|
||||
mock_request.side_effect = side_effects
|
||||
|
||||
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
|
||||
assert mock_request.call_args_list[0][1].get("method") == "GET"
|
||||
@@ -0,0 +1,86 @@
|
||||
import pytest
|
||||
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args, get_external_trace_id, is_valid_trace_id
|
||||
|
||||
|
||||
class DummyRequest:
|
||||
def __init__(self, headers=None, args=None, json=None, is_json=False):
|
||||
self.headers = headers or {}
|
||||
self.args = args or {}
|
||||
self.json = json
|
||||
self.is_json = is_json
|
||||
|
||||
|
||||
class TestTraceIdHelper:
|
||||
"""Test cases for trace_id_helper.py"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("trace_id", "expected"),
|
||||
[
|
||||
("abc123", True),
|
||||
("A-B_C-123", True),
|
||||
("a" * 128, True),
|
||||
("", False),
|
||||
("a" * 129, False),
|
||||
("abc!@#", False),
|
||||
("空格", False),
|
||||
("with space", False),
|
||||
],
|
||||
)
|
||||
def test_is_valid_trace_id(self, trace_id, expected):
|
||||
"""Test trace_id validation for various cases"""
|
||||
assert is_valid_trace_id(trace_id) is expected
|
||||
|
||||
def test_get_external_trace_id_from_header(self):
|
||||
"""Should extract valid trace_id from header"""
|
||||
req = DummyRequest(headers={"X-Trace-Id": "abc123"})
|
||||
assert get_external_trace_id(req) == "abc123"
|
||||
|
||||
def test_get_external_trace_id_from_args(self):
|
||||
"""Should extract valid trace_id from args if header missing"""
|
||||
req = DummyRequest(args={"trace_id": "abc123"})
|
||||
assert get_external_trace_id(req) == "abc123"
|
||||
|
||||
def test_get_external_trace_id_from_json(self):
|
||||
"""Should extract valid trace_id from JSON body if header and args missing"""
|
||||
req = DummyRequest(is_json=True, json={"trace_id": "abc123"})
|
||||
assert get_external_trace_id(req) == "abc123"
|
||||
|
||||
def test_get_external_trace_id_priority(self):
|
||||
"""Header > args > json priority"""
|
||||
req = DummyRequest(
|
||||
headers={"X-Trace-Id": "header_id"},
|
||||
args={"trace_id": "args_id"},
|
||||
is_json=True,
|
||||
json={"trace_id": "json_id"},
|
||||
)
|
||||
assert get_external_trace_id(req) == "header_id"
|
||||
req2 = DummyRequest(args={"trace_id": "args_id"}, is_json=True, json={"trace_id": "json_id"})
|
||||
assert get_external_trace_id(req2) == "args_id"
|
||||
req3 = DummyRequest(is_json=True, json={"trace_id": "json_id"})
|
||||
assert get_external_trace_id(req3) == "json_id"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"req",
|
||||
[
|
||||
DummyRequest(headers={"X-Trace-Id": "!!!"}),
|
||||
DummyRequest(args={"trace_id": "!!!"}),
|
||||
DummyRequest(is_json=True, json={"trace_id": "!!!"}),
|
||||
DummyRequest(),
|
||||
],
|
||||
)
|
||||
def test_get_external_trace_id_invalid(self, req):
|
||||
"""Should return None for invalid or missing trace_id"""
|
||||
assert get_external_trace_id(req) is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("args", "expected"),
|
||||
[
|
||||
({"external_trace_id": "abc123"}, {"external_trace_id": "abc123"}),
|
||||
({"other": "value"}, {}),
|
||||
({}, {}),
|
||||
],
|
||||
)
|
||||
def test_extract_external_trace_id_from_args(self, args, expected):
|
||||
"""Test extraction of external_trace_id from args mapping"""
|
||||
assert extract_external_trace_id_from_args(args) == expected
|
||||
0
dify/api/tests/unit_tests/core/mcp/__init__.py
Normal file
0
dify/api/tests/unit_tests/core/mcp/__init__.py
Normal file
0
dify/api/tests/unit_tests/core/mcp/auth/__init__.py
Normal file
0
dify/api/tests/unit_tests/core/mcp/auth/__init__.py
Normal file
766
dify/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py
Normal file
766
dify/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py
Normal file
@@ -0,0 +1,766 @@
|
||||
"""Unit tests for MCP OAuth authentication flow."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.auth.auth_flow import (
|
||||
OAUTH_STATE_EXPIRY_SECONDS,
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX,
|
||||
OAuthCallbackState,
|
||||
_create_secure_redis_state,
|
||||
_retrieve_redis_state,
|
||||
auth,
|
||||
check_support_resource_discovery,
|
||||
discover_oauth_metadata,
|
||||
exchange_authorization,
|
||||
generate_pkce_challenge,
|
||||
handle_callback,
|
||||
refresh_authorization,
|
||||
register_client,
|
||||
start_authorization,
|
||||
)
|
||||
from core.mcp.entities import AuthActionType, AuthResult
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthTokens,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
|
||||
|
||||
class TestPKCEGeneration:
|
||||
"""Test PKCE challenge generation."""
|
||||
|
||||
def test_generate_pkce_challenge(self):
|
||||
"""Test PKCE challenge and verifier generation."""
|
||||
code_verifier, code_challenge = generate_pkce_challenge()
|
||||
|
||||
# Verify format - should be URL-safe base64 without padding
|
||||
assert "=" not in code_verifier
|
||||
assert "+" not in code_verifier
|
||||
assert "/" not in code_verifier
|
||||
assert "=" not in code_challenge
|
||||
assert "+" not in code_challenge
|
||||
assert "/" not in code_challenge
|
||||
|
||||
# Verify length
|
||||
assert len(code_verifier) > 40 # Should be around 54 characters
|
||||
assert len(code_challenge) > 40 # Should be around 43 characters
|
||||
|
||||
def test_generate_pkce_challenge_uniqueness(self):
|
||||
"""Test that PKCE generation produces unique values."""
|
||||
results = set()
|
||||
for _ in range(10):
|
||||
code_verifier, code_challenge = generate_pkce_challenge()
|
||||
results.add((code_verifier, code_challenge))
|
||||
|
||||
# All should be unique
|
||||
assert len(results) == 10
|
||||
|
||||
|
||||
class TestRedisStateManagement:
|
||||
"""Test Redis state management functions."""
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.redis_client")
|
||||
def test_create_secure_redis_state(self, mock_redis):
|
||||
"""Test creating secure Redis state."""
|
||||
state_data = OAuthCallbackState(
|
||||
provider_id="test-provider",
|
||||
tenant_id="test-tenant",
|
||||
server_url="https://example.com",
|
||||
metadata=None,
|
||||
client_information=OAuthClientInformation(client_id="test-client"),
|
||||
code_verifier="test-verifier",
|
||||
redirect_uri="https://redirect.example.com",
|
||||
)
|
||||
|
||||
state_key = _create_secure_redis_state(state_data)
|
||||
|
||||
# Verify state key format
|
||||
assert len(state_key) > 20 # Should be a secure random token
|
||||
|
||||
# Verify Redis call
|
||||
mock_redis.setex.assert_called_once()
|
||||
call_args = mock_redis.setex.call_args
|
||||
assert call_args[0][0].startswith(OAUTH_STATE_REDIS_KEY_PREFIX)
|
||||
assert call_args[0][1] == OAUTH_STATE_EXPIRY_SECONDS
|
||||
assert state_data.model_dump_json() in call_args[0][2]
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.redis_client")
|
||||
def test_retrieve_redis_state_success(self, mock_redis):
|
||||
"""Test retrieving state from Redis."""
|
||||
state_data = OAuthCallbackState(
|
||||
provider_id="test-provider",
|
||||
tenant_id="test-tenant",
|
||||
server_url="https://example.com",
|
||||
metadata=None,
|
||||
client_information=OAuthClientInformation(client_id="test-client"),
|
||||
code_verifier="test-verifier",
|
||||
redirect_uri="https://redirect.example.com",
|
||||
)
|
||||
mock_redis.get.return_value = state_data.model_dump_json()
|
||||
|
||||
result = _retrieve_redis_state("test-state-key")
|
||||
|
||||
# Verify result
|
||||
assert result.provider_id == "test-provider"
|
||||
assert result.tenant_id == "test-tenant"
|
||||
assert result.server_url == "https://example.com"
|
||||
|
||||
# Verify Redis calls
|
||||
mock_redis.get.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key")
|
||||
mock_redis.delete.assert_called_once_with(f"{OAUTH_STATE_REDIS_KEY_PREFIX}test-state-key")
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.redis_client")
|
||||
def test_retrieve_redis_state_not_found(self, mock_redis):
|
||||
"""Test retrieving non-existent state from Redis."""
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_retrieve_redis_state("nonexistent-key")
|
||||
|
||||
assert "State parameter has expired or does not exist" in str(exc_info.value)
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.redis_client")
|
||||
def test_retrieve_redis_state_invalid_json(self, mock_redis):
|
||||
"""Test retrieving invalid JSON state from Redis."""
|
||||
mock_redis.get.return_value = '{"invalid": json}'
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
_retrieve_redis_state("test-key")
|
||||
|
||||
assert "Invalid state parameter" in str(exc_info.value)
|
||||
# State should still be deleted
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
class TestOAuthDiscovery:
|
||||
"""Test OAuth discovery functions."""
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_check_support_resource_discovery_success(self, mock_get):
|
||||
"""Test successful resource discovery check."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
supported, auth_url = check_support_resource_discovery("https://api.example.com/endpoint")
|
||||
|
||||
assert supported is True
|
||||
assert auth_url == "https://auth.example.com"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.example.com/.well-known/oauth-protected-resource",
|
||||
headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"},
|
||||
)
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_check_support_resource_discovery_not_supported(self, mock_get):
|
||||
"""Test resource discovery not supported."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
supported, auth_url = check_support_resource_discovery("https://api.example.com")
|
||||
|
||||
assert supported is False
|
||||
assert auth_url == ""
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_check_support_resource_discovery_with_query_fragment(self, mock_get):
|
||||
"""Test resource discovery with query and fragment."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"authorization_server_url": ["https://auth.example.com"]}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
supported, auth_url = check_support_resource_discovery("https://api.example.com/path?query=1#fragment")
|
||||
|
||||
assert supported is True
|
||||
assert auth_url == "https://auth.example.com"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment",
|
||||
headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"},
|
||||
)
|
||||
|
||||
def test_discover_oauth_metadata_with_resource_discovery(self):
|
||||
"""Test OAuth metadata discovery with resource discovery support."""
|
||||
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
|
||||
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
|
||||
# Mock protected resource metadata with auth server URL
|
||||
mock_prm.return_value = ProtectedResourceMetadata(
|
||||
resource="https://api.example.com",
|
||||
authorization_servers=["https://auth.example.com"],
|
||||
)
|
||||
|
||||
# Mock OAuth authorization server metadata
|
||||
mock_asm.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
)
|
||||
|
||||
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
|
||||
|
||||
assert oauth_metadata is not None
|
||||
assert oauth_metadata.authorization_endpoint == "https://auth.example.com/authorize"
|
||||
assert oauth_metadata.token_endpoint == "https://auth.example.com/token"
|
||||
assert prm is not None
|
||||
assert prm.authorization_servers == ["https://auth.example.com"]
|
||||
|
||||
# Verify the discovery functions were called
|
||||
mock_prm.assert_called_once()
|
||||
mock_asm.assert_called_once()
|
||||
|
||||
def test_discover_oauth_metadata_without_resource_discovery(self):
|
||||
"""Test OAuth metadata discovery without resource discovery."""
|
||||
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
|
||||
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
|
||||
# Mock no protected resource metadata
|
||||
mock_prm.return_value = None
|
||||
|
||||
# Mock OAuth authorization server metadata
|
||||
mock_asm.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://api.example.com/oauth/authorize",
|
||||
token_endpoint="https://api.example.com/oauth/token",
|
||||
response_types_supported=["code"],
|
||||
)
|
||||
|
||||
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
|
||||
|
||||
assert oauth_metadata is not None
|
||||
assert oauth_metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
|
||||
assert prm is None
|
||||
|
||||
# Verify the discovery functions were called
|
||||
mock_prm.assert_called_once()
|
||||
mock_asm.assert_called_once()
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_oauth_metadata_not_found(self, mock_get):
|
||||
"""Test OAuth metadata discovery when not found."""
|
||||
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
|
||||
mock_check.return_value = (False, "")
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
|
||||
|
||||
assert oauth_metadata is None
|
||||
|
||||
|
||||
class TestAuthorizationFlow:
|
||||
"""Test authorization flow functions."""
|
||||
|
||||
@patch("core.mcp.auth.auth_flow._create_secure_redis_state")
|
||||
def test_start_authorization_with_metadata(self, mock_create_state):
|
||||
"""Test starting authorization with metadata."""
|
||||
mock_create_state.return_value = "secure-state-key"
|
||||
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
code_challenge_methods_supported=["S256"],
|
||||
)
|
||||
client_info = OAuthClientInformation(client_id="test-client-id")
|
||||
|
||||
auth_url, code_verifier = start_authorization(
|
||||
"https://api.example.com",
|
||||
metadata,
|
||||
client_info,
|
||||
"https://redirect.example.com",
|
||||
"provider-id",
|
||||
"tenant-id",
|
||||
)
|
||||
|
||||
# Verify URL format
|
||||
assert auth_url.startswith("https://auth.example.com/authorize?")
|
||||
assert "response_type=code" in auth_url
|
||||
assert "client_id=test-client-id" in auth_url
|
||||
assert "code_challenge=" in auth_url
|
||||
assert "code_challenge_method=S256" in auth_url
|
||||
assert "redirect_uri=https%3A%2F%2Fredirect.example.com" in auth_url
|
||||
assert "state=secure-state-key" in auth_url
|
||||
|
||||
# Verify code verifier
|
||||
assert len(code_verifier) > 40
|
||||
|
||||
# Verify state was stored
|
||||
mock_create_state.assert_called_once()
|
||||
state_data = mock_create_state.call_args[0][0]
|
||||
assert state_data.provider_id == "provider-id"
|
||||
assert state_data.tenant_id == "tenant-id"
|
||||
assert state_data.code_verifier == code_verifier
|
||||
|
||||
def test_start_authorization_without_metadata(self):
|
||||
"""Test starting authorization without metadata."""
|
||||
with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create_state:
|
||||
mock_create_state.return_value = "secure-state-key"
|
||||
|
||||
client_info = OAuthClientInformation(client_id="test-client-id")
|
||||
|
||||
auth_url, code_verifier = start_authorization(
|
||||
"https://api.example.com",
|
||||
None,
|
||||
client_info,
|
||||
"https://redirect.example.com",
|
||||
"provider-id",
|
||||
"tenant-id",
|
||||
)
|
||||
|
||||
# Should use default authorization endpoint
|
||||
assert auth_url.startswith("https://api.example.com/authorize?")
|
||||
|
||||
def test_start_authorization_invalid_metadata(self):
|
||||
"""Test starting authorization with invalid metadata."""
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["token"], # No "code" support
|
||||
code_challenge_methods_supported=["plain"], # No "S256" support
|
||||
)
|
||||
client_info = OAuthClientInformation(client_id="test-client-id")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
start_authorization(
|
||||
"https://api.example.com",
|
||||
metadata,
|
||||
client_info,
|
||||
"https://redirect.example.com",
|
||||
"provider-id",
|
||||
"tenant-id",
|
||||
)
|
||||
|
||||
assert "does not support response type code" in str(exc_info.value)
|
||||
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_exchange_authorization_success(self, mock_post):
|
||||
"""Test successful authorization code exchange."""
|
||||
mock_response = Mock()
|
||||
mock_response.is_success = True
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
)
|
||||
client_info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret")
|
||||
|
||||
tokens = exchange_authorization(
|
||||
"https://api.example.com",
|
||||
metadata,
|
||||
client_info,
|
||||
"auth-code-123",
|
||||
"code-verifier-xyz",
|
||||
"https://redirect.example.com",
|
||||
)
|
||||
|
||||
assert tokens.access_token == "new-access-token"
|
||||
assert tokens.token_type == "Bearer"
|
||||
assert tokens.expires_in == 3600
|
||||
assert tokens.refresh_token == "new-refresh-token"
|
||||
|
||||
# Verify request
|
||||
mock_post.assert_called_once_with(
|
||||
"https://auth.example.com/token",
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-secret",
|
||||
"code": "auth-code-123",
|
||||
"code_verifier": "code-verifier-xyz",
|
||||
"redirect_uri": "https://redirect.example.com",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_exchange_authorization_failure(self, mock_post):
|
||||
"""Test failed authorization code exchange."""
|
||||
mock_response = Mock()
|
||||
mock_response.is_success = False
|
||||
mock_response.status_code = 400
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client_info = OAuthClientInformation(client_id="test-client-id")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
exchange_authorization(
|
||||
"https://api.example.com",
|
||||
None,
|
||||
client_info,
|
||||
"invalid-code",
|
||||
"code-verifier",
|
||||
"https://redirect.example.com",
|
||||
)
|
||||
|
||||
assert "Token exchange failed: HTTP 400" in str(exc_info.value)
|
||||
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_refresh_authorization_success(self, mock_post):
|
||||
"""Test successful token refresh."""
|
||||
mock_response = Mock()
|
||||
mock_response.is_success = True
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "refreshed-access-token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["refresh_token"],
|
||||
)
|
||||
client_info = OAuthClientInformation(client_id="test-client-id")
|
||||
|
||||
tokens = refresh_authorization("https://api.example.com", metadata, client_info, "old-refresh-token")
|
||||
|
||||
assert tokens.access_token == "refreshed-access-token"
|
||||
assert tokens.refresh_token == "new-refresh-token"
|
||||
|
||||
# Verify request
|
||||
mock_post.assert_called_once_with(
|
||||
"https://auth.example.com/token",
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": "test-client-id",
|
||||
"refresh_token": "old-refresh-token",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("core.helper.ssrf_proxy.post")
|
||||
def test_register_client_success(self, mock_post):
|
||||
"""Test successful client registration."""
|
||||
mock_response = Mock()
|
||||
mock_response.is_success = True
|
||||
mock_response.json.return_value = {
|
||||
"client_id": "new-client-id",
|
||||
"client_secret": "new-client-secret",
|
||||
"client_name": "Dify",
|
||||
"redirect_uris": ["https://redirect.example.com"],
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
registration_endpoint="https://auth.example.com/register",
|
||||
response_types_supported=["code"],
|
||||
)
|
||||
client_metadata = OAuthClientMetadata(
|
||||
client_name="Dify",
|
||||
redirect_uris=["https://redirect.example.com"],
|
||||
grant_types=["authorization_code"],
|
||||
response_types=["code"],
|
||||
)
|
||||
|
||||
client_info = register_client("https://api.example.com", metadata, client_metadata)
|
||||
|
||||
assert isinstance(client_info, OAuthClientInformationFull)
|
||||
assert client_info.client_id == "new-client-id"
|
||||
assert client_info.client_secret == "new-client-secret"
|
||||
|
||||
# Verify request
|
||||
mock_post.assert_called_once_with(
|
||||
"https://auth.example.com/register",
|
||||
json=client_metadata.model_dump(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def test_register_client_no_endpoint(self):
|
||||
"""Test client registration when no endpoint available."""
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
registration_endpoint=None,
|
||||
response_types_supported=["code"],
|
||||
)
|
||||
client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://redirect.example.com"])
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
register_client("https://api.example.com", metadata, client_metadata)
|
||||
|
||||
assert "does not support dynamic client registration" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestCallbackHandling:
|
||||
"""Test OAuth callback handling."""
|
||||
|
||||
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
|
||||
@patch("core.mcp.auth.auth_flow.exchange_authorization")
|
||||
def test_handle_callback_success(self, mock_exchange, mock_retrieve_state):
|
||||
"""Test successful callback handling."""
|
||||
# Setup state
|
||||
state_data = OAuthCallbackState(
|
||||
provider_id="test-provider",
|
||||
tenant_id="test-tenant",
|
||||
server_url="https://api.example.com",
|
||||
metadata=None,
|
||||
client_information=OAuthClientInformation(client_id="test-client"),
|
||||
code_verifier="test-verifier",
|
||||
redirect_uri="https://redirect.example.com",
|
||||
)
|
||||
mock_retrieve_state.return_value = state_data
|
||||
|
||||
# Setup token exchange
|
||||
tokens = OAuthTokens(
|
||||
access_token="new-token",
|
||||
token_type="Bearer",
|
||||
expires_in=3600,
|
||||
)
|
||||
mock_exchange.return_value = tokens
|
||||
|
||||
# Setup service
|
||||
mock_service = Mock()
|
||||
|
||||
state_result, tokens_result = handle_callback("state-key", "auth-code")
|
||||
|
||||
assert state_result == state_data
|
||||
assert tokens_result == tokens
|
||||
|
||||
# Verify calls
|
||||
mock_retrieve_state.assert_called_once_with("state-key")
|
||||
mock_exchange.assert_called_once_with(
|
||||
"https://api.example.com",
|
||||
None,
|
||||
state_data.client_information,
|
||||
"auth-code",
|
||||
"test-verifier",
|
||||
"https://redirect.example.com",
|
||||
)
|
||||
# Note: handle_callback no longer saves tokens directly, it just returns them
|
||||
# The caller (e.g., controller) is responsible for saving via execute_auth_actions
|
||||
|
||||
|
||||
class TestAuthOrchestration:
|
||||
"""Test the main auth orchestration function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider(self):
|
||||
"""Create a mock provider entity."""
|
||||
provider = Mock(spec=MCPProviderEntity)
|
||||
provider.id = "provider-id"
|
||||
provider.tenant_id = "tenant-id"
|
||||
provider.decrypt_server_url.return_value = "https://api.example.com"
|
||||
provider.client_metadata = OAuthClientMetadata(
|
||||
client_name="Dify",
|
||||
redirect_uris=["https://redirect.example.com"],
|
||||
)
|
||||
provider.redirect_url = "https://redirect.example.com"
|
||||
provider.retrieve_client_information.return_value = None
|
||||
provider.retrieve_tokens.return_value = None
|
||||
return provider
|
||||
|
||||
@pytest.fixture
|
||||
def mock_service(self):
|
||||
"""Create a mock MCP service."""
|
||||
return Mock()
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
@patch("core.mcp.auth.auth_flow.register_client")
|
||||
@patch("core.mcp.auth.auth_flow.start_authorization")
|
||||
def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service):
|
||||
"""Test auth flow for new client registration."""
|
||||
# Setup
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
mock_register.return_value = OAuthClientInformationFull(
|
||||
client_id="new-client-id",
|
||||
client_name="Dify",
|
||||
redirect_uris=["https://redirect.example.com"],
|
||||
)
|
||||
mock_start_auth.return_value = ("https://auth.example.com/authorize?...", "code-verifier")
|
||||
|
||||
result = auth(mock_provider)
|
||||
|
||||
# auth() now returns AuthResult
|
||||
assert isinstance(result, AuthResult)
|
||||
assert result.response == {"authorization_url": "https://auth.example.com/authorize?..."}
|
||||
|
||||
# Verify that the result contains the correct actions
|
||||
assert len(result.actions) == 2
|
||||
# Check for SAVE_CLIENT_INFO action
|
||||
client_info_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CLIENT_INFO)
|
||||
assert client_info_action.data == {"client_information": mock_register.return_value.model_dump()}
|
||||
assert client_info_action.provider_id == "provider-id"
|
||||
assert client_info_action.tenant_id == "tenant-id"
|
||||
|
||||
# Check for SAVE_CODE_VERIFIER action
|
||||
verifier_action = next(a for a in result.actions if a.action_type == AuthActionType.SAVE_CODE_VERIFIER)
|
||||
assert verifier_action.data == {"code_verifier": "code-verifier"}
|
||||
assert verifier_action.provider_id == "provider-id"
|
||||
assert verifier_action.tenant_id == "tenant-id"
|
||||
|
||||
# Verify calls
|
||||
mock_register.assert_called_once()
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
@patch("core.mcp.auth.auth_flow._retrieve_redis_state")
|
||||
@patch("core.mcp.auth.auth_flow.exchange_authorization")
|
||||
def test_auth_exchange_code(self, mock_exchange, mock_retrieve_state, mock_discover, mock_provider, mock_service):
|
||||
"""Test auth flow for exchanging authorization code."""
|
||||
# Setup metadata discovery
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
# Setup existing client
|
||||
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
|
||||
|
||||
# Setup state retrieval
|
||||
state_data = OAuthCallbackState(
|
||||
provider_id="provider-id",
|
||||
tenant_id="tenant-id",
|
||||
server_url="https://api.example.com",
|
||||
metadata=None,
|
||||
client_information=OAuthClientInformation(client_id="existing-client"),
|
||||
code_verifier="test-verifier",
|
||||
redirect_uri="https://redirect.example.com",
|
||||
)
|
||||
mock_retrieve_state.return_value = state_data
|
||||
|
||||
# Setup token exchange
|
||||
tokens = OAuthTokens(access_token="new-token", token_type="Bearer", expires_in=3600)
|
||||
mock_exchange.return_value = tokens
|
||||
|
||||
result = auth(mock_provider, authorization_code="auth-code", state_param="state-key")
|
||||
|
||||
# auth() now returns AuthResult, not a dict
|
||||
assert isinstance(result, AuthResult)
|
||||
assert result.response == {"result": "success"}
|
||||
|
||||
# Verify that the result contains the correct action
|
||||
assert len(result.actions) == 1
|
||||
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
|
||||
assert result.actions[0].data == tokens.model_dump()
|
||||
assert result.actions[0].provider_id == "provider-id"
|
||||
assert result.actions[0].tenant_id == "tenant-id"
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
|
||||
"""Test auth flow fails when exchanging code without state."""
|
||||
# Setup metadata discovery
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
auth(mock_provider, authorization_code="auth-code")
|
||||
|
||||
assert "State parameter is required" in str(exc_info.value)
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.refresh_authorization")
|
||||
def test_auth_refresh_token(self, mock_refresh, mock_provider, mock_service):
|
||||
"""Test auth flow for refreshing tokens."""
|
||||
# Setup existing client and tokens
|
||||
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
|
||||
mock_provider.retrieve_tokens.return_value = OAuthTokens(
|
||||
access_token="old-token",
|
||||
token_type="Bearer",
|
||||
expires_in=0,
|
||||
refresh_token="refresh-token",
|
||||
)
|
||||
|
||||
# Setup refresh
|
||||
new_tokens = OAuthTokens(
|
||||
access_token="refreshed-token",
|
||||
token_type="Bearer",
|
||||
expires_in=3600,
|
||||
refresh_token="new-refresh-token",
|
||||
)
|
||||
mock_refresh.return_value = new_tokens
|
||||
|
||||
with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover:
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
result = auth(mock_provider)
|
||||
|
||||
# auth() now returns AuthResult
|
||||
assert isinstance(result, AuthResult)
|
||||
assert result.response == {"result": "success"}
|
||||
|
||||
# Verify that the result contains the correct action
|
||||
assert len(result.actions) == 1
|
||||
assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS
|
||||
assert result.actions[0].data == new_tokens.model_dump()
|
||||
assert result.actions[0].provider_id == "provider-id"
|
||||
assert result.actions[0].tenant_id == "tenant-id"
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
@patch("core.mcp.auth.auth_flow.discover_oauth_metadata")
|
||||
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
|
||||
"""Test auth fails when no client info exists but code is provided."""
|
||||
# Setup metadata discovery
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
mock_provider.retrieve_client_information.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
auth(mock_provider, authorization_code="auth-code")
|
||||
|
||||
assert "Existing OAuth client information is required" in str(exc_info.value)
|
||||
468
dify/api/tests/unit_tests/core/mcp/client/test_session.py
Normal file
468
dify/api/tests/unit_tests/core/mcp/client/test_session.py
Normal file
@@ -0,0 +1,468 @@
|
||||
import queue
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from core.mcp import types
|
||||
from core.mcp.entities import RequestContext
|
||||
from core.mcp.session.base_session import RequestResponder
|
||||
from core.mcp.session.client_session import DEFAULT_CLIENT_INFO, ClientSession
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
Implementation,
|
||||
InitializedNotification,
|
||||
InitializeRequest,
|
||||
InitializeResult,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
ServerCapabilities,
|
||||
ServerResult,
|
||||
SessionMessage,
|
||||
)
|
||||
|
||||
|
||||
def test_client_session_initialize():
|
||||
# Create synchronous queues to replace async streams
|
||||
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
|
||||
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
|
||||
|
||||
initialized_notification = None
|
||||
|
||||
def mock_server():
|
||||
nonlocal initialized_notification
|
||||
|
||||
# Receive initialization request
|
||||
session_message = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
|
||||
# Create response
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(
|
||||
logging=None,
|
||||
resources=None,
|
||||
tools=None,
|
||||
experimental=None,
|
||||
prompts=None,
|
||||
),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
instructions="The server instructions.",
|
||||
)
|
||||
)
|
||||
|
||||
# Send response
|
||||
server_to_client.put(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Receive initialized notification
|
||||
session_notification = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_notification = session_notification.message
|
||||
assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
|
||||
initialized_notification = ClientNotification.model_validate(
|
||||
jsonrpc_notification.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
|
||||
# Create message handler
|
||||
def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
):
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
# Start mock server thread
|
||||
server_thread = threading.Thread(target=mock_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
# Create and use client session
|
||||
with ClientSession(
|
||||
server_to_client,
|
||||
client_to_server,
|
||||
message_handler=message_handler,
|
||||
) as session:
|
||||
result = session.initialize()
|
||||
|
||||
# Wait for server thread to complete
|
||||
server_thread.join(timeout=10.0)
|
||||
|
||||
# Assert results
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
|
||||
assert isinstance(result.capabilities, ServerCapabilities)
|
||||
assert result.serverInfo == Implementation(name="mock-server", version="0.1.0")
|
||||
assert result.instructions == "The server instructions."
|
||||
|
||||
# Check that client sent initialized notification
|
||||
assert initialized_notification
|
||||
assert isinstance(initialized_notification.root, InitializedNotification)
|
||||
|
||||
|
||||
def test_client_session_custom_client_info():
|
||||
# Create synchronous queues to replace async streams
|
||||
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
|
||||
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
|
||||
|
||||
custom_client_info = Implementation(name="test-client", version="1.2.3")
|
||||
received_client_info = None
|
||||
|
||||
def mock_server():
|
||||
nonlocal received_client_info
|
||||
|
||||
session_message = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
received_client_info = request.root.params.clientInfo
|
||||
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
server_to_client.put(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Receive initialized notification
|
||||
client_to_server.get(timeout=5.0)
|
||||
|
||||
# Start mock server thread
|
||||
server_thread = threading.Thread(target=mock_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
with ClientSession(
|
||||
server_to_client,
|
||||
client_to_server,
|
||||
client_info=custom_client_info,
|
||||
) as session:
|
||||
session.initialize()
|
||||
|
||||
# Wait for server thread to complete
|
||||
server_thread.join(timeout=10.0)
|
||||
|
||||
# Assert that custom client info was sent
|
||||
assert received_client_info == custom_client_info
|
||||
|
||||
|
||||
def test_client_session_default_client_info():
|
||||
# Create synchronous queues to replace async streams
|
||||
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
|
||||
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
|
||||
|
||||
received_client_info = None
|
||||
|
||||
def mock_server():
|
||||
nonlocal received_client_info
|
||||
|
||||
session_message = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
received_client_info = request.root.params.clientInfo
|
||||
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
server_to_client.put(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Receive initialized notification
|
||||
client_to_server.get(timeout=5.0)
|
||||
|
||||
# Start mock server thread
|
||||
server_thread = threading.Thread(target=mock_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
with ClientSession(
|
||||
server_to_client,
|
||||
client_to_server,
|
||||
) as session:
|
||||
session.initialize()
|
||||
|
||||
# Wait for server thread to complete
|
||||
server_thread.join(timeout=10.0)
|
||||
|
||||
# Assert that default client info was used
|
||||
assert received_client_info == DEFAULT_CLIENT_INFO
|
||||
|
||||
|
||||
def test_client_session_version_negotiation_success():
|
||||
# Create synchronous queues to replace async streams
|
||||
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
|
||||
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
|
||||
|
||||
def mock_server():
|
||||
session_message = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
|
||||
# Send supported protocol version
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
server_to_client.put(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Receive initialized notification
|
||||
client_to_server.get(timeout=5.0)
|
||||
|
||||
# Start mock server thread
|
||||
server_thread = threading.Thread(target=mock_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
with ClientSession(
|
||||
server_to_client,
|
||||
client_to_server,
|
||||
) as session:
|
||||
result = session.initialize()
|
||||
|
||||
# Wait for server thread to complete
|
||||
server_thread.join(timeout=10.0)
|
||||
|
||||
# Should successfully initialize
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
|
||||
|
||||
|
||||
def test_client_session_version_negotiation_failure():
|
||||
# Create synchronous queues to replace async streams
|
||||
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
|
||||
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
|
||||
|
||||
def mock_server():
|
||||
session_message = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
|
||||
# Send unsupported protocol version
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion="99.99.99", # Unsupported version
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
server_to_client.put(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Start mock server thread
|
||||
server_thread = threading.Thread(target=mock_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
with ClientSession(
|
||||
server_to_client,
|
||||
client_to_server,
|
||||
) as session:
|
||||
import pytest
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unsupported protocol version"):
|
||||
session.initialize()
|
||||
|
||||
# Wait for server thread to complete
|
||||
server_thread.join(timeout=10.0)
|
||||
|
||||
|
||||
def test_client_capabilities_default():
|
||||
# Create synchronous queues to replace async streams
|
||||
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
|
||||
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
|
||||
|
||||
received_capabilities = None
|
||||
|
||||
def mock_server():
|
||||
nonlocal received_capabilities
|
||||
|
||||
session_message = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
received_capabilities = request.root.params.capabilities
|
||||
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
server_to_client.put(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Receive initialized notification
|
||||
client_to_server.get(timeout=5.0)
|
||||
|
||||
# Start mock server thread
|
||||
server_thread = threading.Thread(target=mock_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
with ClientSession(
|
||||
server_to_client,
|
||||
client_to_server,
|
||||
) as session:
|
||||
session.initialize()
|
||||
|
||||
# Wait for server thread to complete
|
||||
server_thread.join(timeout=10.0)
|
||||
|
||||
# Assert default capabilities
|
||||
assert received_capabilities is not None
|
||||
|
||||
|
||||
def test_client_capabilities_with_custom_callbacks():
|
||||
# Create synchronous queues to replace async streams
|
||||
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
|
||||
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
|
||||
|
||||
def custom_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData:
|
||||
return types.CreateMessageResult(
|
||||
model="test-model",
|
||||
role="assistant",
|
||||
content=types.TextContent(type="text", text="Custom response"),
|
||||
)
|
||||
|
||||
def custom_list_roots_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
) -> types.ListRootsResult | types.ErrorData:
|
||||
return types.ListRootsResult(roots=[])
|
||||
|
||||
def mock_server():
|
||||
session_message = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
server_to_client.put(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Receive initialized notification
|
||||
client_to_server.get(timeout=5.0)
|
||||
|
||||
# Start mock server thread
|
||||
server_thread = threading.Thread(target=mock_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
with ClientSession(
|
||||
server_to_client,
|
||||
client_to_server,
|
||||
sampling_callback=custom_sampling_callback,
|
||||
list_roots_callback=custom_list_roots_callback,
|
||||
) as session:
|
||||
result = session.initialize()
|
||||
|
||||
# Wait for server thread to complete
|
||||
server_thread.join(timeout=10.0)
|
||||
|
||||
# Verify initialization succeeded
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
|
||||
324
dify/api/tests/unit_tests/core/mcp/client/test_sse.py
Normal file
324
dify/api/tests/unit_tests/core/mcp/client/test_sse.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import contextlib
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.mcp import types
|
||||
from core.mcp.client.sse_client import sse_client
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
|
||||
SERVER_NAME = "test_server_for_SSE"
|
||||
|
||||
|
||||
def test_sse_message_id_coercion():
|
||||
"""Test that string message IDs that look like integers are parsed as integers.
|
||||
|
||||
See <https://github.com/modelcontextprotocol/python-sdk/pull/851> for more details.
|
||||
"""
|
||||
json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}'
|
||||
msg = types.JSONRPCMessage.model_validate_json(json_message)
|
||||
expected = types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123))
|
||||
|
||||
# Check if both are JSONRPCRequest instances
|
||||
assert isinstance(msg.root, types.JSONRPCRequest)
|
||||
assert isinstance(expected.root, types.JSONRPCRequest)
|
||||
|
||||
assert msg.root.id == expected.root.id
|
||||
assert msg.root.method == expected.root.method
|
||||
assert msg.root.jsonrpc == expected.root.jsonrpc
|
||||
|
||||
|
||||
class MockSSEClient:
|
||||
"""Mock SSE client for testing."""
|
||||
|
||||
def __init__(self, url: str, headers: dict[str, Any] | None = None):
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.connected = False
|
||||
self.read_queue: queue.Queue = queue.Queue()
|
||||
self.write_queue: queue.Queue = queue.Queue()
|
||||
|
||||
def connect(self):
|
||||
"""Simulate connection establishment."""
|
||||
self.connected = True
|
||||
|
||||
# Send endpoint event
|
||||
endpoint_data = "/messages/?session_id=test-session-123"
|
||||
self.read_queue.put(("endpoint", endpoint_data))
|
||||
|
||||
return self.read_queue, self.write_queue
|
||||
|
||||
def send_initialize_response(self):
|
||||
"""Send a mock initialize response."""
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"result": {
|
||||
"protocolVersion": types.LATEST_PROTOCOL_VERSION,
|
||||
"capabilities": {
|
||||
"logging": None,
|
||||
"resources": None,
|
||||
"tools": None,
|
||||
"experimental": None,
|
||||
"prompts": None,
|
||||
},
|
||||
"serverInfo": {"name": SERVER_NAME, "version": "0.1.0"},
|
||||
"instructions": "Test server instructions.",
|
||||
},
|
||||
}
|
||||
self.read_queue.put(("message", json.dumps(response)))
|
||||
|
||||
|
||||
def test_sse_client_message_id_handling():
|
||||
"""Test SSE client properly handles message ID coercion."""
|
||||
mock_client = MockSSEClient("http://test.example/sse")
|
||||
read_queue, write_queue = mock_client.connect()
|
||||
|
||||
# Send a message with string ID that should be coerced to int
|
||||
message_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "456", # String ID
|
||||
"result": {"test": "data"},
|
||||
}
|
||||
read_queue.put(("message", json.dumps(message_data)))
|
||||
read_queue.get(timeout=1.0)
|
||||
# Get the message from queue
|
||||
event_type, data = read_queue.get(timeout=1.0)
|
||||
assert event_type == "message"
|
||||
|
||||
# Parse the message
|
||||
parsed_message = types.JSONRPCMessage.model_validate_json(data)
|
||||
# Check that it's a JSONRPCResponse and verify the ID
|
||||
assert isinstance(parsed_message.root, types.JSONRPCResponse)
|
||||
assert parsed_message.root.id == 456 # Should be converted to int
|
||||
|
||||
|
||||
def test_sse_client_connection_validation():
|
||||
"""Test SSE client validates endpoint URLs properly."""
|
||||
test_url = "http://test.example/sse"
|
||||
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock the HTTP client
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
# Mock the SSE connection
|
||||
mock_event_source = Mock()
|
||||
mock_event_source.response.raise_for_status.return_value = None
|
||||
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
|
||||
|
||||
# Mock SSE events
|
||||
class MockSSEEvent:
|
||||
def __init__(self, event_type: str, data: str):
|
||||
self.event = event_type
|
||||
self.data = data
|
||||
|
||||
# Simulate endpoint event
|
||||
endpoint_event = MockSSEEvent("endpoint", "/messages/?session_id=test-123")
|
||||
mock_event_source.iter_sse.return_value = [endpoint_event]
|
||||
|
||||
# Test connection
|
||||
with contextlib.suppress(Exception):
|
||||
with sse_client(test_url) as (read_queue, write_queue):
|
||||
assert read_queue is not None
|
||||
assert write_queue is not None
|
||||
|
||||
|
||||
def test_sse_client_error_handling():
|
||||
"""Test SSE client properly handles various error conditions."""
|
||||
test_url = "http://test.example/sse"
|
||||
|
||||
# Test 401 error handling
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock 401 HTTP error
|
||||
mock_response = Mock(status_code=401)
|
||||
mock_response.headers = {"WWW-Authenticate": 'Bearer realm="example"'}
|
||||
mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=mock_response)
|
||||
mock_sse_connect.side_effect = mock_error
|
||||
|
||||
with pytest.raises(MCPAuthError):
|
||||
with sse_client(test_url):
|
||||
pass
|
||||
|
||||
# Test other HTTP errors
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock other HTTP error
|
||||
mock_response = Mock(status_code=500)
|
||||
mock_response.headers = {}
|
||||
mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=mock_response)
|
||||
mock_sse_connect.side_effect = mock_error
|
||||
|
||||
with pytest.raises(MCPConnectionError):
|
||||
with sse_client(test_url):
|
||||
pass
|
||||
|
||||
|
||||
def test_sse_client_timeout_configuration():
|
||||
"""Test SSE client timeout configuration."""
|
||||
test_url = "http://test.example/sse"
|
||||
custom_timeout = 10.0
|
||||
custom_sse_timeout = 300.0
|
||||
custom_headers = {"Authorization": "Bearer test-token"}
|
||||
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock successful connection
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
mock_event_source = Mock()
|
||||
mock_event_source.response.raise_for_status.return_value = None
|
||||
mock_event_source.iter_sse.return_value = []
|
||||
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
with sse_client(
|
||||
test_url, headers=custom_headers, timeout=custom_timeout, sse_read_timeout=custom_sse_timeout
|
||||
) as (read_queue, write_queue):
|
||||
# Verify the configuration was passed correctly
|
||||
mock_client_factory.assert_called_with(headers=custom_headers)
|
||||
|
||||
# Check that timeout was configured
|
||||
call_args = mock_sse_connect.call_args
|
||||
assert call_args is not None
|
||||
timeout_arg = call_args[1]["timeout"]
|
||||
assert timeout_arg.read == custom_sse_timeout
|
||||
|
||||
|
||||
def test_sse_transport_endpoint_validation():
|
||||
"""Test SSE transport validates endpoint URLs correctly."""
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
|
||||
# Valid endpoint (same origin)
|
||||
valid_endpoint = "http://example.com/messages/session123"
|
||||
assert transport._validate_endpoint_url(valid_endpoint) == True
|
||||
|
||||
# Invalid endpoint (different origin)
|
||||
invalid_endpoint = "http://malicious.com/messages/session123"
|
||||
assert transport._validate_endpoint_url(invalid_endpoint) == False
|
||||
|
||||
# Invalid endpoint (different scheme)
|
||||
invalid_scheme = "https://example.com/messages/session123"
|
||||
assert transport._validate_endpoint_url(invalid_scheme) == False
|
||||
|
||||
|
||||
def test_sse_transport_message_parsing():
|
||||
"""Test SSE transport properly parses different message types."""
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
read_queue: queue.Queue = queue.Queue()
|
||||
|
||||
# Test valid JSON-RPC message
|
||||
valid_message = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
|
||||
transport._handle_message_event(valid_message, read_queue)
|
||||
|
||||
# Should have a SessionMessage in the queue
|
||||
message = read_queue.get(timeout=1.0)
|
||||
assert message is not None
|
||||
assert hasattr(message, "message")
|
||||
|
||||
# Test invalid JSON
|
||||
invalid_json = '{"invalid": json}'
|
||||
transport._handle_message_event(invalid_json, read_queue)
|
||||
|
||||
# Should have an exception in the queue
|
||||
error = read_queue.get(timeout=1.0)
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
|
||||
def test_sse_client_queue_cleanup():
|
||||
"""Test that SSE client properly cleans up queues on exit."""
|
||||
test_url = "http://test.example/sse"
|
||||
|
||||
read_queue = None
|
||||
write_queue = None
|
||||
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock connection that raises an exception
|
||||
mock_sse_connect.side_effect = Exception("Connection failed")
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
with sse_client(test_url) as (rq, wq):
|
||||
read_queue = rq
|
||||
write_queue = wq
|
||||
|
||||
# Queues should be cleaned up even on exception
|
||||
# Note: In real implementation, cleanup should put None to signal shutdown
|
||||
|
||||
|
||||
def test_sse_client_headers_propagation():
|
||||
"""Test that custom headers are properly propagated in SSE client."""
|
||||
test_url = "http://test.example/sse"
|
||||
custom_headers = {
|
||||
"Authorization": "Bearer test-token",
|
||||
"X-Custom-Header": "test-value",
|
||||
"User-Agent": "test-client/1.0",
|
||||
}
|
||||
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock the client factory to capture headers
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
# Mock the SSE connection
|
||||
mock_event_source = Mock()
|
||||
mock_event_source.response.raise_for_status.return_value = None
|
||||
mock_event_source.iter_sse.return_value = []
|
||||
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
with sse_client(test_url, headers=custom_headers):
|
||||
pass
|
||||
|
||||
# Verify headers were passed to client factory
|
||||
mock_client_factory.assert_called_with(headers=custom_headers)
|
||||
|
||||
|
||||
def test_sse_client_concurrent_access():
|
||||
"""Test SSE client behavior with concurrent queue access."""
|
||||
test_read_queue: queue.Queue = queue.Queue()
|
||||
|
||||
# Simulate concurrent producers and consumers
|
||||
def producer():
|
||||
for i in range(10):
|
||||
test_read_queue.put(f"message_{i}")
|
||||
time.sleep(0.01) # Small delay to simulate real conditions
|
||||
|
||||
def consumer():
|
||||
received = []
|
||||
for _ in range(10):
|
||||
try:
|
||||
msg = test_read_queue.get(timeout=2.0)
|
||||
received.append(msg)
|
||||
except queue.Empty:
|
||||
break
|
||||
return received
|
||||
|
||||
# Start producer in separate thread
|
||||
producer_thread = threading.Thread(target=producer, daemon=True)
|
||||
producer_thread.start()
|
||||
|
||||
# Consume messages
|
||||
received_messages = consumer()
|
||||
|
||||
# Wait for producer to finish
|
||||
producer_thread.join(timeout=5.0)
|
||||
|
||||
# Verify all messages were received
|
||||
assert len(received_messages) == 10
|
||||
for i in range(10):
|
||||
assert f"message_{i}" in received_messages
|
||||
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
Tests for the StreamableHTTP client transport.
|
||||
|
||||
Contains tests for only the client side of the StreamableHTTP transport.
|
||||
"""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.mcp import types
|
||||
from core.mcp.client.streamable_client import streamablehttp_client
|
||||
|
||||
# Test constants
|
||||
SERVER_NAME = "test_streamable_http_server"
|
||||
TEST_SESSION_ID = "test-session-id-12345"
|
||||
INIT_REQUEST = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"clientInfo": {"name": "test-client", "version": "1.0"},
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": {},
|
||||
},
|
||||
"id": "init-1",
|
||||
}
|
||||
|
||||
|
||||
class MockStreamableHTTPClient:
|
||||
"""Mock StreamableHTTP client for testing."""
|
||||
|
||||
def __init__(self, url: str, headers: dict[str, Any] | None = None):
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.connected = False
|
||||
self.read_queue: queue.Queue = queue.Queue()
|
||||
self.write_queue: queue.Queue = queue.Queue()
|
||||
self.session_id = TEST_SESSION_ID
|
||||
|
||||
def connect(self):
|
||||
"""Simulate connection establishment."""
|
||||
self.connected = True
|
||||
return self.read_queue, self.write_queue, lambda: self.session_id
|
||||
|
||||
def send_initialize_response(self):
|
||||
"""Send a mock initialize response."""
|
||||
session_message = types.SessionMessage(
|
||||
message=types.JSONRPCMessage(
|
||||
root=types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id="init-1",
|
||||
result={
|
||||
"protocolVersion": types.LATEST_PROTOCOL_VERSION,
|
||||
"capabilities": {
|
||||
"logging": None,
|
||||
"resources": None,
|
||||
"tools": None,
|
||||
"experimental": None,
|
||||
"prompts": None,
|
||||
},
|
||||
"serverInfo": {"name": SERVER_NAME, "version": "0.1.0"},
|
||||
"instructions": "Test server instructions.",
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
self.read_queue.put(session_message)
|
||||
|
||||
def send_tools_response(self):
|
||||
"""Send a mock tools list response."""
|
||||
session_message = types.SessionMessage(
|
||||
message=types.JSONRPCMessage(
|
||||
root=types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id="tools-1",
|
||||
result={
|
||||
"tools": [
|
||||
{
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"inputSchema": {"type": "object", "properties": {}},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
self.read_queue.put(session_message)
|
||||
|
||||
|
||||
def test_streamablehttp_client_message_id_handling():
|
||||
"""Test StreamableHTTP client properly handles message ID coercion."""
|
||||
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
|
||||
read_queue, write_queue, get_session_id = mock_client.connect()
|
||||
|
||||
# Send a message with string ID that should be coerced to int
|
||||
response_message = types.SessionMessage(
|
||||
message=types.JSONRPCMessage(root=types.JSONRPCResponse(jsonrpc="2.0", id="789", result={"test": "data"}))
|
||||
)
|
||||
read_queue.put(response_message)
|
||||
|
||||
# Get the message from queue
|
||||
message = read_queue.get(timeout=1.0)
|
||||
assert message is not None
|
||||
assert isinstance(message, types.SessionMessage)
|
||||
|
||||
# Check that the ID was properly handled
|
||||
assert isinstance(message.message.root, types.JSONRPCResponse)
|
||||
assert message.message.root.id == 789 # ID should be coerced to int due to union_mode="left_to_right"
|
||||
|
||||
|
||||
def test_streamablehttp_client_connection_validation():
|
||||
"""Test StreamableHTTP client validates connections properly."""
|
||||
test_url = "http://test.example/mcp"
|
||||
|
||||
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
# Mock the HTTP client
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
# Mock successful response
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Test connection
|
||||
try:
|
||||
with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
|
||||
assert read_queue is not None
|
||||
assert write_queue is not None
|
||||
assert get_session_id is not None
|
||||
except Exception:
|
||||
# Connection might fail due to mocking, but we're testing the validation logic
|
||||
pass
|
||||
|
||||
|
||||
def test_streamablehttp_client_timeout_configuration():
|
||||
"""Test StreamableHTTP client timeout configuration."""
|
||||
test_url = "http://test.example/mcp"
|
||||
custom_headers = {"Authorization": "Bearer test-token"}
|
||||
|
||||
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
# Mock successful connection
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
try:
|
||||
with streamablehttp_client(test_url, headers=custom_headers) as (read_queue, write_queue, get_session_id):
|
||||
# Verify the configuration was passed correctly
|
||||
mock_client_factory.assert_called_with(headers=custom_headers)
|
||||
except Exception:
|
||||
# Connection might fail due to mocking, but we tested the configuration
|
||||
pass
|
||||
|
||||
|
||||
def test_streamablehttp_client_session_id_handling():
|
||||
"""Test StreamableHTTP client properly handles session IDs."""
|
||||
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
|
||||
read_queue, write_queue, get_session_id = mock_client.connect()
|
||||
|
||||
# Test that session ID is available
|
||||
session_id = get_session_id()
|
||||
assert session_id == TEST_SESSION_ID
|
||||
|
||||
# Test that we can use the session ID in subsequent requests
|
||||
assert session_id is not None
|
||||
assert len(session_id) > 0
|
||||
|
||||
|
||||
def test_streamablehttp_client_message_parsing():
|
||||
"""Test StreamableHTTP client properly parses different message types."""
|
||||
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
|
||||
read_queue, write_queue, get_session_id = mock_client.connect()
|
||||
|
||||
# Test valid initialization response
|
||||
mock_client.send_initialize_response()
|
||||
|
||||
# Should have a SessionMessage in the queue
|
||||
message = read_queue.get(timeout=1.0)
|
||||
assert message is not None
|
||||
assert isinstance(message, types.SessionMessage)
|
||||
assert isinstance(message.message.root, types.JSONRPCResponse)
|
||||
|
||||
# Test tools response
|
||||
mock_client.send_tools_response()
|
||||
|
||||
tools_message = read_queue.get(timeout=1.0)
|
||||
assert tools_message is not None
|
||||
assert isinstance(tools_message, types.SessionMessage)
|
||||
|
||||
|
||||
def test_streamablehttp_client_queue_cleanup():
|
||||
"""Test that StreamableHTTP client properly cleans up queues on exit."""
|
||||
test_url = "http://test.example/mcp"
|
||||
|
||||
read_queue = None
|
||||
write_queue = None
|
||||
|
||||
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
# Mock connection that raises an exception
|
||||
mock_client_factory.side_effect = Exception("Connection failed")
|
||||
|
||||
try:
|
||||
with streamablehttp_client(test_url) as (rq, wq, get_session_id):
|
||||
read_queue = rq
|
||||
write_queue = wq
|
||||
except Exception:
|
||||
pass # Expected to fail
|
||||
|
||||
# Queues should be cleaned up even on exception
|
||||
# Note: In real implementation, cleanup should put None to signal shutdown
|
||||
|
||||
|
||||
def test_streamablehttp_client_headers_propagation():
|
||||
"""Test that custom headers are properly propagated in StreamableHTTP client."""
|
||||
test_url = "http://test.example/mcp"
|
||||
custom_headers = {
|
||||
"Authorization": "Bearer test-token",
|
||||
"X-Custom-Header": "test-value",
|
||||
"User-Agent": "test-client/1.0",
|
||||
}
|
||||
|
||||
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
# Mock the client factory to capture headers
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
try:
|
||||
with streamablehttp_client(test_url, headers=custom_headers):
|
||||
pass
|
||||
except Exception:
|
||||
pass # Expected due to mocking
|
||||
|
||||
# Verify headers were passed to client factory
|
||||
# Check that the call was made with headers that include our custom headers
|
||||
mock_client_factory.assert_called_once()
|
||||
call_args = mock_client_factory.call_args
|
||||
assert "headers" in call_args.kwargs
|
||||
passed_headers = call_args.kwargs["headers"]
|
||||
|
||||
# Verify all custom headers are present
|
||||
for key, value in custom_headers.items():
|
||||
assert key in passed_headers
|
||||
assert passed_headers[key] == value
|
||||
|
||||
|
||||
def test_streamablehttp_client_concurrent_access():
|
||||
"""Test StreamableHTTP client behavior with concurrent queue access."""
|
||||
test_read_queue: queue.Queue = queue.Queue()
|
||||
test_write_queue: queue.Queue = queue.Queue()
|
||||
|
||||
# Simulate concurrent producers and consumers
|
||||
def producer():
|
||||
for i in range(10):
|
||||
test_read_queue.put(f"message_{i}")
|
||||
time.sleep(0.01) # Small delay to simulate real conditions
|
||||
|
||||
def consumer():
|
||||
received = []
|
||||
for _ in range(10):
|
||||
try:
|
||||
msg = test_read_queue.get(timeout=2.0)
|
||||
received.append(msg)
|
||||
except queue.Empty:
|
||||
break
|
||||
return received
|
||||
|
||||
# Start producer in separate thread
|
||||
producer_thread = threading.Thread(target=producer, daemon=True)
|
||||
producer_thread.start()
|
||||
|
||||
# Consume messages
|
||||
received_messages = consumer()
|
||||
|
||||
# Wait for producer to finish
|
||||
producer_thread.join(timeout=5.0)
|
||||
|
||||
# Verify all messages were received
|
||||
assert len(received_messages) == 10
|
||||
for i in range(10):
|
||||
assert f"message_{i}" in received_messages
|
||||
|
||||
|
||||
def test_streamablehttp_client_json_vs_sse_mode():
|
||||
"""Test StreamableHTTP client handling of JSON vs SSE response modes."""
|
||||
test_url = "http://test.example/mcp"
|
||||
|
||||
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
# Mock JSON response
|
||||
mock_json_response = Mock()
|
||||
mock_json_response.status_code = 200
|
||||
mock_json_response.headers = {"content-type": "application/json"}
|
||||
mock_json_response.json.return_value = {"result": "json_mode"}
|
||||
mock_json_response.raise_for_status.return_value = None
|
||||
|
||||
# Mock SSE response
|
||||
mock_sse_response = Mock()
|
||||
mock_sse_response.status_code = 200
|
||||
mock_sse_response.headers = {"content-type": "text/event-stream"}
|
||||
mock_sse_response.raise_for_status.return_value = None
|
||||
|
||||
# Test JSON mode
|
||||
mock_client.post.return_value = mock_json_response
|
||||
|
||||
try:
|
||||
with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
|
||||
# Should handle JSON responses
|
||||
assert read_queue is not None
|
||||
assert write_queue is not None
|
||||
except Exception:
|
||||
pass # Expected due to mocking
|
||||
|
||||
# Test SSE mode
|
||||
mock_client.post.return_value = mock_sse_response
|
||||
|
||||
try:
|
||||
with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
|
||||
# Should handle SSE responses
|
||||
assert read_queue is not None
|
||||
assert write_queue is not None
|
||||
except Exception:
|
||||
pass # Expected due to mocking
|
||||
|
||||
|
||||
def test_streamablehttp_client_terminate_on_close():
|
||||
"""Test StreamableHTTP client terminate_on_close parameter."""
|
||||
test_url = "http://test.example/mcp"
|
||||
|
||||
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.delete.return_value = mock_response
|
||||
|
||||
# Test with terminate_on_close=True (default)
|
||||
try:
|
||||
with streamablehttp_client(test_url, terminate_on_close=True) as (read_queue, write_queue, get_session_id):
|
||||
pass
|
||||
except Exception:
|
||||
pass # Expected due to mocking
|
||||
|
||||
# Test with terminate_on_close=False
|
||||
try:
|
||||
with streamablehttp_client(test_url, terminate_on_close=False) as (read_queue, write_queue, get_session_id):
|
||||
pass
|
||||
except Exception:
|
||||
pass # Expected due to mocking
|
||||
|
||||
|
||||
def test_streamablehttp_client_protocol_version_handling():
|
||||
"""Test StreamableHTTP client protocol version handling."""
|
||||
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
|
||||
read_queue, write_queue, get_session_id = mock_client.connect()
|
||||
|
||||
# Send initialize response with specific protocol version
|
||||
|
||||
session_message = types.SessionMessage(
|
||||
message=types.JSONRPCMessage(
|
||||
root=types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id="init-1",
|
||||
result={
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"serverInfo": {"name": SERVER_NAME, "version": "0.1.0"},
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
read_queue.put(session_message)
|
||||
|
||||
# Get the message and verify protocol version
|
||||
message = read_queue.get(timeout=1.0)
|
||||
assert message is not None
|
||||
assert isinstance(message.message.root, types.JSONRPCResponse)
|
||||
result = message.message.root.result
|
||||
assert result["protocolVersion"] == "2024-11-05"
|
||||
|
||||
|
||||
def test_streamablehttp_client_error_response_handling():
|
||||
"""Test StreamableHTTP client handling of error responses."""
|
||||
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
|
||||
read_queue, write_queue, get_session_id = mock_client.connect()
|
||||
|
||||
# Send an error response
|
||||
session_message = types.SessionMessage(
|
||||
message=types.JSONRPCMessage(
|
||||
root=types.JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id="test-1",
|
||||
error=types.ErrorData(code=-32601, message="Method not found", data=None),
|
||||
)
|
||||
)
|
||||
)
|
||||
read_queue.put(session_message)
|
||||
|
||||
# Get the error message
|
||||
message = read_queue.get(timeout=1.0)
|
||||
assert message is not None
|
||||
assert isinstance(message.message.root, types.JSONRPCError)
|
||||
assert message.message.root.error.code == -32601
|
||||
assert message.message.root.error.message == "Method not found"
|
||||
|
||||
|
||||
def test_streamablehttp_client_resumption_token_handling():
|
||||
"""Test StreamableHTTP client resumption token functionality."""
|
||||
test_url = "http://test.example/mcp"
|
||||
test_resumption_token = "resume-token-123"
|
||||
|
||||
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/json", "last-event-id": test_resumption_token}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
try:
|
||||
with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
|
||||
# Test that resumption token can be captured from headers
|
||||
assert read_queue is not None
|
||||
assert write_queue is not None
|
||||
except Exception:
|
||||
pass # Expected due to mocking
|
||||
1
dify/api/tests/unit_tests/core/mcp/server/__init__.py
Normal file
1
dify/api/tests/unit_tests/core/mcp/server/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# MCP server tests
|
||||
@@ -0,0 +1,512 @@
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import jsonschema
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.mcp import types
|
||||
from core.mcp.server.streamable_http import (
|
||||
build_parameter_schema,
|
||||
convert_input_form_to_parameters,
|
||||
extract_answer_from_response,
|
||||
handle_call_tool,
|
||||
handle_initialize,
|
||||
handle_list_tools,
|
||||
handle_mcp_request,
|
||||
handle_ping,
|
||||
prepare_tool_arguments,
|
||||
process_mapping_response,
|
||||
)
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
|
||||
|
||||
class TestHandleMCPRequest:
|
||||
"""Test handle_mcp_request function"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test fixtures"""
|
||||
self.app = Mock(spec=App)
|
||||
self.app.name = "test_app"
|
||||
self.app.mode = AppMode.CHAT
|
||||
|
||||
self.mcp_server = Mock(spec=AppMCPServer)
|
||||
self.mcp_server.description = "Test server"
|
||||
self.mcp_server.parameters_dict = {}
|
||||
|
||||
self.end_user = Mock(spec=EndUser)
|
||||
self.user_input_form = []
|
||||
|
||||
# Create mock request
|
||||
self.mock_request = Mock()
|
||||
self.mock_request.root = Mock()
|
||||
self.mock_request.root.id = 123
|
||||
|
||||
def test_handle_ping_request(self):
|
||||
"""Test handling ping request"""
|
||||
# Setup ping request
|
||||
self.mock_request.root = Mock(spec=types.PingRequest)
|
||||
self.mock_request.root.id = 123
|
||||
request_type = Mock(return_value=types.PingRequest)
|
||||
|
||||
with patch("core.mcp.server.streamable_http.type", request_type):
|
||||
result = handle_mcp_request(
|
||||
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
|
||||
)
|
||||
|
||||
assert isinstance(result, types.JSONRPCResponse)
|
||||
assert result.jsonrpc == "2.0"
|
||||
assert result.id == 123
|
||||
|
||||
def test_handle_initialize_request(self):
|
||||
"""Test handling initialize request"""
|
||||
# Setup initialize request
|
||||
self.mock_request.root = Mock(spec=types.InitializeRequest)
|
||||
self.mock_request.root.id = 123
|
||||
request_type = Mock(return_value=types.InitializeRequest)
|
||||
|
||||
with patch("core.mcp.server.streamable_http.type", request_type):
|
||||
result = handle_mcp_request(
|
||||
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
|
||||
)
|
||||
|
||||
assert isinstance(result, types.JSONRPCResponse)
|
||||
assert result.jsonrpc == "2.0"
|
||||
assert result.id == 123
|
||||
|
||||
def test_handle_list_tools_request(self):
|
||||
"""Test handling list tools request"""
|
||||
# Setup list tools request
|
||||
self.mock_request.root = Mock(spec=types.ListToolsRequest)
|
||||
self.mock_request.root.id = 123
|
||||
request_type = Mock(return_value=types.ListToolsRequest)
|
||||
|
||||
with patch("core.mcp.server.streamable_http.type", request_type):
|
||||
result = handle_mcp_request(
|
||||
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
|
||||
)
|
||||
|
||||
assert isinstance(result, types.JSONRPCResponse)
|
||||
assert result.jsonrpc == "2.0"
|
||||
assert result.id == 123
|
||||
|
||||
@patch("core.mcp.server.streamable_http.AppGenerateService")
|
||||
def test_handle_call_tool_request(self, mock_app_generate):
|
||||
"""Test handling call tool request"""
|
||||
# Setup call tool request
|
||||
mock_call_request = Mock(spec=types.CallToolRequest)
|
||||
mock_call_request.params = Mock()
|
||||
mock_call_request.params.arguments = {"query": "test question"}
|
||||
mock_call_request.id = 123
|
||||
|
||||
self.mock_request.root = mock_call_request
|
||||
request_type = Mock(return_value=types.CallToolRequest)
|
||||
|
||||
# Mock app generate service response
|
||||
mock_response = {"answer": "test answer"}
|
||||
mock_app_generate.generate.return_value = mock_response
|
||||
|
||||
with patch("core.mcp.server.streamable_http.type", request_type):
|
||||
result = handle_mcp_request(
|
||||
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
|
||||
)
|
||||
|
||||
assert isinstance(result, types.JSONRPCResponse)
|
||||
assert result.jsonrpc == "2.0"
|
||||
assert result.id == 123
|
||||
|
||||
# Verify AppGenerateService was called
|
||||
mock_app_generate.generate.assert_called_once()
|
||||
|
||||
def test_handle_unknown_request_type(self):
|
||||
"""Test handling unknown request type"""
|
||||
|
||||
# Setup unknown request
|
||||
class UnknownRequest:
|
||||
pass
|
||||
|
||||
self.mock_request.root = Mock(spec=UnknownRequest)
|
||||
self.mock_request.root.id = 123
|
||||
request_type = Mock(return_value=UnknownRequest)
|
||||
|
||||
with patch("core.mcp.server.streamable_http.type", request_type):
|
||||
result = handle_mcp_request(
|
||||
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
|
||||
)
|
||||
|
||||
assert isinstance(result, types.JSONRPCError)
|
||||
assert result.jsonrpc == "2.0"
|
||||
assert result.id == 123
|
||||
assert result.error.code == types.METHOD_NOT_FOUND
|
||||
|
||||
def test_handle_value_error(self):
|
||||
"""Test handling ValueError"""
|
||||
# Setup request that will cause ValueError
|
||||
self.mock_request.root = Mock(spec=types.CallToolRequest)
|
||||
self.mock_request.root.params = Mock()
|
||||
self.mock_request.root.params.arguments = {}
|
||||
|
||||
request_type = Mock(return_value=types.CallToolRequest)
|
||||
|
||||
# Don't provide end_user to cause ValueError
|
||||
with patch("core.mcp.server.streamable_http.type", request_type):
|
||||
result = handle_mcp_request(self.app, self.mock_request, self.user_input_form, self.mcp_server, None, 123)
|
||||
|
||||
assert isinstance(result, types.JSONRPCError)
|
||||
assert result.error.code == types.INVALID_PARAMS
|
||||
|
||||
def test_handle_generic_exception(self):
|
||||
"""Test handling generic exception"""
|
||||
# Setup request that will cause generic exception
|
||||
self.mock_request.root = Mock(spec=types.PingRequest)
|
||||
self.mock_request.root.id = 123
|
||||
|
||||
# Patch handle_ping to raise exception instead of type
|
||||
with patch("core.mcp.server.streamable_http.handle_ping", side_effect=Exception("Test error")):
|
||||
with patch("core.mcp.server.streamable_http.type", return_value=types.PingRequest):
|
||||
result = handle_mcp_request(
|
||||
self.app, self.mock_request, self.user_input_form, self.mcp_server, self.end_user, 123
|
||||
)
|
||||
|
||||
assert isinstance(result, types.JSONRPCError)
|
||||
assert result.error.code == types.INTERNAL_ERROR
|
||||
|
||||
|
||||
class TestIndividualHandlers:
|
||||
"""Test individual handler functions"""
|
||||
|
||||
def test_handle_ping(self):
|
||||
"""Test ping handler"""
|
||||
result = handle_ping()
|
||||
assert isinstance(result, types.EmptyResult)
|
||||
|
||||
def test_handle_initialize(self):
|
||||
"""Test initialize handler"""
|
||||
description = "Test server"
|
||||
|
||||
with patch("core.mcp.server.streamable_http.dify_config") as mock_config:
|
||||
mock_config.project.version = "1.0.0"
|
||||
result = handle_initialize(description)
|
||||
|
||||
assert isinstance(result, types.InitializeResult)
|
||||
assert result.protocolVersion == types.SERVER_LATEST_PROTOCOL_VERSION
|
||||
assert result.instructions == "Test server"
|
||||
|
||||
def test_handle_list_tools(self):
|
||||
"""Test list tools handler"""
|
||||
app_name = "test_app"
|
||||
app_mode = AppMode.CHAT
|
||||
description = "Test server"
|
||||
parameters_dict: dict[str, str] = {}
|
||||
user_input_form: list[VariableEntity] = []
|
||||
|
||||
result = handle_list_tools(app_name, app_mode, user_input_form, description, parameters_dict)
|
||||
|
||||
assert isinstance(result, types.ListToolsResult)
|
||||
assert len(result.tools) == 1
|
||||
assert result.tools[0].name == "test_app"
|
||||
assert result.tools[0].description == "Test server"
|
||||
|
||||
@patch("core.mcp.server.streamable_http.AppGenerateService")
|
||||
def test_handle_call_tool(self, mock_app_generate):
|
||||
"""Test call tool handler"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.CHAT
|
||||
|
||||
# Create mock request
|
||||
mock_request = Mock()
|
||||
mock_call_request = Mock(spec=types.CallToolRequest)
|
||||
mock_call_request.params = Mock()
|
||||
mock_call_request.params.arguments = {"query": "test question"}
|
||||
mock_request.root = mock_call_request
|
||||
|
||||
user_input_form: list[VariableEntity] = []
|
||||
end_user = Mock(spec=EndUser)
|
||||
|
||||
# Mock app generate service response
|
||||
mock_response = {"answer": "test answer"}
|
||||
mock_app_generate.generate.return_value = mock_response
|
||||
|
||||
result = handle_call_tool(app, mock_request, user_input_form, end_user)
|
||||
|
||||
assert isinstance(result, types.CallToolResult)
|
||||
assert len(result.content) == 1
|
||||
# Type assertion needed due to union type
|
||||
text_content = result.content[0]
|
||||
assert hasattr(text_content, "text")
|
||||
assert text_content.text == "test answer"
|
||||
|
||||
def test_handle_call_tool_no_end_user(self):
|
||||
"""Test call tool handler without end user"""
|
||||
app = Mock(spec=App)
|
||||
mock_request = Mock()
|
||||
user_input_form: list[VariableEntity] = []
|
||||
|
||||
with pytest.raises(ValueError, match="End user not found"):
|
||||
handle_call_tool(app, mock_request, user_input_form, None)
|
||||
|
||||
|
||||
class TestUtilityFunctions:
|
||||
"""Test utility functions"""
|
||||
|
||||
def test_build_parameter_schema_chat_mode(self):
|
||||
"""Test building parameter schema for chat mode"""
|
||||
app_mode = AppMode.CHAT
|
||||
parameters_dict: dict[str, str] = {"name": "Enter your name"}
|
||||
|
||||
user_input_form = [
|
||||
VariableEntity(
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
variable="name",
|
||||
description="User name",
|
||||
label="Name",
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
|
||||
|
||||
assert schema["type"] == "object"
|
||||
assert "query" in schema["properties"]
|
||||
assert "name" in schema["properties"]
|
||||
assert "query" in schema["required"]
|
||||
assert "name" in schema["required"]
|
||||
|
||||
def test_build_parameter_schema_workflow_mode(self):
|
||||
"""Test building parameter schema for workflow mode"""
|
||||
app_mode = AppMode.WORKFLOW
|
||||
parameters_dict: dict[str, str] = {"input_text": "Enter text"}
|
||||
|
||||
user_input_form = [
|
||||
VariableEntity(
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
variable="input_text",
|
||||
description="Input text",
|
||||
label="Input",
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
schema = build_parameter_schema(app_mode, user_input_form, parameters_dict)
|
||||
|
||||
assert schema["type"] == "object"
|
||||
assert "query" not in schema["properties"]
|
||||
assert "input_text" in schema["properties"]
|
||||
assert "input_text" in schema["required"]
|
||||
|
||||
def test_prepare_tool_arguments_chat_mode(self):
|
||||
"""Test preparing tool arguments for chat mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.CHAT
|
||||
|
||||
arguments = {"query": "test question", "name": "John"}
|
||||
|
||||
result = prepare_tool_arguments(app, arguments)
|
||||
|
||||
assert result["query"] == "test question"
|
||||
assert result["inputs"]["name"] == "John"
|
||||
# Original arguments should not be modified
|
||||
assert arguments["query"] == "test question"
|
||||
|
||||
def test_prepare_tool_arguments_workflow_mode(self):
|
||||
"""Test preparing tool arguments for workflow mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
arguments = {"input_text": "test input"}
|
||||
|
||||
result = prepare_tool_arguments(app, arguments)
|
||||
|
||||
assert "inputs" in result
|
||||
assert result["inputs"]["input_text"] == "test input"
|
||||
|
||||
def test_prepare_tool_arguments_completion_mode(self):
|
||||
"""Test preparing tool arguments for completion mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.COMPLETION
|
||||
|
||||
arguments = {"name": "John"}
|
||||
|
||||
result = prepare_tool_arguments(app, arguments)
|
||||
|
||||
assert result["query"] == ""
|
||||
assert result["inputs"]["name"] == "John"
|
||||
|
||||
def test_extract_answer_from_mapping_response_chat(self):
|
||||
"""Test extracting answer from mapping response for chat mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.CHAT
|
||||
|
||||
response = {"answer": "test answer", "other": "data"}
|
||||
|
||||
result = extract_answer_from_response(app, response)
|
||||
|
||||
assert result == "test answer"
|
||||
|
||||
def test_extract_answer_from_mapping_response_workflow(self):
|
||||
"""Test extracting answer from mapping response for workflow mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
response = {"data": {"outputs": {"result": "test result"}}}
|
||||
|
||||
result = extract_answer_from_response(app, response)
|
||||
|
||||
expected = json.dumps({"result": "test result"}, ensure_ascii=False)
|
||||
assert result == expected
|
||||
|
||||
def test_extract_answer_from_streaming_response(self):
|
||||
"""Test extracting answer from streaming response"""
|
||||
app = Mock(spec=App)
|
||||
|
||||
# Mock RateLimitGenerator
|
||||
mock_generator = Mock(spec=RateLimitGenerator)
|
||||
mock_generator.generator = [
|
||||
'data: {"event": "agent_thought", "thought": "thinking..."}',
|
||||
'data: {"event": "agent_thought", "thought": "more thinking"}',
|
||||
'data: {"event": "other", "content": "ignore this"}',
|
||||
"not data format",
|
||||
]
|
||||
|
||||
result = extract_answer_from_response(app, mock_generator)
|
||||
|
||||
assert result == "thinking...more thinking"
|
||||
|
||||
def test_process_mapping_response_invalid_mode(self):
|
||||
"""Test processing mapping response with invalid app mode"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = "invalid_mode"
|
||||
|
||||
response = {"answer": "test"}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid app mode"):
|
||||
process_mapping_response(app, response)
|
||||
|
||||
def test_convert_input_form_to_parameters(self):
|
||||
"""Test converting input form to parameters"""
|
||||
user_input_form = [
|
||||
VariableEntity(
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
variable="name",
|
||||
description="User name",
|
||||
label="Name",
|
||||
required=True,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.SELECT,
|
||||
variable="category",
|
||||
description="Category",
|
||||
label="Category",
|
||||
required=False,
|
||||
options=["A", "B", "C"],
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.NUMBER,
|
||||
variable="count",
|
||||
description="Count",
|
||||
label="Count",
|
||||
required=True,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.FILE,
|
||||
variable="upload",
|
||||
description="File upload",
|
||||
label="Upload",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
parameters_dict: dict[str, str] = {
|
||||
"name": "Enter your name",
|
||||
"category": "Select category",
|
||||
"count": "Enter count",
|
||||
}
|
||||
|
||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||
|
||||
# Check parameters
|
||||
assert "name" in parameters
|
||||
assert parameters["name"]["type"] == "string"
|
||||
assert parameters["name"]["description"] == "Enter your name"
|
||||
|
||||
assert "category" in parameters
|
||||
assert parameters["category"]["type"] == "string"
|
||||
assert parameters["category"]["enum"] == ["A", "B", "C"]
|
||||
|
||||
assert "count" in parameters
|
||||
assert parameters["count"]["type"] == "number"
|
||||
|
||||
# FILE type should be skipped - it creates empty dict but gets filtered later
|
||||
# Check that it doesn't have any meaningful content
|
||||
if "upload" in parameters:
|
||||
assert parameters["upload"] == {}
|
||||
|
||||
# Check required fields
|
||||
assert "name" in required
|
||||
assert "count" in required
|
||||
assert "category" not in required
|
||||
|
||||
# Note: _get_request_id function has been removed as request_id is now passed as parameter
|
||||
|
||||
def test_convert_input_form_to_parameters_jsonschema_validation_ok(self):
|
||||
"""Current schema uses 'number' for numeric fields; it should be a valid JSON Schema."""
|
||||
user_input_form = [
|
||||
VariableEntity(
|
||||
type=VariableEntityType.NUMBER,
|
||||
variable="count",
|
||||
description="Count",
|
||||
label="Count",
|
||||
required=True,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
variable="name",
|
||||
description="User name",
|
||||
label="Name",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
parameters_dict = {
|
||||
"count": "Enter count",
|
||||
"name": "Enter your name",
|
||||
}
|
||||
|
||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||
|
||||
# Build a complete JSON Schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
# 1) The schema itself must be valid
|
||||
jsonschema.Draft202012Validator.check_schema(schema)
|
||||
|
||||
# 2) Both float and integer instances should pass validation
|
||||
jsonschema.validate(instance={"count": 3.14, "name": "alice"}, schema=schema)
|
||||
jsonschema.validate(instance={"count": 2, "name": "bob"}, schema=schema)
|
||||
|
||||
def test_legacy_float_type_schema_is_invalid(self):
|
||||
"""Legacy/buggy behavior: using 'float' should produce an invalid JSON Schema."""
|
||||
# Manually construct a legacy/incorrect schema (simulating old behavior)
|
||||
bad_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "float", # Invalid type: JSON Schema does not support 'float'
|
||||
"description": "Enter count",
|
||||
}
|
||||
},
|
||||
"required": ["count"],
|
||||
}
|
||||
|
||||
# The schema itself should raise a SchemaError
|
||||
with pytest.raises(jsonschema.exceptions.SchemaError):
|
||||
jsonschema.Draft202012Validator.check_schema(bad_schema)
|
||||
|
||||
# Or validation should also raise SchemaError
|
||||
with pytest.raises(jsonschema.exceptions.SchemaError):
|
||||
jsonschema.validate(instance={"count": 1.23}, schema=bad_schema)
|
||||
239
dify/api/tests/unit_tests/core/mcp/test_entities.py
Normal file
239
dify/api/tests/unit_tests/core/mcp/test_entities.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""Unit tests for MCP entities module."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from core.mcp.entities import (
|
||||
SUPPORTED_PROTOCOL_VERSIONS,
|
||||
LifespanContextT,
|
||||
RequestContext,
|
||||
SessionT,
|
||||
)
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
|
||||
|
||||
|
||||
class TestProtocolVersions:
|
||||
"""Test protocol version constants."""
|
||||
|
||||
def test_supported_protocol_versions(self):
|
||||
"""Test supported protocol versions list."""
|
||||
assert isinstance(SUPPORTED_PROTOCOL_VERSIONS, list)
|
||||
assert len(SUPPORTED_PROTOCOL_VERSIONS) >= 3
|
||||
assert "2024-11-05" in SUPPORTED_PROTOCOL_VERSIONS
|
||||
assert "2025-03-26" in SUPPORTED_PROTOCOL_VERSIONS
|
||||
assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS
|
||||
|
||||
def test_latest_protocol_version_is_supported(self):
|
||||
"""Test that latest protocol version is in supported versions."""
|
||||
assert LATEST_PROTOCOL_VERSION in SUPPORTED_PROTOCOL_VERSIONS
|
||||
|
||||
|
||||
class TestRequestContext:
|
||||
"""Test RequestContext dataclass."""
|
||||
|
||||
def test_request_context_creation(self):
|
||||
"""Test creating a RequestContext instance."""
|
||||
mock_session = Mock(spec=BaseSession)
|
||||
mock_lifespan = {"key": "value"}
|
||||
mock_meta = RequestParams.Meta(progressToken="test-token")
|
||||
|
||||
context = RequestContext(
|
||||
request_id="test-request-123",
|
||||
meta=mock_meta,
|
||||
session=mock_session,
|
||||
lifespan_context=mock_lifespan,
|
||||
)
|
||||
|
||||
assert context.request_id == "test-request-123"
|
||||
assert context.meta == mock_meta
|
||||
assert context.session == mock_session
|
||||
assert context.lifespan_context == mock_lifespan
|
||||
|
||||
def test_request_context_with_none_meta(self):
|
||||
"""Test creating RequestContext with None meta."""
|
||||
mock_session = Mock(spec=BaseSession)
|
||||
|
||||
context = RequestContext(
|
||||
request_id=42, # Can be int or string
|
||||
meta=None,
|
||||
session=mock_session,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
assert context.request_id == 42
|
||||
assert context.meta is None
|
||||
assert context.session == mock_session
|
||||
assert context.lifespan_context is None
|
||||
|
||||
def test_request_context_attributes(self):
|
||||
"""Test RequestContext attributes are accessible."""
|
||||
mock_session = Mock(spec=BaseSession)
|
||||
|
||||
context = RequestContext(
|
||||
request_id="test-123",
|
||||
meta=None,
|
||||
session=mock_session,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
# Verify attributes are accessible
|
||||
assert hasattr(context, "request_id")
|
||||
assert hasattr(context, "meta")
|
||||
assert hasattr(context, "session")
|
||||
assert hasattr(context, "lifespan_context")
|
||||
|
||||
# Verify values
|
||||
assert context.request_id == "test-123"
|
||||
assert context.meta is None
|
||||
assert context.session == mock_session
|
||||
assert context.lifespan_context is None
|
||||
|
||||
def test_request_context_generic_typing(self):
|
||||
"""Test RequestContext with different generic types."""
|
||||
# Create a mock session with specific type
|
||||
mock_session = Mock(spec=BaseSession)
|
||||
|
||||
# Create context with string lifespan context
|
||||
context_str = RequestContext[BaseSession, str](
|
||||
request_id="test-1",
|
||||
meta=None,
|
||||
session=mock_session,
|
||||
lifespan_context="string-context",
|
||||
)
|
||||
assert isinstance(context_str.lifespan_context, str)
|
||||
|
||||
# Create context with dict lifespan context
|
||||
context_dict = RequestContext[BaseSession, dict](
|
||||
request_id="test-2",
|
||||
meta=None,
|
||||
session=mock_session,
|
||||
lifespan_context={"key": "value"},
|
||||
)
|
||||
assert isinstance(context_dict.lifespan_context, dict)
|
||||
|
||||
# Create context with custom object lifespan context
|
||||
class CustomLifespan:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
custom_lifespan = CustomLifespan("test-data")
|
||||
context_custom = RequestContext[BaseSession, CustomLifespan](
|
||||
request_id="test-3",
|
||||
meta=None,
|
||||
session=mock_session,
|
||||
lifespan_context=custom_lifespan,
|
||||
)
|
||||
assert isinstance(context_custom.lifespan_context, CustomLifespan)
|
||||
assert context_custom.lifespan_context.data == "test-data"
|
||||
|
||||
def test_request_context_with_progress_meta(self):
|
||||
"""Test RequestContext with progress metadata."""
|
||||
mock_session = Mock(spec=BaseSession)
|
||||
progress_meta = RequestParams.Meta(progressToken="progress-123")
|
||||
|
||||
context = RequestContext(
|
||||
request_id="req-456",
|
||||
meta=progress_meta,
|
||||
session=mock_session,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
assert context.meta is not None
|
||||
assert context.meta.progressToken == "progress-123"
|
||||
|
||||
def test_request_context_equality(self):
|
||||
"""Test RequestContext equality comparison."""
|
||||
mock_session1 = Mock(spec=BaseSession)
|
||||
mock_session2 = Mock(spec=BaseSession)
|
||||
|
||||
context1 = RequestContext(
|
||||
request_id="test-123",
|
||||
meta=None,
|
||||
session=mock_session1,
|
||||
lifespan_context="context",
|
||||
)
|
||||
|
||||
context2 = RequestContext(
|
||||
request_id="test-123",
|
||||
meta=None,
|
||||
session=mock_session1,
|
||||
lifespan_context="context",
|
||||
)
|
||||
|
||||
context3 = RequestContext(
|
||||
request_id="test-456",
|
||||
meta=None,
|
||||
session=mock_session1,
|
||||
lifespan_context="context",
|
||||
)
|
||||
|
||||
# Same values should be equal
|
||||
assert context1 == context2
|
||||
|
||||
# Different request_id should not be equal
|
||||
assert context1 != context3
|
||||
|
||||
# Different session should not be equal
|
||||
context4 = RequestContext(
|
||||
request_id="test-123",
|
||||
meta=None,
|
||||
session=mock_session2,
|
||||
lifespan_context="context",
|
||||
)
|
||||
assert context1 != context4
|
||||
|
||||
def test_request_context_repr(self):
|
||||
"""Test RequestContext string representation."""
|
||||
mock_session = Mock(spec=BaseSession)
|
||||
mock_session.__repr__ = Mock(return_value="<MockSession>")
|
||||
|
||||
context = RequestContext(
|
||||
request_id="test-123",
|
||||
meta=None,
|
||||
session=mock_session,
|
||||
lifespan_context={"data": "test"},
|
||||
)
|
||||
|
||||
repr_str = repr(context)
|
||||
assert "RequestContext" in repr_str
|
||||
assert "test-123" in repr_str
|
||||
assert "MockSession" in repr_str
|
||||
|
||||
|
||||
class TestTypeVariables:
|
||||
"""Test type variables defined in the module."""
|
||||
|
||||
def test_session_type_var(self):
|
||||
"""Test SessionT type variable."""
|
||||
|
||||
# Create a custom session class
|
||||
class CustomSession(BaseSession):
|
||||
pass
|
||||
|
||||
# Use in generic context
|
||||
def process_session(session: SessionT) -> SessionT:
|
||||
return session
|
||||
|
||||
mock_session = Mock(spec=CustomSession)
|
||||
result = process_session(mock_session)
|
||||
assert result == mock_session
|
||||
|
||||
def test_lifespan_context_type_var(self):
|
||||
"""Test LifespanContextT type variable."""
|
||||
|
||||
# Use in generic context
|
||||
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
|
||||
return context
|
||||
|
||||
# Test with different types
|
||||
str_context = "string-context"
|
||||
assert process_lifespan(str_context) == str_context
|
||||
|
||||
dict_context = {"key": "value"}
|
||||
assert process_lifespan(dict_context) == dict_context
|
||||
|
||||
class CustomContext:
|
||||
pass
|
||||
|
||||
custom_context = CustomContext()
|
||||
assert process_lifespan(custom_context) == custom_context
|
||||
205
dify/api/tests/unit_tests/core/mcp/test_error.py
Normal file
205
dify/api/tests/unit_tests/core/mcp/test_error.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Unit tests for MCP error classes."""
|
||||
|
||||
import pytest
|
||||
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError, MCPError
|
||||
|
||||
|
||||
class TestMCPError:
|
||||
"""Test MCPError base exception class."""
|
||||
|
||||
def test_mcp_error_creation(self):
|
||||
"""Test creating MCPError instance."""
|
||||
error = MCPError("Test error message")
|
||||
assert str(error) == "Test error message"
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
def test_mcp_error_inheritance(self):
|
||||
"""Test MCPError inherits from Exception."""
|
||||
error = MCPError()
|
||||
assert isinstance(error, Exception)
|
||||
assert type(error).__name__ == "MCPError"
|
||||
|
||||
def test_mcp_error_with_empty_message(self):
|
||||
"""Test MCPError with empty message."""
|
||||
error = MCPError()
|
||||
assert str(error) == ""
|
||||
|
||||
def test_mcp_error_raise(self):
|
||||
"""Test raising MCPError."""
|
||||
with pytest.raises(MCPError) as exc_info:
|
||||
raise MCPError("Something went wrong")
|
||||
|
||||
assert str(exc_info.value) == "Something went wrong"
|
||||
|
||||
|
||||
class TestMCPConnectionError:
|
||||
"""Test MCPConnectionError exception class."""
|
||||
|
||||
def test_mcp_connection_error_creation(self):
|
||||
"""Test creating MCPConnectionError instance."""
|
||||
error = MCPConnectionError("Connection failed")
|
||||
assert str(error) == "Connection failed"
|
||||
assert isinstance(error, MCPError)
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
def test_mcp_connection_error_inheritance(self):
|
||||
"""Test MCPConnectionError inheritance chain."""
|
||||
error = MCPConnectionError()
|
||||
assert isinstance(error, MCPConnectionError)
|
||||
assert isinstance(error, MCPError)
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
def test_mcp_connection_error_raise(self):
|
||||
"""Test raising MCPConnectionError."""
|
||||
with pytest.raises(MCPConnectionError) as exc_info:
|
||||
raise MCPConnectionError("Unable to connect to server")
|
||||
|
||||
assert str(exc_info.value) == "Unable to connect to server"
|
||||
|
||||
def test_mcp_connection_error_catch_as_mcp_error(self):
|
||||
"""Test catching MCPConnectionError as MCPError."""
|
||||
with pytest.raises(MCPError) as exc_info:
|
||||
raise MCPConnectionError("Connection issue")
|
||||
|
||||
assert isinstance(exc_info.value, MCPConnectionError)
|
||||
assert str(exc_info.value) == "Connection issue"
|
||||
|
||||
|
||||
class TestMCPAuthError:
|
||||
"""Test MCPAuthError exception class."""
|
||||
|
||||
def test_mcp_auth_error_creation(self):
|
||||
"""Test creating MCPAuthError instance."""
|
||||
error = MCPAuthError("Authentication failed")
|
||||
assert str(error) == "Authentication failed"
|
||||
assert isinstance(error, MCPConnectionError)
|
||||
assert isinstance(error, MCPError)
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
def test_mcp_auth_error_inheritance(self):
|
||||
"""Test MCPAuthError inheritance chain."""
|
||||
error = MCPAuthError()
|
||||
assert isinstance(error, MCPAuthError)
|
||||
assert isinstance(error, MCPConnectionError)
|
||||
assert isinstance(error, MCPError)
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
def test_mcp_auth_error_raise(self):
|
||||
"""Test raising MCPAuthError."""
|
||||
with pytest.raises(MCPAuthError) as exc_info:
|
||||
raise MCPAuthError("Invalid credentials")
|
||||
|
||||
assert str(exc_info.value) == "Invalid credentials"
|
||||
|
||||
def test_mcp_auth_error_catch_hierarchy(self):
|
||||
"""Test catching MCPAuthError at different levels."""
|
||||
# Catch as MCPAuthError
|
||||
with pytest.raises(MCPAuthError) as exc_info:
|
||||
raise MCPAuthError("Auth specific error")
|
||||
assert str(exc_info.value) == "Auth specific error"
|
||||
|
||||
# Catch as MCPConnectionError
|
||||
with pytest.raises(MCPConnectionError) as exc_info:
|
||||
raise MCPAuthError("Auth connection error")
|
||||
assert isinstance(exc_info.value, MCPAuthError)
|
||||
assert str(exc_info.value) == "Auth connection error"
|
||||
|
||||
# Catch as MCPError
|
||||
with pytest.raises(MCPError) as exc_info:
|
||||
raise MCPAuthError("Auth base error")
|
||||
assert isinstance(exc_info.value, MCPAuthError)
|
||||
assert str(exc_info.value) == "Auth base error"
|
||||
|
||||
|
||||
class TestErrorHierarchy:
|
||||
"""Test the complete error hierarchy."""
|
||||
|
||||
def test_exception_hierarchy(self):
|
||||
"""Test the complete exception hierarchy."""
|
||||
# Create instances
|
||||
base_error = MCPError("base")
|
||||
connection_error = MCPConnectionError("connection")
|
||||
auth_error = MCPAuthError("auth")
|
||||
|
||||
# Test type relationships
|
||||
assert not isinstance(base_error, MCPConnectionError)
|
||||
assert not isinstance(base_error, MCPAuthError)
|
||||
|
||||
assert isinstance(connection_error, MCPError)
|
||||
assert not isinstance(connection_error, MCPAuthError)
|
||||
|
||||
assert isinstance(auth_error, MCPError)
|
||||
assert isinstance(auth_error, MCPConnectionError)
|
||||
|
||||
def test_error_handling_patterns(self):
|
||||
"""Test common error handling patterns."""
|
||||
|
||||
def raise_auth_error():
|
||||
raise MCPAuthError("401 Unauthorized")
|
||||
|
||||
def raise_connection_error():
|
||||
raise MCPConnectionError("Connection timeout")
|
||||
|
||||
def raise_base_error():
|
||||
raise MCPError("Generic error")
|
||||
|
||||
# Pattern 1: Catch specific errors first
|
||||
errors_caught = []
|
||||
|
||||
for error_func in [raise_auth_error, raise_connection_error, raise_base_error]:
|
||||
try:
|
||||
error_func()
|
||||
except MCPAuthError:
|
||||
errors_caught.append("auth")
|
||||
except MCPConnectionError:
|
||||
errors_caught.append("connection")
|
||||
except MCPError:
|
||||
errors_caught.append("base")
|
||||
|
||||
assert errors_caught == ["auth", "connection", "base"]
|
||||
|
||||
# Pattern 2: Catch all as base error
|
||||
for error_func in [raise_auth_error, raise_connection_error, raise_base_error]:
|
||||
with pytest.raises(MCPError) as exc_info:
|
||||
error_func()
|
||||
assert isinstance(exc_info.value, MCPError)
|
||||
|
||||
def test_error_with_cause(self):
|
||||
"""Test errors with cause (chained exceptions)."""
|
||||
original_error = ValueError("Original error")
|
||||
|
||||
def raise_chained_error():
|
||||
try:
|
||||
raise original_error
|
||||
except ValueError as e:
|
||||
raise MCPConnectionError("Connection failed") from e
|
||||
|
||||
with pytest.raises(MCPConnectionError) as exc_info:
|
||||
raise_chained_error()
|
||||
|
||||
assert str(exc_info.value) == "Connection failed"
|
||||
assert exc_info.value.__cause__ == original_error
|
||||
|
||||
def test_error_comparison(self):
|
||||
"""Test error instance comparison."""
|
||||
error1 = MCPError("Test message")
|
||||
error2 = MCPError("Test message")
|
||||
error3 = MCPError("Different message")
|
||||
|
||||
# Errors are not equal even with same message (different instances)
|
||||
assert error1 != error2
|
||||
assert error1 != error3
|
||||
|
||||
# But they have the same type
|
||||
assert type(error1) == type(error2) == type(error3)
|
||||
|
||||
def test_error_representation(self):
|
||||
"""Test error string representation."""
|
||||
base_error = MCPError("Base error message")
|
||||
connection_error = MCPConnectionError("Connection error message")
|
||||
auth_error = MCPAuthError("Auth error message")
|
||||
|
||||
assert repr(base_error) == "MCPError('Base error message')"
|
||||
assert repr(connection_error) == "MCPConnectionError('Connection error message')"
|
||||
assert repr(auth_error) == "MCPAuthError('Auth error message')"
|
||||
382
dify/api/tests/unit_tests/core/mcp/test_mcp_client.py
Normal file
382
dify/api/tests/unit_tests/core/mcp/test_mcp_client.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""Unit tests for MCP client."""
|
||||
|
||||
from contextlib import ExitStack
|
||||
from types import TracebackType
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import CallToolResult, ListToolsResult, TextContent, Tool, ToolAnnotations
|
||||
|
||||
|
||||
class TestMCPClient:
|
||||
"""Test suite for MCPClient."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test client initialization."""
|
||||
client = MCPClient(
|
||||
server_url="http://test.example.com/mcp",
|
||||
headers={"Authorization": "Bearer test"},
|
||||
timeout=30.0,
|
||||
sse_read_timeout=60.0,
|
||||
)
|
||||
|
||||
assert client.server_url == "http://test.example.com/mcp"
|
||||
assert client.headers == {"Authorization": "Bearer test"}
|
||||
assert client.timeout == 30.0
|
||||
assert client.sse_read_timeout == 60.0
|
||||
assert client._session is None
|
||||
assert isinstance(client._exit_stack, ExitStack)
|
||||
assert client._initialized is False
|
||||
|
||||
def test_init_defaults(self):
|
||||
"""Test client initialization with defaults."""
|
||||
client = MCPClient(server_url="http://test.example.com")
|
||||
|
||||
assert client.server_url == "http://test.example.com"
|
||||
assert client.headers == {}
|
||||
assert client.timeout is None
|
||||
assert client.sse_read_timeout is None
|
||||
|
||||
@patch("core.mcp.mcp_client.streamablehttp_client")
|
||||
@patch("core.mcp.mcp_client.ClientSession")
|
||||
def test_initialize_with_mcp_url(self, mock_client_session, mock_streamable_client):
|
||||
"""Test initialization with MCP URL."""
|
||||
# Setup mocks
|
||||
mock_read_stream = Mock()
|
||||
mock_write_stream = Mock()
|
||||
mock_client_context = Mock()
|
||||
mock_streamable_client.return_value.__enter__.return_value = (
|
||||
mock_read_stream,
|
||||
mock_write_stream,
|
||||
mock_client_context,
|
||||
)
|
||||
|
||||
mock_session = Mock()
|
||||
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
client = MCPClient(server_url="http://test.example.com/mcp")
|
||||
client._initialize()
|
||||
|
||||
# Verify streamable client was called
|
||||
mock_streamable_client.assert_called_once_with(
|
||||
url="http://test.example.com/mcp",
|
||||
headers={},
|
||||
timeout=None,
|
||||
sse_read_timeout=None,
|
||||
)
|
||||
|
||||
# Verify session was created
|
||||
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
|
||||
mock_session.initialize.assert_called_once()
|
||||
assert client._session == mock_session
|
||||
|
||||
@patch("core.mcp.mcp_client.sse_client")
|
||||
@patch("core.mcp.mcp_client.ClientSession")
|
||||
def test_initialize_with_sse_url(self, mock_client_session, mock_sse_client):
|
||||
"""Test initialization with SSE URL."""
|
||||
# Setup mocks
|
||||
mock_read_stream = Mock()
|
||||
mock_write_stream = Mock()
|
||||
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
|
||||
|
||||
mock_session = Mock()
|
||||
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
client = MCPClient(server_url="http://test.example.com/sse")
|
||||
client._initialize()
|
||||
|
||||
# Verify SSE client was called
|
||||
mock_sse_client.assert_called_once_with(
|
||||
url="http://test.example.com/sse",
|
||||
headers={},
|
||||
timeout=None,
|
||||
sse_read_timeout=None,
|
||||
)
|
||||
|
||||
# Verify session was created
|
||||
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
|
||||
mock_session.initialize.assert_called_once()
|
||||
assert client._session == mock_session
|
||||
|
||||
@patch("core.mcp.mcp_client.sse_client")
|
||||
@patch("core.mcp.mcp_client.streamablehttp_client")
|
||||
@patch("core.mcp.mcp_client.ClientSession")
|
||||
def test_initialize_with_unknown_method_fallback_to_sse(
|
||||
self, mock_client_session, mock_streamable_client, mock_sse_client
|
||||
):
|
||||
"""Test initialization with unknown method falls back to SSE."""
|
||||
# Setup mocks
|
||||
mock_read_stream = Mock()
|
||||
mock_write_stream = Mock()
|
||||
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
|
||||
|
||||
mock_session = Mock()
|
||||
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
client = MCPClient(server_url="http://test.example.com/unknown")
|
||||
client._initialize()
|
||||
|
||||
# Verify SSE client was tried
|
||||
mock_sse_client.assert_called_once()
|
||||
mock_streamable_client.assert_not_called()
|
||||
|
||||
# Verify session was created
|
||||
assert client._session == mock_session
|
||||
|
||||
@patch("core.mcp.mcp_client.sse_client")
|
||||
@patch("core.mcp.mcp_client.streamablehttp_client")
|
||||
@patch("core.mcp.mcp_client.ClientSession")
|
||||
def test_initialize_fallback_from_sse_to_mcp(self, mock_client_session, mock_streamable_client, mock_sse_client):
|
||||
"""Test initialization falls back from SSE to MCP on connection error."""
|
||||
# Setup SSE to fail
|
||||
mock_sse_client.side_effect = MCPConnectionError("SSE connection failed")
|
||||
|
||||
# Setup MCP to succeed
|
||||
mock_read_stream = Mock()
|
||||
mock_write_stream = Mock()
|
||||
mock_client_context = Mock()
|
||||
mock_streamable_client.return_value.__enter__.return_value = (
|
||||
mock_read_stream,
|
||||
mock_write_stream,
|
||||
mock_client_context,
|
||||
)
|
||||
|
||||
mock_session = Mock()
|
||||
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
client = MCPClient(server_url="http://test.example.com/unknown")
|
||||
client._initialize()
|
||||
|
||||
# Verify both were tried
|
||||
mock_sse_client.assert_called_once()
|
||||
mock_streamable_client.assert_called_once()
|
||||
|
||||
# Verify session was created with MCP
|
||||
assert client._session == mock_session
|
||||
|
||||
@patch("core.mcp.mcp_client.streamablehttp_client")
|
||||
@patch("core.mcp.mcp_client.ClientSession")
|
||||
def test_connect_server_mcp(self, mock_client_session, mock_streamable_client):
|
||||
"""Test connect_server with MCP method."""
|
||||
# Setup mocks
|
||||
mock_read_stream = Mock()
|
||||
mock_write_stream = Mock()
|
||||
mock_client_context = Mock()
|
||||
mock_streamable_client.return_value.__enter__.return_value = (
|
||||
mock_read_stream,
|
||||
mock_write_stream,
|
||||
mock_client_context,
|
||||
)
|
||||
|
||||
mock_session = Mock()
|
||||
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
client = MCPClient(server_url="http://test.example.com")
|
||||
client.connect_server(mock_streamable_client, "mcp")
|
||||
|
||||
# Verify correct streams were passed
|
||||
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
|
||||
mock_session.initialize.assert_called_once()
|
||||
|
||||
@patch("core.mcp.mcp_client.sse_client")
|
||||
@patch("core.mcp.mcp_client.ClientSession")
|
||||
def test_connect_server_sse(self, mock_client_session, mock_sse_client):
|
||||
"""Test connect_server with SSE method."""
|
||||
# Setup mocks
|
||||
mock_read_stream = Mock()
|
||||
mock_write_stream = Mock()
|
||||
mock_sse_client.return_value.__enter__.return_value = (mock_read_stream, mock_write_stream)
|
||||
|
||||
mock_session = Mock()
|
||||
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
client = MCPClient(server_url="http://test.example.com")
|
||||
client.connect_server(mock_sse_client, "sse")
|
||||
|
||||
# Verify correct streams were passed
|
||||
mock_client_session.assert_called_once_with(mock_read_stream, mock_write_stream)
|
||||
mock_session.initialize.assert_called_once()
|
||||
|
||||
def test_context_manager_enter(self):
|
||||
"""Test context manager enter."""
|
||||
client = MCPClient(server_url="http://test.example.com")
|
||||
|
||||
with patch.object(client, "_initialize") as mock_initialize:
|
||||
result = client.__enter__()
|
||||
|
||||
assert result == client
|
||||
assert client._initialized is True
|
||||
mock_initialize.assert_called_once()
|
||||
|
||||
def test_context_manager_exit(self):
|
||||
"""Test context manager exit."""
|
||||
client = MCPClient(server_url="http://test.example.com")
|
||||
|
||||
with patch.object(client, "cleanup") as mock_cleanup:
|
||||
exc_type: type[BaseException] | None = None
|
||||
exc_val: BaseException | None = None
|
||||
exc_tb: TracebackType | None = None
|
||||
client.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
mock_cleanup.assert_called_once()
|
||||
|
||||
def test_list_tools_not_initialized(self):
|
||||
"""Test list_tools when session not initialized."""
|
||||
client = MCPClient(server_url="http://test.example.com")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
client.list_tools()
|
||||
|
||||
assert "Session not initialized" in str(exc_info.value)
|
||||
|
||||
def test_list_tools_success(self):
|
||||
"""Test successful list_tools call."""
|
||||
client = MCPClient(server_url="http://test.example.com")
|
||||
|
||||
# Setup mock session
|
||||
mock_session = Mock()
|
||||
expected_tools = [
|
||||
Tool(
|
||||
name="test-tool",
|
||||
description="A test tool",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
annotations=ToolAnnotations(title="Test Tool"),
|
||||
)
|
||||
]
|
||||
mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools)
|
||||
client._session = mock_session
|
||||
|
||||
result = client.list_tools()
|
||||
|
||||
assert result == expected_tools
|
||||
mock_session.list_tools.assert_called_once()
|
||||
|
||||
def test_invoke_tool_not_initialized(self):
|
||||
"""Test invoke_tool when session not initialized."""
|
||||
client = MCPClient(server_url="http://test.example.com")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
client.invoke_tool("test-tool", {"arg": "value"})
|
||||
|
||||
assert "Session not initialized" in str(exc_info.value)
|
||||
|
||||
def test_invoke_tool_success(self):
|
||||
"""Test successful invoke_tool call."""
|
||||
client = MCPClient(server_url="http://test.example.com")
|
||||
|
||||
# Setup mock session
|
||||
mock_session = Mock()
|
||||
expected_result = CallToolResult(
|
||||
content=[TextContent(type="text", text="Tool executed successfully")],
|
||||
isError=False,
|
||||
)
|
||||
mock_session.call_tool.return_value = expected_result
|
||||
client._session = mock_session
|
||||
|
||||
result = client.invoke_tool("test-tool", {"arg": "value"})
|
||||
|
||||
assert result == expected_result
|
||||
mock_session.call_tool.assert_called_once_with("test-tool", {"arg": "value"})
|
||||
|
||||
def test_cleanup(self):
|
||||
"""Test cleanup method."""
|
||||
client = MCPClient(server_url="http://test.example.com")
|
||||
mock_exit_stack = Mock(spec=ExitStack)
|
||||
client._exit_stack = mock_exit_stack
|
||||
client._session = Mock()
|
||||
client._initialized = True
|
||||
|
||||
client.cleanup()
|
||||
|
||||
mock_exit_stack.close.assert_called_once()
|
||||
assert client._session is None
|
||||
assert client._initialized is False
|
||||
|
||||
def test_cleanup_with_error(self):
|
||||
"""Test cleanup method with error."""
|
||||
client = MCPClient(server_url="http://test.example.com")
|
||||
mock_exit_stack = Mock(spec=ExitStack)
|
||||
mock_exit_stack.close.side_effect = Exception("Cleanup error")
|
||||
client._exit_stack = mock_exit_stack
|
||||
client._session = Mock()
|
||||
client._initialized = True
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
client.cleanup()
|
||||
|
||||
assert "Error during cleanup: Cleanup error" in str(exc_info.value)
|
||||
assert client._session is None
|
||||
assert client._initialized is False
|
||||
|
||||
@patch("core.mcp.mcp_client.streamablehttp_client")
|
||||
@patch("core.mcp.mcp_client.ClientSession")
|
||||
def test_full_context_manager_flow(self, mock_client_session, mock_streamable_client):
|
||||
"""Test full context manager flow."""
|
||||
# Setup mocks
|
||||
mock_read_stream = Mock()
|
||||
mock_write_stream = Mock()
|
||||
mock_client_context = Mock()
|
||||
mock_streamable_client.return_value.__enter__.return_value = (
|
||||
mock_read_stream,
|
||||
mock_write_stream,
|
||||
mock_client_context,
|
||||
)
|
||||
|
||||
mock_session = Mock()
|
||||
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
expected_tools = [Tool(name="test-tool", description="Test", inputSchema={})]
|
||||
mock_session.list_tools.return_value = ListToolsResult(tools=expected_tools)
|
||||
|
||||
with MCPClient(server_url="http://test.example.com/mcp") as client:
|
||||
assert client._initialized is True
|
||||
assert client._session == mock_session
|
||||
|
||||
# Test tool operations
|
||||
tools = client.list_tools()
|
||||
assert tools == expected_tools
|
||||
|
||||
# After exit, should be cleaned up
|
||||
assert client._initialized is False
|
||||
assert client._session is None
|
||||
|
||||
def test_headers_passed_to_clients(self):
|
||||
"""Test that headers are properly passed to underlying clients."""
|
||||
custom_headers = {
|
||||
"Authorization": "Bearer test-token",
|
||||
"X-Custom-Header": "test-value",
|
||||
}
|
||||
|
||||
with patch("core.mcp.mcp_client.streamablehttp_client") as mock_streamable_client:
|
||||
with patch("core.mcp.mcp_client.ClientSession") as mock_client_session:
|
||||
# Setup mocks
|
||||
mock_read_stream = Mock()
|
||||
mock_write_stream = Mock()
|
||||
mock_client_context = Mock()
|
||||
mock_streamable_client.return_value.__enter__.return_value = (
|
||||
mock_read_stream,
|
||||
mock_write_stream,
|
||||
mock_client_context,
|
||||
)
|
||||
|
||||
mock_session = Mock()
|
||||
mock_client_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
client = MCPClient(
|
||||
server_url="http://test.example.com/mcp",
|
||||
headers=custom_headers,
|
||||
timeout=30.0,
|
||||
sse_read_timeout=60.0,
|
||||
)
|
||||
client._initialize()
|
||||
|
||||
# Verify headers were passed
|
||||
mock_streamable_client.assert_called_once_with(
|
||||
url="http://test.example.com/mcp",
|
||||
headers=custom_headers,
|
||||
timeout=30.0,
|
||||
sse_read_timeout=60.0,
|
||||
)
|
||||
492
dify/api/tests/unit_tests/core/mcp/test_types.py
Normal file
492
dify/api/tests/unit_tests/core/mcp/test_types.py
Normal file
@@ -0,0 +1,492 @@
|
||||
"""Unit tests for MCP types module."""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.mcp.types import (
|
||||
INTERNAL_ERROR,
|
||||
INVALID_PARAMS,
|
||||
INVALID_REQUEST,
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
METHOD_NOT_FOUND,
|
||||
PARSE_ERROR,
|
||||
SERVER_LATEST_PROTOCOL_VERSION,
|
||||
Annotations,
|
||||
CallToolRequest,
|
||||
CallToolRequestParams,
|
||||
CallToolResult,
|
||||
ClientCapabilities,
|
||||
CompleteRequest,
|
||||
CompleteRequestParams,
|
||||
CompleteResult,
|
||||
Completion,
|
||||
CompletionArgument,
|
||||
CompletionContext,
|
||||
ErrorData,
|
||||
ImageContent,
|
||||
Implementation,
|
||||
InitializeRequest,
|
||||
InitializeRequestParams,
|
||||
InitializeResult,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
ListToolsRequest,
|
||||
ListToolsResult,
|
||||
OAuthClientInformation,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthTokens,
|
||||
PingRequest,
|
||||
ProgressNotification,
|
||||
ProgressNotificationParams,
|
||||
PromptReference,
|
||||
RequestParams,
|
||||
ResourceTemplateReference,
|
||||
Result,
|
||||
ServerCapabilities,
|
||||
TextContent,
|
||||
Tool,
|
||||
ToolAnnotations,
|
||||
)
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Test module constants."""
|
||||
|
||||
def test_protocol_versions(self):
|
||||
"""Test protocol version constants."""
|
||||
assert LATEST_PROTOCOL_VERSION == "2025-06-18"
|
||||
assert SERVER_LATEST_PROTOCOL_VERSION == "2024-11-05"
|
||||
|
||||
def test_error_codes(self):
|
||||
"""Test JSON-RPC error code constants."""
|
||||
assert PARSE_ERROR == -32700
|
||||
assert INVALID_REQUEST == -32600
|
||||
assert METHOD_NOT_FOUND == -32601
|
||||
assert INVALID_PARAMS == -32602
|
||||
assert INTERNAL_ERROR == -32603
|
||||
|
||||
|
||||
class TestRequestParams:
|
||||
"""Test RequestParams and related classes."""
|
||||
|
||||
def test_request_params_basic(self):
|
||||
"""Test basic RequestParams creation."""
|
||||
params = RequestParams()
|
||||
assert params.meta is None
|
||||
|
||||
def test_request_params_with_meta(self):
|
||||
"""Test RequestParams with meta."""
|
||||
meta = RequestParams.Meta(progressToken="test-token")
|
||||
params = RequestParams(_meta=meta)
|
||||
assert params.meta is not None
|
||||
assert params.meta.progressToken == "test-token"
|
||||
|
||||
def test_request_params_meta_extra_fields(self):
|
||||
"""Test RequestParams.Meta allows extra fields."""
|
||||
meta = RequestParams.Meta(progressToken="token", customField="value")
|
||||
assert meta.progressToken == "token"
|
||||
assert meta.customField == "value" # type: ignore
|
||||
|
||||
def test_request_params_serialization(self):
|
||||
"""Test RequestParams serialization with _meta alias."""
|
||||
meta = RequestParams.Meta(progressToken="test")
|
||||
params = RequestParams(_meta=meta)
|
||||
|
||||
# Model dump should use the alias
|
||||
dumped = params.model_dump(by_alias=True)
|
||||
assert "_meta" in dumped
|
||||
assert dumped["_meta"] is not None
|
||||
assert dumped["_meta"]["progressToken"] == "test"
|
||||
|
||||
|
||||
class TestJSONRPCMessages:
|
||||
"""Test JSON-RPC message types."""
|
||||
|
||||
def test_jsonrpc_request(self):
|
||||
"""Test JSONRPCRequest creation and validation."""
|
||||
request = JSONRPCRequest(jsonrpc="2.0", id="test-123", method="test_method", params={"key": "value"})
|
||||
|
||||
assert request.jsonrpc == "2.0"
|
||||
assert request.id == "test-123"
|
||||
assert request.method == "test_method"
|
||||
assert request.params == {"key": "value"}
|
||||
|
||||
def test_jsonrpc_request_numeric_id(self):
|
||||
"""Test JSONRPCRequest with numeric ID."""
|
||||
request = JSONRPCRequest(jsonrpc="2.0", id=123, method="test", params=None)
|
||||
assert request.id == 123
|
||||
|
||||
def test_jsonrpc_notification(self):
|
||||
"""Test JSONRPCNotification creation."""
|
||||
notification = JSONRPCNotification(jsonrpc="2.0", method="notification_method", params={"data": "test"})
|
||||
|
||||
assert notification.jsonrpc == "2.0"
|
||||
assert notification.method == "notification_method"
|
||||
assert not hasattr(notification, "id") # Notifications don't have ID
|
||||
|
||||
def test_jsonrpc_response(self):
|
||||
"""Test JSONRPCResponse creation."""
|
||||
response = JSONRPCResponse(jsonrpc="2.0", id="req-123", result={"success": True})
|
||||
|
||||
assert response.jsonrpc == "2.0"
|
||||
assert response.id == "req-123"
|
||||
assert response.result == {"success": True}
|
||||
|
||||
def test_jsonrpc_error(self):
|
||||
"""Test JSONRPCError creation."""
|
||||
error_data = ErrorData(code=INVALID_PARAMS, message="Invalid parameters", data={"field": "missing"})
|
||||
|
||||
error = JSONRPCError(jsonrpc="2.0", id="req-123", error=error_data)
|
||||
|
||||
assert error.jsonrpc == "2.0"
|
||||
assert error.id == "req-123"
|
||||
assert error.error.code == INVALID_PARAMS
|
||||
assert error.error.message == "Invalid parameters"
|
||||
assert error.error.data == {"field": "missing"}
|
||||
|
||||
def test_jsonrpc_message_parsing(self):
|
||||
"""Test JSONRPCMessage parsing different message types."""
|
||||
# Parse request
|
||||
request_json = '{"jsonrpc": "2.0", "id": 1, "method": "test", "params": null}'
|
||||
msg = JSONRPCMessage.model_validate_json(request_json)
|
||||
assert isinstance(msg.root, JSONRPCRequest)
|
||||
|
||||
# Parse response
|
||||
response_json = '{"jsonrpc": "2.0", "id": 1, "result": {"data": "test"}}'
|
||||
msg = JSONRPCMessage.model_validate_json(response_json)
|
||||
assert isinstance(msg.root, JSONRPCResponse)
|
||||
|
||||
# Parse error
|
||||
error_json = '{"jsonrpc": "2.0", "id": 1, "error": {"code": -32600, "message": "Invalid Request"}}'
|
||||
msg = JSONRPCMessage.model_validate_json(error_json)
|
||||
assert isinstance(msg.root, JSONRPCError)
|
||||
|
||||
|
||||
class TestCapabilities:
|
||||
"""Test capability classes."""
|
||||
|
||||
def test_client_capabilities(self):
|
||||
"""Test ClientCapabilities creation."""
|
||||
caps = ClientCapabilities(
|
||||
experimental={"feature": {"enabled": True}},
|
||||
sampling={"model_config": {"extra": "allow"}},
|
||||
roots={"listChanged": True},
|
||||
)
|
||||
|
||||
assert caps.experimental == {"feature": {"enabled": True}}
|
||||
assert caps.sampling is not None
|
||||
assert caps.roots.listChanged is True # type: ignore
|
||||
|
||||
def test_server_capabilities(self):
|
||||
"""Test ServerCapabilities creation."""
|
||||
caps = ServerCapabilities(
|
||||
tools={"listChanged": True},
|
||||
resources={"subscribe": True, "listChanged": False},
|
||||
prompts={"listChanged": True},
|
||||
logging={},
|
||||
completions={},
|
||||
)
|
||||
|
||||
assert caps.tools.listChanged is True # type: ignore
|
||||
assert caps.resources.subscribe is True # type: ignore
|
||||
assert caps.resources.listChanged is False # type: ignore
|
||||
|
||||
|
||||
class TestInitialization:
|
||||
"""Test initialization request/response types."""
|
||||
|
||||
def test_initialize_request(self):
|
||||
"""Test InitializeRequest creation."""
|
||||
client_info = Implementation(name="test-client", version="1.0.0")
|
||||
capabilities = ClientCapabilities()
|
||||
|
||||
params = InitializeRequestParams(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=capabilities, clientInfo=client_info
|
||||
)
|
||||
|
||||
request = InitializeRequest(params=params)
|
||||
|
||||
assert request.method == "initialize"
|
||||
assert request.params.protocolVersion == LATEST_PROTOCOL_VERSION
|
||||
assert request.params.clientInfo.name == "test-client"
|
||||
|
||||
def test_initialize_result(self):
|
||||
"""Test InitializeResult creation."""
|
||||
server_info = Implementation(name="test-server", version="1.0.0")
|
||||
capabilities = ServerCapabilities()
|
||||
|
||||
result = InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=capabilities,
|
||||
serverInfo=server_info,
|
||||
instructions="Welcome to test server",
|
||||
)
|
||||
|
||||
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
|
||||
assert result.serverInfo.name == "test-server"
|
||||
assert result.instructions == "Welcome to test server"
|
||||
|
||||
|
||||
class TestTools:
|
||||
"""Test tool-related types."""
|
||||
|
||||
def test_tool_creation(self):
|
||||
"""Test Tool creation with all fields."""
|
||||
tool = Tool(
|
||||
name="test_tool",
|
||||
title="Test Tool",
|
||||
description="A tool for testing",
|
||||
inputSchema={"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]},
|
||||
outputSchema={"type": "object", "properties": {"result": {"type": "string"}}},
|
||||
annotations=ToolAnnotations(
|
||||
title="Test Tool", readOnlyHint=False, destructiveHint=False, idempotentHint=True
|
||||
),
|
||||
)
|
||||
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.title == "Test Tool"
|
||||
assert tool.description == "A tool for testing"
|
||||
assert tool.inputSchema["properties"]["input"]["type"] == "string"
|
||||
assert tool.annotations.idempotentHint is True
|
||||
|
||||
def test_call_tool_request(self):
|
||||
"""Test CallToolRequest creation."""
|
||||
params = CallToolRequestParams(name="test_tool", arguments={"input": "test value"})
|
||||
|
||||
request = CallToolRequest(params=params)
|
||||
|
||||
assert request.method == "tools/call"
|
||||
assert request.params.name == "test_tool"
|
||||
assert request.params.arguments == {"input": "test value"}
|
||||
|
||||
def test_call_tool_result(self):
|
||||
"""Test CallToolResult creation."""
|
||||
result = CallToolResult(
|
||||
content=[TextContent(type="text", text="Tool executed successfully")],
|
||||
structuredContent={"status": "success", "data": "test"},
|
||||
isError=False,
|
||||
)
|
||||
|
||||
assert len(result.content) == 1
|
||||
assert result.content[0].text == "Tool executed successfully" # type: ignore
|
||||
assert result.structuredContent == {"status": "success", "data": "test"}
|
||||
assert result.isError is False
|
||||
|
||||
def test_list_tools_request(self):
|
||||
"""Test ListToolsRequest creation."""
|
||||
request = ListToolsRequest()
|
||||
assert request.method == "tools/list"
|
||||
|
||||
def test_list_tools_result(self):
|
||||
"""Test ListToolsResult creation."""
|
||||
tool1 = Tool(name="tool1", inputSchema={})
|
||||
tool2 = Tool(name="tool2", inputSchema={})
|
||||
|
||||
result = ListToolsResult(tools=[tool1, tool2])
|
||||
|
||||
assert len(result.tools) == 2
|
||||
assert result.tools[0].name == "tool1"
|
||||
assert result.tools[1].name == "tool2"
|
||||
|
||||
|
||||
class TestContent:
|
||||
"""Test content types."""
|
||||
|
||||
def test_text_content(self):
|
||||
"""Test TextContent creation."""
|
||||
annotations = Annotations(audience=["user"], priority=0.8)
|
||||
content = TextContent(type="text", text="Hello, world!", annotations=annotations)
|
||||
|
||||
assert content.type == "text"
|
||||
assert content.text == "Hello, world!"
|
||||
assert content.annotations is not None
|
||||
assert content.annotations.priority == 0.8
|
||||
|
||||
def test_image_content(self):
|
||||
"""Test ImageContent creation."""
|
||||
content = ImageContent(type="image", data="base64encodeddata", mimeType="image/png")
|
||||
|
||||
assert content.type == "image"
|
||||
assert content.data == "base64encodeddata"
|
||||
assert content.mimeType == "image/png"
|
||||
|
||||
|
||||
class TestOAuth:
|
||||
"""Test OAuth-related types."""
|
||||
|
||||
def test_oauth_client_metadata(self):
|
||||
"""Test OAuthClientMetadata creation."""
|
||||
metadata = OAuthClientMetadata(
|
||||
client_name="Test Client",
|
||||
redirect_uris=["https://example.com/callback"],
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
token_endpoint_auth_method="none",
|
||||
client_uri="https://example.com",
|
||||
scope="read write",
|
||||
)
|
||||
|
||||
assert metadata.client_name == "Test Client"
|
||||
assert len(metadata.redirect_uris) == 1
|
||||
assert "authorization_code" in metadata.grant_types
|
||||
|
||||
def test_oauth_client_information(self):
|
||||
"""Test OAuthClientInformation creation."""
|
||||
info = OAuthClientInformation(client_id="test-client-id", client_secret="test-secret")
|
||||
|
||||
assert info.client_id == "test-client-id"
|
||||
assert info.client_secret == "test-secret"
|
||||
|
||||
def test_oauth_client_information_without_secret(self):
|
||||
"""Test OAuthClientInformation without secret."""
|
||||
info = OAuthClientInformation(client_id="public-client")
|
||||
|
||||
assert info.client_id == "public-client"
|
||||
assert info.client_secret is None
|
||||
|
||||
def test_oauth_tokens(self):
|
||||
"""Test OAuthTokens creation."""
|
||||
tokens = OAuthTokens(
|
||||
access_token="access-token-123",
|
||||
token_type="Bearer",
|
||||
expires_in=3600,
|
||||
refresh_token="refresh-token-456",
|
||||
scope="read write",
|
||||
)
|
||||
|
||||
assert tokens.access_token == "access-token-123"
|
||||
assert tokens.token_type == "Bearer"
|
||||
assert tokens.expires_in == 3600
|
||||
assert tokens.refresh_token == "refresh-token-456"
|
||||
assert tokens.scope == "read write"
|
||||
|
||||
def test_oauth_metadata(self):
|
||||
"""Test OAuthMetadata creation."""
|
||||
metadata = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
registration_endpoint="https://auth.example.com/register",
|
||||
response_types_supported=["code", "token"],
|
||||
grant_types_supported=["authorization_code", "refresh_token"],
|
||||
code_challenge_methods_supported=["plain", "S256"],
|
||||
)
|
||||
|
||||
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
|
||||
assert "code" in metadata.response_types_supported
|
||||
assert "S256" in metadata.code_challenge_methods_supported
|
||||
|
||||
|
||||
class TestNotifications:
|
||||
"""Test notification types."""
|
||||
|
||||
def test_progress_notification(self):
|
||||
"""Test ProgressNotification creation."""
|
||||
params = ProgressNotificationParams(
|
||||
progressToken="progress-123", progress=50.0, total=100.0, message="Processing... 50%"
|
||||
)
|
||||
|
||||
notification = ProgressNotification(params=params)
|
||||
|
||||
assert notification.method == "notifications/progress"
|
||||
assert notification.params.progressToken == "progress-123"
|
||||
assert notification.params.progress == 50.0
|
||||
assert notification.params.total == 100.0
|
||||
assert notification.params.message == "Processing... 50%"
|
||||
|
||||
def test_ping_request(self):
|
||||
"""Test PingRequest creation."""
|
||||
request = PingRequest()
|
||||
assert request.method == "ping"
|
||||
assert request.params is None
|
||||
|
||||
|
||||
class TestCompletion:
|
||||
"""Test completion-related types."""
|
||||
|
||||
def test_completion_context(self):
|
||||
"""Test CompletionContext creation."""
|
||||
context = CompletionContext(arguments={"template_var": "value"})
|
||||
assert context.arguments == {"template_var": "value"}
|
||||
|
||||
def test_resource_template_reference(self):
|
||||
"""Test ResourceTemplateReference creation."""
|
||||
ref = ResourceTemplateReference(type="ref/resource", uri="file:///path/to/{filename}")
|
||||
assert ref.type == "ref/resource"
|
||||
assert ref.uri == "file:///path/to/{filename}"
|
||||
|
||||
def test_prompt_reference(self):
|
||||
"""Test PromptReference creation."""
|
||||
ref = PromptReference(type="ref/prompt", name="test_prompt")
|
||||
assert ref.type == "ref/prompt"
|
||||
assert ref.name == "test_prompt"
|
||||
|
||||
def test_complete_request(self):
|
||||
"""Test CompleteRequest creation."""
|
||||
ref = PromptReference(type="ref/prompt", name="test_prompt")
|
||||
arg = CompletionArgument(name="arg1", value="val")
|
||||
|
||||
params = CompleteRequestParams(ref=ref, argument=arg, context=CompletionContext(arguments={"key": "value"}))
|
||||
|
||||
request = CompleteRequest(params=params)
|
||||
|
||||
assert request.method == "completion/complete"
|
||||
assert request.params.ref.name == "test_prompt" # type: ignore
|
||||
assert request.params.argument.name == "arg1"
|
||||
|
||||
def test_complete_result(self):
|
||||
"""Test CompleteResult creation."""
|
||||
completion = Completion(values=["option1", "option2", "option3"], total=10, hasMore=True)
|
||||
|
||||
result = CompleteResult(completion=completion)
|
||||
|
||||
assert len(result.completion.values) == 3
|
||||
assert result.completion.total == 10
|
||||
assert result.completion.hasMore is True
|
||||
|
||||
|
||||
class TestValidation:
|
||||
"""Test validation of various types."""
|
||||
|
||||
def test_invalid_jsonrpc_version(self):
|
||||
"""Test invalid JSON-RPC version validation."""
|
||||
with pytest.raises(ValidationError):
|
||||
JSONRPCRequest(
|
||||
jsonrpc="1.0", # Invalid version
|
||||
id=1,
|
||||
method="test",
|
||||
)
|
||||
|
||||
def test_tool_annotations_validation(self):
|
||||
"""Test ToolAnnotations with invalid values."""
|
||||
# Valid annotations
|
||||
annotations = ToolAnnotations(
|
||||
title="Test", readOnlyHint=True, destructiveHint=False, idempotentHint=True, openWorldHint=False
|
||||
)
|
||||
assert annotations.title == "Test"
|
||||
|
||||
def test_extra_fields_allowed(self):
|
||||
"""Test that extra fields are allowed in models."""
|
||||
# Most models should allow extra fields
|
||||
tool = Tool(
|
||||
name="test",
|
||||
inputSchema={},
|
||||
customField="allowed", # type: ignore
|
||||
)
|
||||
assert tool.customField == "allowed" # type: ignore
|
||||
|
||||
def test_result_meta_alias(self):
|
||||
"""Test Result model with _meta alias."""
|
||||
# Create with the field name (not alias)
|
||||
result = Result(_meta={"key": "value"})
|
||||
|
||||
# Verify the field is set correctly
|
||||
assert result.meta == {"key": "value"}
|
||||
|
||||
# Dump with alias
|
||||
dumped = result.model_dump(by_alias=True)
|
||||
assert "_meta" in dumped
|
||||
assert dumped["_meta"] == {"key": "value"}
|
||||
355
dify/api/tests/unit_tests/core/mcp/test_utils.py
Normal file
355
dify/api/tests/unit_tests/core/mcp/test_utils.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""Unit tests for MCP utils module."""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
import httpx_sse
|
||||
import pytest
|
||||
|
||||
from core.mcp.utils import (
|
||||
STATUS_FORCELIST,
|
||||
create_mcp_error_response,
|
||||
create_ssrf_proxy_mcp_http_client,
|
||||
ssrf_proxy_sse_connect,
|
||||
)
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Test module constants."""
|
||||
|
||||
def test_status_forcelist(self):
|
||||
"""Test STATUS_FORCELIST contains expected HTTP status codes."""
|
||||
assert STATUS_FORCELIST == [429, 500, 502, 503, 504]
|
||||
assert 429 in STATUS_FORCELIST # Too Many Requests
|
||||
assert 500 in STATUS_FORCELIST # Internal Server Error
|
||||
assert 502 in STATUS_FORCELIST # Bad Gateway
|
||||
assert 503 in STATUS_FORCELIST # Service Unavailable
|
||||
assert 504 in STATUS_FORCELIST # Gateway Timeout
|
||||
|
||||
|
||||
class TestCreateSSRFProxyMCPHTTPClient:
|
||||
"""Test create_ssrf_proxy_mcp_http_client function."""
|
||||
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
def test_create_client_with_all_url_proxy(self, mock_config):
|
||||
"""Test client creation with SSRF_PROXY_ALL_URL configured."""
|
||||
mock_config.SSRF_PROXY_ALL_URL = "http://proxy.example.com:8080"
|
||||
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
|
||||
|
||||
client = create_ssrf_proxy_mcp_http_client(
|
||||
headers={"Authorization": "Bearer token"}, timeout=httpx.Timeout(30.0)
|
||||
)
|
||||
|
||||
assert isinstance(client, httpx.Client)
|
||||
assert client.headers["Authorization"] == "Bearer token"
|
||||
assert client.timeout.connect == 30.0
|
||||
assert client.follow_redirects is True
|
||||
|
||||
# Clean up
|
||||
client.close()
|
||||
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
def test_create_client_with_http_https_proxies(self, mock_config):
|
||||
"""Test client creation with separate HTTP/HTTPS proxies."""
|
||||
mock_config.SSRF_PROXY_ALL_URL = None
|
||||
mock_config.SSRF_PROXY_HTTP_URL = "http://http-proxy.example.com:8080"
|
||||
mock_config.SSRF_PROXY_HTTPS_URL = "http://https-proxy.example.com:8443"
|
||||
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = False
|
||||
|
||||
client = create_ssrf_proxy_mcp_http_client()
|
||||
|
||||
assert isinstance(client, httpx.Client)
|
||||
assert client.follow_redirects is True
|
||||
|
||||
# Clean up
|
||||
client.close()
|
||||
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
def test_create_client_without_proxy(self, mock_config):
|
||||
"""Test client creation without proxy configuration."""
|
||||
mock_config.SSRF_PROXY_ALL_URL = None
|
||||
mock_config.SSRF_PROXY_HTTP_URL = None
|
||||
mock_config.SSRF_PROXY_HTTPS_URL = None
|
||||
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
|
||||
|
||||
headers = {"X-Custom-Header": "value"}
|
||||
timeout = httpx.Timeout(timeout=30.0, connect=5.0, read=10.0, write=30.0)
|
||||
|
||||
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
|
||||
|
||||
assert isinstance(client, httpx.Client)
|
||||
assert client.headers["X-Custom-Header"] == "value"
|
||||
assert client.timeout.connect == 5.0
|
||||
assert client.timeout.read == 10.0
|
||||
assert client.follow_redirects is True
|
||||
|
||||
# Clean up
|
||||
client.close()
|
||||
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
def test_create_client_default_params(self, mock_config):
|
||||
"""Test client creation with default parameters."""
|
||||
mock_config.SSRF_PROXY_ALL_URL = None
|
||||
mock_config.SSRF_PROXY_HTTP_URL = None
|
||||
mock_config.SSRF_PROXY_HTTPS_URL = None
|
||||
mock_config.HTTP_REQUEST_NODE_SSL_VERIFY = True
|
||||
|
||||
client = create_ssrf_proxy_mcp_http_client()
|
||||
|
||||
assert isinstance(client, httpx.Client)
|
||||
# httpx.Client adds default headers, so we just check it's a Headers object
|
||||
assert isinstance(client.headers, httpx.Headers)
|
||||
# When no timeout is provided, httpx uses its default timeout
|
||||
assert client.timeout is not None
|
||||
|
||||
# Clean up
|
||||
client.close()
|
||||
|
||||
|
||||
class TestSSRFProxySSEConnect:
|
||||
"""Test ssrf_proxy_sse_connect function."""
|
||||
|
||||
@patch("core.mcp.utils.connect_sse")
|
||||
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
|
||||
def test_sse_connect_with_provided_client(self, mock_create_client, mock_connect_sse):
|
||||
"""Test SSE connection with pre-configured client."""
|
||||
# Setup mocks
|
||||
mock_client = Mock(spec=httpx.Client)
|
||||
mock_event_source = Mock(spec=httpx_sse.EventSource)
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_event_source
|
||||
mock_connect_sse.return_value = mock_context
|
||||
|
||||
# Call with provided client
|
||||
result = ssrf_proxy_sse_connect(
|
||||
"http://example.com/sse", client=mock_client, method="POST", headers={"Authorization": "Bearer token"}
|
||||
)
|
||||
|
||||
# Verify client creation was not called
|
||||
mock_create_client.assert_not_called()
|
||||
|
||||
# Verify connect_sse was called correctly
|
||||
mock_connect_sse.assert_called_once_with(
|
||||
mock_client, "POST", "http://example.com/sse", headers={"Authorization": "Bearer token"}
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert result == mock_context
|
||||
|
||||
@patch("core.mcp.utils.connect_sse")
|
||||
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
def test_sse_connect_without_client(self, mock_config, mock_create_client, mock_connect_sse):
|
||||
"""Test SSE connection without pre-configured client."""
|
||||
# Setup config
|
||||
mock_config.SSRF_DEFAULT_TIME_OUT = 30.0
|
||||
mock_config.SSRF_DEFAULT_CONNECT_TIME_OUT = 10.0
|
||||
mock_config.SSRF_DEFAULT_READ_TIME_OUT = 60.0
|
||||
mock_config.SSRF_DEFAULT_WRITE_TIME_OUT = 30.0
|
||||
|
||||
# Setup mocks
|
||||
mock_client = Mock(spec=httpx.Client)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
mock_event_source = Mock(spec=httpx_sse.EventSource)
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_event_source
|
||||
mock_connect_sse.return_value = mock_context
|
||||
|
||||
# Call without client
|
||||
result = ssrf_proxy_sse_connect("http://example.com/sse", headers={"X-Custom": "value"})
|
||||
|
||||
# Verify client was created
|
||||
mock_create_client.assert_called_once()
|
||||
call_args = mock_create_client.call_args
|
||||
assert call_args[1]["headers"] == {"X-Custom": "value"}
|
||||
|
||||
timeout = call_args[1]["timeout"]
|
||||
# httpx.Timeout object has these attributes
|
||||
assert isinstance(timeout, httpx.Timeout)
|
||||
assert timeout.connect == 10.0
|
||||
assert timeout.read == 60.0
|
||||
assert timeout.write == 30.0
|
||||
|
||||
# Verify connect_sse was called
|
||||
mock_connect_sse.assert_called_once_with(
|
||||
mock_client,
|
||||
"GET", # Default method
|
||||
"http://example.com/sse",
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert result == mock_context
|
||||
|
||||
@patch("core.mcp.utils.connect_sse")
|
||||
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
|
||||
def test_sse_connect_with_custom_timeout(self, mock_create_client, mock_connect_sse):
|
||||
"""Test SSE connection with custom timeout."""
|
||||
# Setup mocks
|
||||
mock_client = Mock(spec=httpx.Client)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
mock_event_source = Mock(spec=httpx_sse.EventSource)
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_event_source
|
||||
mock_connect_sse.return_value = mock_context
|
||||
|
||||
custom_timeout = httpx.Timeout(timeout=60.0, read=120.0)
|
||||
|
||||
# Call with custom timeout
|
||||
result = ssrf_proxy_sse_connect("http://example.com/sse", timeout=custom_timeout)
|
||||
|
||||
# Verify client was created with custom timeout
|
||||
mock_create_client.assert_called_once()
|
||||
call_args = mock_create_client.call_args
|
||||
assert call_args[1]["timeout"] == custom_timeout
|
||||
|
||||
# Verify result
|
||||
assert result == mock_context
|
||||
|
||||
@patch("core.mcp.utils.connect_sse")
|
||||
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
|
||||
def test_sse_connect_error_cleanup(self, mock_create_client, mock_connect_sse):
|
||||
"""Test SSE connection cleans up client on error."""
|
||||
# Setup mocks
|
||||
mock_client = Mock(spec=httpx.Client)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
# Make connect_sse raise an exception
|
||||
mock_connect_sse.side_effect = httpx.ConnectError("Connection failed")
|
||||
|
||||
# Call should raise the exception
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
ssrf_proxy_sse_connect("http://example.com/sse")
|
||||
|
||||
# Verify client was cleaned up
|
||||
mock_client.close.assert_called_once()
|
||||
|
||||
@patch("core.mcp.utils.connect_sse")
|
||||
def test_sse_connect_error_no_cleanup_with_provided_client(self, mock_connect_sse):
|
||||
"""Test SSE connection doesn't clean up provided client on error."""
|
||||
# Setup mocks
|
||||
mock_client = Mock(spec=httpx.Client)
|
||||
|
||||
# Make connect_sse raise an exception
|
||||
mock_connect_sse.side_effect = httpx.ConnectError("Connection failed")
|
||||
|
||||
# Call should raise the exception
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
ssrf_proxy_sse_connect("http://example.com/sse", client=mock_client)
|
||||
|
||||
# Verify client was NOT cleaned up (because it was provided)
|
||||
mock_client.close.assert_not_called()
|
||||
|
||||
|
||||
class TestCreateMCPErrorResponse:
|
||||
"""Test create_mcp_error_response function."""
|
||||
|
||||
def test_create_error_response_basic(self):
|
||||
"""Test creating basic error response."""
|
||||
generator = create_mcp_error_response(request_id="req-123", code=-32600, message="Invalid Request")
|
||||
|
||||
# Generator should yield bytes
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
# Get the response
|
||||
response_bytes = next(generator)
|
||||
assert isinstance(response_bytes, bytes)
|
||||
|
||||
# Parse the response
|
||||
response_str = response_bytes.decode("utf-8")
|
||||
response_json = json.loads(response_str)
|
||||
|
||||
assert response_json["jsonrpc"] == "2.0"
|
||||
assert response_json["id"] == "req-123"
|
||||
assert response_json["error"]["code"] == -32600
|
||||
assert response_json["error"]["message"] == "Invalid Request"
|
||||
assert response_json["error"]["data"] is None
|
||||
|
||||
# Generator should be exhausted
|
||||
with pytest.raises(StopIteration):
|
||||
next(generator)
|
||||
|
||||
def test_create_error_response_with_data(self):
|
||||
"""Test creating error response with additional data."""
|
||||
error_data = {"field": "username", "reason": "required"}
|
||||
|
||||
generator = create_mcp_error_response(
|
||||
request_id=456, # Numeric ID
|
||||
code=-32602,
|
||||
message="Invalid params",
|
||||
data=error_data,
|
||||
)
|
||||
|
||||
response_bytes = next(generator)
|
||||
response_json = json.loads(response_bytes.decode("utf-8"))
|
||||
|
||||
assert response_json["id"] == 456
|
||||
assert response_json["error"]["code"] == -32602
|
||||
assert response_json["error"]["message"] == "Invalid params"
|
||||
assert response_json["error"]["data"] == error_data
|
||||
|
||||
def test_create_error_response_without_request_id(self):
|
||||
"""Test creating error response without request ID."""
|
||||
generator = create_mcp_error_response(request_id=None, code=-32700, message="Parse error")
|
||||
|
||||
response_bytes = next(generator)
|
||||
response_json = json.loads(response_bytes.decode("utf-8"))
|
||||
|
||||
# Should default to ID 1
|
||||
assert response_json["id"] == 1
|
||||
assert response_json["error"]["code"] == -32700
|
||||
assert response_json["error"]["message"] == "Parse error"
|
||||
|
||||
def test_create_error_response_with_complex_data(self):
|
||||
"""Test creating error response with complex error data."""
|
||||
complex_data = {
|
||||
"errors": [{"field": "name", "message": "Too short"}, {"field": "email", "message": "Invalid format"}],
|
||||
"timestamp": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
|
||||
generator = create_mcp_error_response(
|
||||
request_id="complex-req", code=-32602, message="Validation failed", data=complex_data
|
||||
)
|
||||
|
||||
response_bytes = next(generator)
|
||||
response_json = json.loads(response_bytes.decode("utf-8"))
|
||||
|
||||
assert response_json["error"]["data"] == complex_data
|
||||
assert len(response_json["error"]["data"]["errors"]) == 2
|
||||
|
||||
def test_create_error_response_encoding(self):
|
||||
"""Test error response with non-ASCII characters."""
|
||||
generator = create_mcp_error_response(
|
||||
request_id="unicode-req",
|
||||
code=-32603,
|
||||
message="内部错误", # Chinese characters
|
||||
data={"details": "エラー詳細"}, # Japanese characters
|
||||
)
|
||||
|
||||
response_bytes = next(generator)
|
||||
|
||||
# Should be valid UTF-8
|
||||
response_str = response_bytes.decode("utf-8")
|
||||
response_json = json.loads(response_str)
|
||||
|
||||
assert response_json["error"]["message"] == "内部错误"
|
||||
assert response_json["error"]["data"]["details"] == "エラー詳細"
|
||||
|
||||
def test_create_error_response_yields_once(self):
|
||||
"""Test that error response generator yields exactly once."""
|
||||
generator = create_mcp_error_response(request_id="test", code=-32600, message="Test")
|
||||
|
||||
# First yield should work
|
||||
first_yield = next(generator)
|
||||
assert isinstance(first_yield, bytes)
|
||||
|
||||
# Second yield should raise StopIteration
|
||||
with pytest.raises(StopIteration):
|
||||
next(generator)
|
||||
|
||||
# Subsequent calls should also raise
|
||||
with pytest.raises(StopIteration):
|
||||
next(generator)
|
||||
@@ -0,0 +1,99 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call
|
||||
|
||||
ToolCall = AssistantPromptMessage.ToolCall
|
||||
|
||||
# CASE 1: Single tool call
|
||||
INPUTS_CASE_1 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_1 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...)
|
||||
INPUTS_CASE_2 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_2 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
ToolCall(
|
||||
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...)
|
||||
INPUTS_CASE_3 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_3 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
ToolCall(
|
||||
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 4: Tool call sequences with no IDs
|
||||
INPUTS_CASE_4 = [
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_4 = [
|
||||
ToolCall(
|
||||
id="RANDOM_ID_1",
|
||||
type="function",
|
||||
function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
|
||||
),
|
||||
ToolCall(
|
||||
id="RANDOM_ID_2",
|
||||
type="function",
|
||||
function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _run_case(inputs: list[ToolCall], expected: list[ToolCall]):
|
||||
actual = []
|
||||
_increase_tool_call(inputs, actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test__increase_tool_call():
|
||||
# case 1:
|
||||
_run_case(INPUTS_CASE_1, EXPECTED_CASE_1)
|
||||
|
||||
# case 2:
|
||||
_run_case(INPUTS_CASE_2, EXPECTED_CASE_2)
|
||||
|
||||
# case 3:
|
||||
_run_case(INPUTS_CASE_3, EXPECTED_CASE_3)
|
||||
|
||||
# case 4:
|
||||
mock_id_generator = MagicMock()
|
||||
mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4]
|
||||
with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator):
|
||||
_run_case(INPUTS_CASE_4, EXPECTED_CASE_4)
|
||||
@@ -0,0 +1,148 @@
|
||||
"""Tests for LLMUsage entity."""
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
|
||||
|
||||
class TestLLMUsage:
|
||||
"""Test cases for LLMUsage class."""
|
||||
|
||||
def test_from_metadata_with_all_tokens(self):
|
||||
"""Test from_metadata when all token types are provided."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"prompt_unit_price": 0.001,
|
||||
"completion_unit_price": 0.002,
|
||||
"total_price": 0.2,
|
||||
"currency": "USD",
|
||||
"latency": 1.5,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 150
|
||||
assert usage.prompt_unit_price == Decimal("0.001")
|
||||
assert usage.completion_unit_price == Decimal("0.002")
|
||||
assert usage.total_price == Decimal("0.2")
|
||||
assert usage.currency == "USD"
|
||||
assert usage.latency == 1.5
|
||||
|
||||
def test_from_metadata_with_prompt_tokens_only(self):
|
||||
"""Test from_metadata when only prompt_tokens is provided."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 100,
|
||||
"total_tokens": 100,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 0
|
||||
assert usage.total_tokens == 100
|
||||
|
||||
def test_from_metadata_with_completion_tokens_only(self):
|
||||
"""Test from_metadata when only completion_tokens is provided."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 50,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 0
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 50
|
||||
|
||||
def test_from_metadata_calculates_total_when_missing(self):
|
||||
"""Test from_metadata calculates total_tokens when not provided."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 150 # Should be calculated
|
||||
|
||||
def test_from_metadata_with_total_but_no_completion(self):
|
||||
"""
|
||||
Test from_metadata when total_tokens is provided but completion_tokens is 0.
|
||||
This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens.
|
||||
"""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 479,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 521,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
# This is the key fix - prompt tokens should remain as prompt tokens
|
||||
assert usage.prompt_tokens == 479
|
||||
assert usage.completion_tokens == 0
|
||||
assert usage.total_tokens == 521
|
||||
|
||||
def test_from_metadata_with_empty_metadata(self):
|
||||
"""Test from_metadata with empty metadata."""
|
||||
metadata: LLMUsageMetadata = {}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 0
|
||||
assert usage.completion_tokens == 0
|
||||
assert usage.total_tokens == 0
|
||||
assert usage.currency == "USD"
|
||||
assert usage.latency == 0.0
|
||||
|
||||
def test_from_metadata_preserves_zero_completion_tokens(self):
|
||||
"""
|
||||
Test that zero completion_tokens are preserved when explicitly set.
|
||||
This is important for agent nodes that only use prompt tokens.
|
||||
"""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 1000,
|
||||
"prompt_unit_price": 0.15,
|
||||
"completion_unit_price": 0.60,
|
||||
"prompt_price": 0.00015,
|
||||
"completion_price": 0,
|
||||
"total_price": 0.00015,
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_tokens == 1000
|
||||
assert usage.completion_tokens == 0
|
||||
assert usage.total_tokens == 1000
|
||||
assert usage.prompt_price == Decimal("0.00015")
|
||||
assert usage.completion_price == Decimal(0)
|
||||
assert usage.total_price == Decimal("0.00015")
|
||||
|
||||
def test_from_metadata_with_decimal_values(self):
|
||||
"""Test from_metadata handles decimal values correctly."""
|
||||
metadata: LLMUsageMetadata = {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"prompt_unit_price": "0.001",
|
||||
"completion_unit_price": "0.002",
|
||||
"prompt_price": "0.1",
|
||||
"completion_price": "0.1",
|
||||
"total_price": "0.2",
|
||||
}
|
||||
|
||||
usage = LLMUsage.from_metadata(metadata)
|
||||
|
||||
assert usage.prompt_unit_price == Decimal("0.001")
|
||||
assert usage.completion_unit_price == Decimal("0.002")
|
||||
assert usage.prompt_price == Decimal("0.1")
|
||||
assert usage.completion_price == Decimal("0.1")
|
||||
assert usage.total_price == Decimal("0.2")
|
||||
1
dify/api/tests/unit_tests/core/ops/__init__.py
Normal file
1
dify/api/tests/unit_tests/core/ops/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Unit tests for core ops module
|
||||
416
dify/api/tests/unit_tests/core/ops/test_config_entity.py
Normal file
416
dify/api/tests/unit_tests/core/ops/test_config_entity.py
Normal file
@@ -0,0 +1,416 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.ops.entities.config_entity import (
|
||||
AliyunConfig,
|
||||
ArizeConfig,
|
||||
LangfuseConfig,
|
||||
LangSmithConfig,
|
||||
OpikConfig,
|
||||
PhoenixConfig,
|
||||
TracingProviderEnum,
|
||||
WeaveConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestTracingProviderEnum:
|
||||
"""Test cases for TracingProviderEnum"""
|
||||
|
||||
def test_enum_values(self):
|
||||
"""Test that all expected enum values are present"""
|
||||
assert TracingProviderEnum.ARIZE == "arize"
|
||||
assert TracingProviderEnum.PHOENIX == "phoenix"
|
||||
assert TracingProviderEnum.LANGFUSE == "langfuse"
|
||||
assert TracingProviderEnum.LANGSMITH == "langsmith"
|
||||
assert TracingProviderEnum.OPIK == "opik"
|
||||
assert TracingProviderEnum.WEAVE == "weave"
|
||||
assert TracingProviderEnum.ALIYUN == "aliyun"
|
||||
|
||||
|
||||
class TestArizeConfig:
|
||||
"""Test cases for ArizeConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Arize configuration"""
|
||||
config = ArizeConfig(
|
||||
api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.space_id == "test_space"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.arize.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = ArizeConfig()
|
||||
assert config.api_key is None
|
||||
assert config.space_id is None
|
||||
assert config.project is None
|
||||
assert config.endpoint == "https://otlp.arize.com"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = ArizeConfig(project="")
|
||||
assert config.project == "default"
|
||||
|
||||
def test_project_validation_none(self):
|
||||
"""Test project validation with None value"""
|
||||
config = ArizeConfig(project=None)
|
||||
assert config.project == "default"
|
||||
|
||||
def test_endpoint_validation_empty(self):
|
||||
"""Test endpoint validation with empty value"""
|
||||
config = ArizeConfig(endpoint="")
|
||||
assert config.endpoint == "https://otlp.arize.com"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation normalizes URL by removing path"""
|
||||
config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
|
||||
assert config.endpoint == "https://custom.arize.com"
|
||||
|
||||
def test_endpoint_validation_invalid_scheme(self):
|
||||
"""Test endpoint validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
ArizeConfig(endpoint="ftp://invalid.com")
|
||||
|
||||
def test_endpoint_validation_no_scheme(self):
|
||||
"""Test endpoint validation rejects URLs without scheme"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
ArizeConfig(endpoint="invalid.com")
|
||||
|
||||
|
||||
class TestPhoenixConfig:
|
||||
"""Test cases for PhoenixConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Phoenix configuration"""
|
||||
config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.phoenix.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = PhoenixConfig()
|
||||
assert config.api_key is None
|
||||
assert config.project is None
|
||||
assert config.endpoint == "https://app.phoenix.arize.com"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = PhoenixConfig(project="")
|
||||
assert config.project == "default"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation with path"""
|
||||
config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
|
||||
assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
|
||||
|
||||
def test_endpoint_validation_without_path(self):
|
||||
"""Test endpoint validation without path"""
|
||||
config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
|
||||
assert config.endpoint == "https://app.phoenix.arize.com"
|
||||
|
||||
|
||||
class TestLangfuseConfig:
|
||||
"""Test cases for LangfuseConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Langfuse configuration"""
|
||||
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
|
||||
assert config.public_key == "public_key"
|
||||
assert config.secret_key == "secret_key"
|
||||
assert config.host == "https://custom.langfuse.com"
|
||||
|
||||
def test_valid_config_with_path(self):
|
||||
host = "https://custom.langfuse.com/api/v1"
|
||||
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host)
|
||||
assert config.public_key == "public_key"
|
||||
assert config.secret_key == "secret_key"
|
||||
assert config.host == host
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = LangfuseConfig(public_key="public", secret_key="secret")
|
||||
assert config.host == "https://api.langfuse.com"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(public_key="public")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(secret_key="secret")
|
||||
|
||||
def test_host_validation_empty(self):
|
||||
"""Test host validation with empty value"""
|
||||
config = LangfuseConfig(public_key="public", secret_key="secret", host="")
|
||||
assert config.host == "https://api.langfuse.com"
|
||||
|
||||
|
||||
class TestLangSmithConfig:
|
||||
"""Test cases for LangSmithConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid LangSmith configuration"""
|
||||
config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.smith.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = LangSmithConfig(api_key="key", project="project")
|
||||
assert config.endpoint == "https://api.smith.langchain.com"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(api_key="key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(project="project")
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")
|
||||
|
||||
|
||||
class TestOpikConfig:
|
||||
"""Test cases for OpikConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Opik configuration"""
|
||||
config = OpikConfig(
|
||||
api_key="test_key",
|
||||
project="test_project",
|
||||
workspace="test_workspace",
|
||||
url="https://custom.comet.com/opik/api/",
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.workspace == "test_workspace"
|
||||
assert config.url == "https://custom.comet.com/opik/api/"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = OpikConfig()
|
||||
assert config.api_key is None
|
||||
assert config.project is None
|
||||
assert config.workspace is None
|
||||
assert config.url == "https://www.comet.com/opik/api/"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = OpikConfig(project="")
|
||||
assert config.project == "Default Project"
|
||||
|
||||
def test_url_validation_empty(self):
|
||||
"""Test URL validation with empty value"""
|
||||
config = OpikConfig(url="")
|
||||
assert config.url == "https://www.comet.com/opik/api/"
|
||||
|
||||
def test_url_validation_missing_suffix(self):
|
||||
"""Test URL validation requires /api/ suffix"""
|
||||
with pytest.raises(ValidationError, match="URL should end with /api/"):
|
||||
OpikConfig(url="https://custom.comet.com/opik/")
|
||||
|
||||
def test_url_validation_invalid_scheme(self):
|
||||
"""Test URL validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
|
||||
OpikConfig(url="ftp://custom.comet.com/opik/api/")
|
||||
|
||||
|
||||
class TestWeaveConfig:
|
||||
"""Test cases for WeaveConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Weave configuration"""
|
||||
config = WeaveConfig(
|
||||
api_key="test_key",
|
||||
entity="test_entity",
|
||||
project="test_project",
|
||||
endpoint="https://custom.wandb.ai",
|
||||
host="https://custom.host.com",
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.entity == "test_entity"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.wandb.ai"
|
||||
assert config.host == "https://custom.host.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = WeaveConfig(api_key="key", project="project")
|
||||
assert config.entity is None
|
||||
assert config.endpoint == "https://trace.wandb.ai"
|
||||
assert config.host is None
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(api_key="key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(project="project")
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
|
||||
|
||||
def test_host_validation_optional(self):
|
||||
"""Test host validation is optional but validates when provided"""
|
||||
config = WeaveConfig(api_key="key", project="project", host=None)
|
||||
assert config.host is None
|
||||
|
||||
config = WeaveConfig(api_key="key", project="project", host="")
|
||||
assert config.host == ""
|
||||
|
||||
config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
|
||||
assert config.host == "https://valid.host.com"
|
||||
|
||||
def test_host_validation_invalid_scheme(self):
|
||||
"""Test host validation rejects invalid schemes when provided"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
|
||||
|
||||
|
||||
class TestAliyunConfig:
|
||||
"""Test cases for AliyunConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Aliyun configuration"""
|
||||
config = AliyunConfig(
|
||||
app_name="test_app",
|
||||
license_key="test_license_key",
|
||||
endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
|
||||
)
|
||||
assert config.app_name == "test_app"
|
||||
assert config.license_key == "test_license_key"
|
||||
assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
assert config.app_name == "dify_app"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="test_license")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_app_name_validation_empty(self):
|
||||
"""Test app_name validation with empty value"""
|
||||
config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
|
||||
)
|
||||
assert config.app_name == "dify_app"
|
||||
|
||||
def test_endpoint_validation_empty(self):
|
||||
"""Test endpoint validation with empty value"""
|
||||
config = AliyunConfig(license_key="test_license", endpoint="")
|
||||
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation preserves path for Aliyun endpoints"""
|
||||
config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
)
|
||||
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
|
||||
def test_endpoint_validation_invalid_scheme(self):
|
||||
"""Test endpoint validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
|
||||
AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_endpoint_validation_no_scheme(self):
|
||||
"""Test endpoint validation rejects URLs without scheme"""
|
||||
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
|
||||
AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_license_key_required(self):
|
||||
"""Test that license_key is required and cannot be empty"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_valid_endpoint_format_examples(self):
|
||||
"""Test valid endpoint format examples from comments"""
|
||||
valid_endpoints = [
|
||||
# cms2.0 public endpoint
|
||||
"https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry",
|
||||
# cms2.0 intranet endpoint
|
||||
"https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry",
|
||||
# xtrace public endpoint
|
||||
"http://tracing-cn-heyuan.arms.aliyuncs.com",
|
||||
# xtrace intranet endpoint
|
||||
"http://tracing-cn-heyuan-internal.arms.aliyuncs.com",
|
||||
]
|
||||
|
||||
for endpoint in valid_endpoints:
|
||||
config = AliyunConfig(license_key="test_license", endpoint=endpoint)
|
||||
assert config.endpoint == endpoint
|
||||
|
||||
|
||||
class TestConfigIntegration:
|
||||
"""Integration tests for configuration classes"""
|
||||
|
||||
def test_all_configs_can_be_instantiated(self):
|
||||
"""Test that all config classes can be instantiated with valid data"""
|
||||
configs = [
|
||||
ArizeConfig(api_key="key"),
|
||||
PhoenixConfig(api_key="key"),
|
||||
LangfuseConfig(public_key="public", secret_key="secret"),
|
||||
LangSmithConfig(api_key="key", project="project"),
|
||||
OpikConfig(api_key="key"),
|
||||
WeaveConfig(api_key="key", project="project"),
|
||||
AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com"),
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
assert config is not None
|
||||
|
||||
def test_url_normalization_consistency(self):
|
||||
"""Test that URL normalization works consistently across configs"""
|
||||
# Test that paths are removed from endpoints
|
||||
arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test")
|
||||
phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
|
||||
phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
|
||||
aliyun_config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
)
|
||||
|
||||
assert arize_config.endpoint == "https://arize.com"
|
||||
assert phoenix_with_path_config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
|
||||
assert phoenix_without_path_config.endpoint == "https://app.phoenix.arize.com"
|
||||
assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
|
||||
def test_project_default_values(self):
|
||||
"""Test that project default values are set correctly"""
|
||||
arize_config = ArizeConfig(project="")
|
||||
phoenix_config = PhoenixConfig(project="")
|
||||
opik_config = OpikConfig(project="")
|
||||
aliyun_config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
|
||||
)
|
||||
|
||||
assert arize_config.project == "default"
|
||||
assert phoenix_config.project == "default"
|
||||
assert opik_config.project == "Default Project"
|
||||
assert aliyun_config.app_name == "dify_app"
|
||||
138
dify/api/tests/unit_tests/core/ops/test_utils.py
Normal file
138
dify/api/tests/unit_tests/core/ops/test_utils.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import pytest
|
||||
|
||||
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
"""Test cases for validate_url function"""
|
||||
|
||||
def test_valid_https_url(self):
|
||||
"""Test valid HTTPS URL"""
|
||||
result = validate_url("https://example.com", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_valid_http_url(self):
|
||||
"""Test valid HTTP URL"""
|
||||
result = validate_url("http://example.com", "https://default.com")
|
||||
assert result == "http://example.com"
|
||||
|
||||
def test_url_with_path_removed(self):
|
||||
"""Test that URL path is removed during normalization"""
|
||||
result = validate_url("https://example.com/api/v1/test", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_url_with_query_removed(self):
|
||||
"""Test that URL query parameters are removed"""
|
||||
result = validate_url("https://example.com?param=value", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_url_with_fragment_removed(self):
|
||||
"""Test that URL fragments are removed"""
|
||||
result = validate_url("https://example.com#section", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_empty_url_returns_default(self):
|
||||
"""Test empty URL returns default"""
|
||||
result = validate_url("", "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_none_url_returns_default(self):
|
||||
"""Test None URL returns default"""
|
||||
result = validate_url(None, "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_whitespace_url_returns_default(self):
|
||||
"""Test whitespace URL returns default"""
|
||||
result = validate_url(" ", "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_invalid_scheme_raises_error(self):
|
||||
"""Test invalid scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL scheme must be one of"):
|
||||
validate_url("ftp://example.com", "https://default.com")
|
||||
|
||||
def test_no_scheme_raises_error(self):
|
||||
"""Test URL without scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL scheme must be one of"):
|
||||
validate_url("example.com", "https://default.com")
|
||||
|
||||
def test_custom_allowed_schemes(self):
|
||||
"""Test custom allowed schemes"""
|
||||
result = validate_url("https://example.com", "https://default.com", allowed_schemes=("https",))
|
||||
assert result == "https://example.com"
|
||||
|
||||
with pytest.raises(ValueError, match="URL scheme must be one of"):
|
||||
validate_url("http://example.com", "https://default.com", allowed_schemes=("https",))
|
||||
|
||||
|
||||
class TestValidateUrlWithPath:
|
||||
"""Test cases for validate_url_with_path function"""
|
||||
|
||||
def test_valid_url_with_path(self):
|
||||
"""Test valid URL with path"""
|
||||
result = validate_url_with_path("https://example.com/api/v1", "https://default.com")
|
||||
assert result == "https://example.com/api/v1"
|
||||
|
||||
def test_valid_url_with_required_suffix(self):
|
||||
"""Test valid URL with required suffix"""
|
||||
result = validate_url_with_path("https://example.com/api/", "https://default.com", required_suffix="/api/")
|
||||
assert result == "https://example.com/api/"
|
||||
|
||||
def test_url_without_required_suffix_raises_error(self):
|
||||
"""Test URL without required suffix raises error"""
|
||||
with pytest.raises(ValueError, match="URL should end with /api/"):
|
||||
validate_url_with_path("https://example.com/api", "https://default.com", required_suffix="/api/")
|
||||
|
||||
def test_empty_url_returns_default(self):
|
||||
"""Test empty URL returns default"""
|
||||
result = validate_url_with_path("", "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_none_url_returns_default(self):
|
||||
"""Test None URL returns default"""
|
||||
result = validate_url_with_path(None, "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_invalid_scheme_raises_error(self):
|
||||
"""Test invalid scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL must start with https:// or http://"):
|
||||
validate_url_with_path("ftp://example.com", "https://default.com")
|
||||
|
||||
def test_no_scheme_raises_error(self):
|
||||
"""Test URL without scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL must start with https:// or http://"):
|
||||
validate_url_with_path("example.com", "https://default.com")
|
||||
|
||||
|
||||
class TestValidateProjectName:
|
||||
"""Test cases for validate_project_name function"""
|
||||
|
||||
def test_valid_project_name(self):
|
||||
"""Test valid project name"""
|
||||
result = validate_project_name("my-project", "default")
|
||||
assert result == "my-project"
|
||||
|
||||
def test_empty_project_name_returns_default(self):
|
||||
"""Test empty project name returns default"""
|
||||
result = validate_project_name("", "default")
|
||||
assert result == "default"
|
||||
|
||||
def test_none_project_name_returns_default(self):
|
||||
"""Test None project name returns default"""
|
||||
result = validate_project_name(None, "default")
|
||||
assert result == "default"
|
||||
|
||||
def test_whitespace_project_name_returns_default(self):
|
||||
"""Test whitespace project name returns default"""
|
||||
result = validate_project_name(" ", "default")
|
||||
assert result == "default"
|
||||
|
||||
def test_project_name_with_whitespace_trimmed(self):
|
||||
"""Test project name with whitespace is trimmed"""
|
||||
result = validate_project_name(" my-project ", "default")
|
||||
assert result == "my-project"
|
||||
|
||||
def test_custom_default_name(self):
|
||||
"""Test custom default name"""
|
||||
result = validate_project_name("", "Custom Default")
|
||||
assert result == "Custom Default"
|
||||
0
dify/api/tests/unit_tests/core/plugin/__init__.py
Normal file
0
dify/api/tests/unit_tests/core/plugin/__init__.py
Normal file
460
dify/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
Normal file
460
dify/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
Normal file
@@ -0,0 +1,460 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class TestChunkMerger:
|
||||
def test_file_chunk_initialization(self):
|
||||
"""Test FileChunk initialization."""
|
||||
chunk = FileChunk(1024)
|
||||
assert chunk.bytes_written == 0
|
||||
assert chunk.total_length == 1024
|
||||
assert len(chunk.data) == 1024
|
||||
|
||||
def test_merge_blob_chunks_with_single_complete_chunk(self):
|
||||
"""Test merging a single complete blob chunk."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# First chunk (partial)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=10, blob=b"Hello", end=False
|
||||
),
|
||||
)
|
||||
# Second chunk (final)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=1, total_length=10, blob=b"World", end=True
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
# The buffer should contain the complete data
|
||||
assert result[0].message.blob[:10] == b"HelloWorld"
|
||||
|
||||
def test_merge_blob_chunks_with_multiple_files(self):
|
||||
"""Test merging chunks from multiple files."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# File 1, chunk 1
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=4, blob=b"AB", end=False
|
||||
),
|
||||
)
|
||||
# File 2, chunk 1
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file2", sequence=0, total_length=4, blob=b"12", end=False
|
||||
),
|
||||
)
|
||||
# File 1, chunk 2 (final)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=1, total_length=4, blob=b"CD", end=True
|
||||
),
|
||||
)
|
||||
# File 2, chunk 2 (final)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file2", sequence=1, total_length=4, blob=b"34", end=True
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 2
|
||||
# Check that both files are properly merged
|
||||
assert all(r.type == ToolInvokeMessage.MessageType.BLOB for r in result)
|
||||
|
||||
def test_merge_blob_chunks_passes_through_non_blob_messages(self):
|
||||
"""Test that non-blob messages pass through unchanged."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Text message
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text="Hello"),
|
||||
)
|
||||
# Blob chunk
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=5, blob=b"Test", end=True
|
||||
),
|
||||
)
|
||||
# Another text message
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text="World"),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 3
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.TEXT
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.TextMessage)
|
||||
assert result[0].message.text == "Hello"
|
||||
assert result[1].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert result[2].type == ToolInvokeMessage.MessageType.TEXT
|
||||
assert isinstance(result[2].message, ToolInvokeMessage.TextMessage)
|
||||
assert result[2].message.text == "World"
|
||||
|
||||
def test_merge_blob_chunks_file_too_large(self):
|
||||
"""Test that error is raised when file exceeds max size."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Send a chunk that would exceed the limit
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=100, blob=b"x" * 1024, end=False
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
list(merge_blob_chunks(mock_generator(), max_file_size=1000))
|
||||
assert "File is too large" in str(exc_info.value)
|
||||
|
||||
def test_merge_blob_chunks_chunk_too_large(self):
|
||||
"""Test that error is raised when chunk exceeds max chunk size."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Send a chunk that exceeds the max chunk size
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=10000, blob=b"x" * 9000, end=False
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
list(merge_blob_chunks(mock_generator(), max_chunk_size=8192))
|
||||
assert "File chunk is too large" in str(exc_info.value)
|
||||
|
||||
def test_merge_blob_chunks_with_agent_invoke_message(self):
|
||||
"""Test that merge_blob_chunks works with AgentInvokeMessage."""
|
||||
|
||||
def mock_generator() -> Generator[AgentInvokeMessage, None, None]:
|
||||
# First chunk
|
||||
yield AgentInvokeMessage(
|
||||
type=AgentInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=AgentInvokeMessage.BlobChunkMessage(
|
||||
id="agent_file", sequence=0, total_length=8, blob=b"Agent", end=False
|
||||
),
|
||||
)
|
||||
# Final chunk
|
||||
yield AgentInvokeMessage(
|
||||
type=AgentInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=AgentInvokeMessage.BlobChunkMessage(
|
||||
id="agent_file", sequence=1, total_length=8, blob=b"Data", end=True
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], AgentInvokeMessage)
|
||||
assert result[0].type == AgentInvokeMessage.MessageType.BLOB
|
||||
|
||||
def test_merge_blob_chunks_preserves_meta(self):
|
||||
"""Test that meta information is preserved in merged messages."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=4, blob=b"Test", end=True
|
||||
),
|
||||
meta={"key": "value"},
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].meta == {"key": "value"}
|
||||
|
||||
def test_merge_blob_chunks_custom_limits(self):
|
||||
"""Test merge_blob_chunks with custom size limits."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# This should work with custom limits
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=500, blob=b"x" * 400, end=False
|
||||
),
|
||||
)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=1, total_length=500, blob=b"y" * 100, end=True
|
||||
),
|
||||
)
|
||||
|
||||
# Should work with custom limits
|
||||
result = list(merge_blob_chunks(mock_generator(), max_file_size=1000, max_chunk_size=500))
|
||||
assert len(result) == 1
|
||||
|
||||
# Should fail with smaller file size limit
|
||||
def mock_generator2() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=500, blob=b"x" * 400, end=False
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
list(merge_blob_chunks(mock_generator2(), max_file_size=300))
|
||||
|
||||
def test_merge_blob_chunks_data_integrity(self):
|
||||
"""Test that merged chunks exactly match the original data."""
|
||||
# Create original data
|
||||
original_data = b"This is a test message that will be split into chunks for testing purposes."
|
||||
chunk_size = 20
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Split original data into chunks
|
||||
chunks = []
|
||||
for i in range(0, len(original_data), chunk_size):
|
||||
chunk_data = original_data[i : i + chunk_size]
|
||||
is_last = (i + chunk_size) >= len(original_data)
|
||||
chunks.append((i // chunk_size, chunk_data, is_last))
|
||||
|
||||
# Yield chunks
|
||||
for sequence, data, is_end in chunks:
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="test_file",
|
||||
sequence=sequence,
|
||||
total_length=len(original_data),
|
||||
blob=data,
|
||||
end=is_end,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
# Verify the merged data exactly matches the original
|
||||
assert result[0].message.blob == original_data
|
||||
|
||||
def test_merge_blob_chunks_empty_chunk(self):
|
||||
"""Test handling of empty chunks."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# First chunk with data
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=0, total_length=10, blob=b"Hello", end=False
|
||||
),
|
||||
)
|
||||
# Empty chunk in the middle
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=1, total_length=10, blob=b"", end=False
|
||||
),
|
||||
)
|
||||
# Final chunk with data
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="file1", sequence=2, total_length=10, blob=b"World", end=True
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
# The final blob should contain "Hello" followed by "World"
|
||||
assert result[0].message.blob[:10] == b"HelloWorld"
|
||||
|
||||
def test_merge_blob_chunks_single_chunk_file(self):
|
||||
"""Test file that arrives as a single complete chunk."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Single chunk that is both first and last
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="single_chunk_file",
|
||||
sequence=0,
|
||||
total_length=11,
|
||||
blob=b"Single Data",
|
||||
end=True,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert result[0].type == ToolInvokeMessage.MessageType.BLOB
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
assert result[0].message.blob == b"Single Data"
|
||||
|
||||
def test_merge_blob_chunks_concurrent_files(self):
|
||||
"""Test that chunks from different files are properly separated."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Interleave chunks from three different files
|
||||
files_data = {
|
||||
"file1": b"First file content",
|
||||
"file2": b"Second file data",
|
||||
"file3": b"Third file",
|
||||
}
|
||||
|
||||
# First chunk from each file
|
||||
for file_id, data in files_data.items():
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id=file_id,
|
||||
sequence=0,
|
||||
total_length=len(data),
|
||||
blob=data[:6],
|
||||
end=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Second chunk from each file (final)
|
||||
for file_id, data in files_data.items():
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id=file_id,
|
||||
sequence=1,
|
||||
total_length=len(data),
|
||||
blob=data[6:],
|
||||
end=True,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 3
|
||||
|
||||
# Extract the blob data from results
|
||||
blobs = set()
|
||||
for r in result:
|
||||
assert isinstance(r.message, ToolInvokeMessage.BlobMessage)
|
||||
blobs.add(r.message.blob)
|
||||
expected = {b"First file content", b"Second file data", b"Third file"}
|
||||
assert blobs == expected
|
||||
|
||||
def test_merge_blob_chunks_exact_buffer_size(self):
|
||||
"""Test that data fitting exactly in buffer works correctly."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Create data that exactly fills the declared buffer
|
||||
exact_data = b"X" * 100
|
||||
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="exact_file",
|
||||
sequence=0,
|
||||
total_length=100,
|
||||
blob=exact_data[:50],
|
||||
end=False,
|
||||
),
|
||||
)
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="exact_file",
|
||||
sequence=1,
|
||||
total_length=100,
|
||||
blob=exact_data[50:],
|
||||
end=True,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
assert len(result[0].message.blob) == 100
|
||||
assert result[0].message.blob == b"X" * 100
|
||||
|
||||
def test_merge_blob_chunks_large_file_simulation(self):
|
||||
"""Test handling of a large file split into many chunks."""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Simulate a 1MB file split into 128 chunks of 8KB each
|
||||
chunk_size = 8192
|
||||
num_chunks = 128
|
||||
total_size = chunk_size * num_chunks
|
||||
|
||||
for i in range(num_chunks):
|
||||
# Create unique data for each chunk to verify ordering
|
||||
chunk_data = bytes([i % 256]) * chunk_size
|
||||
is_last = i == num_chunks - 1
|
||||
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="large_file",
|
||||
sequence=i,
|
||||
total_length=total_size,
|
||||
blob=chunk_data,
|
||||
end=is_last,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
assert len(result[0].message.blob) == 1024 * 1024
|
||||
|
||||
# Verify the data pattern is correct
|
||||
merged_data = result[0].message.blob
|
||||
chunk_size = 8192
|
||||
num_chunks = 128
|
||||
for i in range(num_chunks):
|
||||
chunk_start = i * chunk_size
|
||||
chunk_end = chunk_start + chunk_size
|
||||
expected_byte = i % 256
|
||||
chunk = merged_data[chunk_start:chunk_end]
|
||||
assert all(b == expected_byte for b in chunk), f"Chunk {i} has incorrect data"
|
||||
|
||||
def test_merge_blob_chunks_sequential_order_required(self):
|
||||
"""
|
||||
Test note: The current implementation assumes chunks arrive in sequential order.
|
||||
Out-of-order chunks would need additional logic to handle properly.
|
||||
This test documents the expected behavior with sequential chunks.
|
||||
"""
|
||||
|
||||
def mock_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
# Chunks arriving in correct sequential order
|
||||
data_parts = [b"First", b"Second", b"Third"]
|
||||
total_length = sum(len(part) for part in data_parts)
|
||||
|
||||
for i, part in enumerate(data_parts):
|
||||
is_last = i == len(data_parts) - 1
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB_CHUNK,
|
||||
message=ToolInvokeMessage.BlobChunkMessage(
|
||||
id="ordered_file",
|
||||
sequence=i,
|
||||
total_length=total_length,
|
||||
blob=part,
|
||||
end=is_last,
|
||||
),
|
||||
)
|
||||
|
||||
result = list(merge_blob_chunks(mock_generator()))
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage)
|
||||
assert result[0].message.blob == b"FirstSecondThird"
|
||||
655
dify/api/tests/unit_tests/core/plugin/utils/test_http_parser.py
Normal file
655
dify/api/tests/unit_tests/core/plugin/utils/test_http_parser.py
Normal file
@@ -0,0 +1,655 @@
|
||||
import pytest
|
||||
from flask import Request, Response
|
||||
|
||||
from core.plugin.utils.http_parser import (
|
||||
deserialize_request,
|
||||
deserialize_response,
|
||||
serialize_request,
|
||||
serialize_response,
|
||||
)
|
||||
|
||||
|
||||
class TestSerializeRequest:
|
||||
def test_serialize_simple_get_request(self):
|
||||
# Create a simple GET request
|
||||
environ = {
|
||||
"REQUEST_METHOD": "GET",
|
||||
"PATH_INFO": "/api/test",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": None,
|
||||
"wsgi.url_scheme": "http",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert raw_data.startswith(b"GET /api/test HTTP/1.1\r\n")
|
||||
assert b"\r\n\r\n" in raw_data # Empty line between headers and body
|
||||
|
||||
def test_serialize_request_with_query_params(self):
|
||||
# Create a GET request with query parameters
|
||||
environ = {
|
||||
"REQUEST_METHOD": "GET",
|
||||
"PATH_INFO": "/api/search",
|
||||
"QUERY_STRING": "q=test&limit=10",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": None,
|
||||
"wsgi.url_scheme": "http",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert raw_data.startswith(b"GET /api/search?q=test&limit=10 HTTP/1.1\r\n")
|
||||
|
||||
def test_serialize_post_request_with_body(self):
|
||||
# Create a POST request with body
|
||||
from io import BytesIO
|
||||
|
||||
body = b'{"name": "test", "value": 123}'
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/data",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": "application/json",
|
||||
"HTTP_CONTENT_TYPE": "application/json",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"POST /api/data HTTP/1.1\r\n" in raw_data
|
||||
assert b"Content-Type: application/json" in raw_data
|
||||
assert raw_data.endswith(body)
|
||||
|
||||
def test_serialize_request_with_custom_headers(self):
|
||||
# Create a request with custom headers
|
||||
environ = {
|
||||
"REQUEST_METHOD": "GET",
|
||||
"PATH_INFO": "/api/test",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": None,
|
||||
"wsgi.url_scheme": "http",
|
||||
"HTTP_AUTHORIZATION": "Bearer token123",
|
||||
"HTTP_X_CUSTOM_HEADER": "custom-value",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"Authorization: Bearer token123" in raw_data
|
||||
assert b"X-Custom-Header: custom-value" in raw_data
|
||||
|
||||
|
||||
class TestDeserializeRequest:
|
||||
def test_deserialize_simple_get_request(self):
|
||||
raw_data = b"GET /api/test HTTP/1.1\r\nHost: localhost:8000\r\n\r\n"
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/api/test"
|
||||
assert request.headers.get("Host") == "localhost:8000"
|
||||
|
||||
def test_deserialize_request_with_query_params(self):
|
||||
raw_data = b"GET /api/search?q=test&limit=10 HTTP/1.1\r\nHost: example.com\r\n\r\n"
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/api/search"
|
||||
assert request.query_string == b"q=test&limit=10"
|
||||
assert request.args.get("q") == "test"
|
||||
assert request.args.get("limit") == "10"
|
||||
|
||||
def test_deserialize_post_request_with_body(self):
|
||||
body = b'{"name": "test", "value": 123}'
|
||||
raw_data = (
|
||||
b"POST /api/data HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Content-Type: application/json\r\n"
|
||||
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
|
||||
b"\r\n" + body
|
||||
)
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "POST"
|
||||
assert request.path == "/api/data"
|
||||
assert request.content_type == "application/json"
|
||||
assert request.get_data() == body
|
||||
|
||||
def test_deserialize_request_with_custom_headers(self):
|
||||
raw_data = (
|
||||
b"GET /api/protected HTTP/1.1\r\n"
|
||||
b"Host: api.example.com\r\n"
|
||||
b"Authorization: Bearer token123\r\n"
|
||||
b"X-Custom-Header: custom-value\r\n"
|
||||
b"User-Agent: TestClient/1.0\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.headers.get("Authorization") == "Bearer token123"
|
||||
assert request.headers.get("X-Custom-Header") == "custom-value"
|
||||
assert request.headers.get("User-Agent") == "TestClient/1.0"
|
||||
|
||||
def test_deserialize_request_with_multiline_body(self):
|
||||
body = b"line1\r\nline2\r\nline3"
|
||||
raw_data = b"PUT /api/text HTTP/1.1\r\nHost: localhost\r\nContent-Type: text/plain\r\n\r\n" + body
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "PUT"
|
||||
assert request.get_data() == body
|
||||
|
||||
def test_deserialize_invalid_request_line(self):
|
||||
raw_data = b"INVALID\r\n\r\n" # Only one part, should fail
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid request line"):
|
||||
deserialize_request(raw_data)
|
||||
|
||||
def test_roundtrip_request(self):
|
||||
# Test that serialize -> deserialize produces equivalent request
|
||||
from io import BytesIO
|
||||
|
||||
body = b"test body content"
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/echo",
|
||||
"QUERY_STRING": "format=json",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8080",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": "text/plain",
|
||||
"HTTP_CONTENT_TYPE": "text/plain",
|
||||
"HTTP_X_REQUEST_ID": "req-123",
|
||||
}
|
||||
original_request = Request(environ)
|
||||
|
||||
# Serialize and deserialize
|
||||
raw_data = serialize_request(original_request)
|
||||
restored_request = deserialize_request(raw_data)
|
||||
|
||||
# Verify key properties are preserved
|
||||
assert restored_request.method == original_request.method
|
||||
assert restored_request.path == original_request.path
|
||||
assert restored_request.query_string == original_request.query_string
|
||||
assert restored_request.get_data() == body
|
||||
assert restored_request.headers.get("X-Request-Id") == "req-123"
|
||||
|
||||
|
||||
class TestSerializeResponse:
|
||||
def test_serialize_simple_response(self):
|
||||
response = Response("Hello, World!", status=200)
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert raw_data.startswith(b"HTTP/1.1 200 OK\r\n")
|
||||
assert b"\r\n\r\n" in raw_data
|
||||
assert raw_data.endswith(b"Hello, World!")
|
||||
|
||||
def test_serialize_response_with_headers(self):
|
||||
response = Response(
|
||||
'{"status": "success"}',
|
||||
status=201,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-Request-Id": "req-456",
|
||||
},
|
||||
)
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert b"HTTP/1.1 201 CREATED\r\n" in raw_data
|
||||
assert b"Content-Type: application/json" in raw_data
|
||||
assert b"X-Request-Id: req-456" in raw_data
|
||||
assert raw_data.endswith(b'{"status": "success"}')
|
||||
|
||||
def test_serialize_error_response(self):
|
||||
response = Response(
|
||||
"Not Found",
|
||||
status=404,
|
||||
headers={"Content-Type": "text/plain"},
|
||||
)
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert b"HTTP/1.1 404 NOT FOUND\r\n" in raw_data
|
||||
assert b"Content-Type: text/plain" in raw_data
|
||||
assert raw_data.endswith(b"Not Found")
|
||||
|
||||
def test_serialize_response_without_body(self):
|
||||
response = Response(status=204) # No Content
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert b"HTTP/1.1 204 NO CONTENT\r\n" in raw_data
|
||||
assert raw_data.endswith(b"\r\n\r\n") # Should end with empty line
|
||||
|
||||
def test_serialize_response_with_binary_body(self):
|
||||
binary_data = b"\x00\x01\x02\x03\x04\x05"
|
||||
response = Response(
|
||||
binary_data,
|
||||
status=200,
|
||||
headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert b"HTTP/1.1 200 OK\r\n" in raw_data
|
||||
assert b"Content-Type: application/octet-stream" in raw_data
|
||||
assert raw_data.endswith(binary_data)
|
||||
|
||||
|
||||
class TestDeserializeResponse:
|
||||
def test_deserialize_simple_response(self):
|
||||
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nHello, World!"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == b"Hello, World!"
|
||||
assert response.headers.get("Content-Type") == "text/plain"
|
||||
|
||||
def test_deserialize_response_with_json(self):
|
||||
body = b'{"result": "success", "data": [1, 2, 3]}'
|
||||
raw_data = (
|
||||
b"HTTP/1.1 201 Created\r\n"
|
||||
b"Content-Type: application/json\r\n"
|
||||
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
|
||||
b"X-Custom-Header: test-value\r\n"
|
||||
b"\r\n" + body
|
||||
)
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.get_data() == body
|
||||
assert response.headers.get("Content-Type") == "application/json"
|
||||
assert response.headers.get("X-Custom-Header") == "test-value"
|
||||
|
||||
def test_deserialize_error_response(self):
|
||||
raw_data = b"HTTP/1.1 404 Not Found\r\nContent-Type: text/html\r\n\r\n<html><body>Page not found</body></html>"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.get_data() == b"<html><body>Page not found</body></html>"
|
||||
|
||||
def test_deserialize_response_without_body(self):
|
||||
raw_data = b"HTTP/1.1 204 No Content\r\n\r\n"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert response.get_data() == b""
|
||||
|
||||
def test_deserialize_response_with_multiline_body(self):
|
||||
body = b"Line 1\r\nLine 2\r\nLine 3"
|
||||
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n" + body
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == body
|
||||
|
||||
def test_deserialize_response_minimal_status_line(self):
|
||||
# Test with minimal status line (no status text)
|
||||
raw_data = b"HTTP/1.1 200\r\n\r\nOK"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == b"OK"
|
||||
|
||||
def test_deserialize_invalid_status_line(self):
|
||||
raw_data = b"INVALID\r\n\r\n"
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid status line"):
|
||||
deserialize_response(raw_data)
|
||||
|
||||
def test_roundtrip_response(self):
|
||||
# Test that serialize -> deserialize produces equivalent response
|
||||
original_response = Response(
|
||||
'{"message": "test"}',
|
||||
status=200,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-Request-Id": "abc-123",
|
||||
"Cache-Control": "no-cache",
|
||||
},
|
||||
)
|
||||
|
||||
# Serialize and deserialize
|
||||
raw_data = serialize_response(original_response)
|
||||
restored_response = deserialize_response(raw_data)
|
||||
|
||||
# Verify key properties are preserved
|
||||
assert restored_response.status_code == original_response.status_code
|
||||
assert restored_response.get_data() == original_response.get_data()
|
||||
assert restored_response.headers.get("Content-Type") == "application/json"
|
||||
assert restored_response.headers.get("X-Request-Id") == "abc-123"
|
||||
assert restored_response.headers.get("Cache-Control") == "no-cache"
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_request_with_empty_headers(self):
|
||||
raw_data = b"GET / HTTP/1.1\r\n\r\n"
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/"
|
||||
|
||||
def test_response_with_empty_headers(self):
|
||||
raw_data = b"HTTP/1.1 200 OK\r\n\r\nSuccess"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == b"Success"
|
||||
|
||||
def test_request_with_special_characters_in_path(self):
|
||||
raw_data = b"GET /api/test%20path?key=%26value HTTP/1.1\r\n\r\n"
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert "/api/test%20path" in request.full_path
|
||||
|
||||
def test_response_with_binary_content(self):
|
||||
binary_body = bytes(range(256)) # All possible byte values
|
||||
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\n\r\n" + binary_body
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == binary_body
|
||||
|
||||
|
||||
class TestFileUploads:
|
||||
def test_serialize_request_with_text_file_upload(self):
|
||||
# Test multipart/form-data request with text file
|
||||
from io import BytesIO
|
||||
|
||||
boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW"
|
||||
text_content = "Hello, this is a test file content!\nWith multiple lines."
|
||||
body = (
|
||||
f"------{boundary}\r\n"
|
||||
f'Content-Disposition: form-data; name="file"; filename="test.txt"\r\n'
|
||||
f"Content-Type: text/plain\r\n"
|
||||
f"\r\n"
|
||||
f"{text_content}\r\n"
|
||||
f"------{boundary}\r\n"
|
||||
f'Content-Disposition: form-data; name="description"\r\n'
|
||||
f"\r\n"
|
||||
f"Test file upload\r\n"
|
||||
f"------{boundary}--\r\n"
|
||||
).encode()
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/upload",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"POST /api/upload HTTP/1.1\r\n" in raw_data
|
||||
assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data
|
||||
assert b'Content-Disposition: form-data; name="file"; filename="test.txt"' in raw_data
|
||||
assert text_content.encode() in raw_data
|
||||
|
||||
def test_deserialize_request_with_text_file_upload(self):
|
||||
# Test deserializing multipart/form-data request with text file
|
||||
boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW"
|
||||
text_content = "Sample text file content\nLine 2\nLine 3"
|
||||
body = (
|
||||
f"------{boundary}\r\n"
|
||||
f'Content-Disposition: form-data; name="document"; filename="document.txt"\r\n'
|
||||
f"Content-Type: text/plain\r\n"
|
||||
f"\r\n"
|
||||
f"{text_content}\r\n"
|
||||
f"------{boundary}\r\n"
|
||||
f'Content-Disposition: form-data; name="title"\r\n'
|
||||
f"\r\n"
|
||||
f"My Document\r\n"
|
||||
f"------{boundary}--\r\n"
|
||||
).encode()
|
||||
|
||||
raw_data = (
|
||||
b"POST /api/documents HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Type: multipart/form-data; boundary=" + boundary.encode() + b"\r\n"
|
||||
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
|
||||
b"\r\n" + body
|
||||
)
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "POST"
|
||||
assert request.path == "/api/documents"
|
||||
assert "multipart/form-data" in request.content_type
|
||||
# The body should contain the multipart data
|
||||
request_body = request.get_data()
|
||||
assert b"document.txt" in request_body
|
||||
assert text_content.encode() in request_body
|
||||
|
||||
def test_serialize_request_with_binary_file_upload(self):
|
||||
# Test multipart/form-data request with binary file (e.g., image)
|
||||
from io import BytesIO
|
||||
|
||||
boundary = "----BoundaryString123"
|
||||
# Simulate a small PNG file header
|
||||
binary_content = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x10\x00\x00\x00\x10"
|
||||
|
||||
# Build multipart body
|
||||
body_parts = []
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="image"; filename="test.png"')
|
||||
body_parts.append(b"Content-Type: image/png")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(binary_content)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="caption"')
|
||||
body_parts.append(b"")
|
||||
body_parts.append(b"Test image")
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/images",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"POST /api/images HTTP/1.1\r\n" in raw_data
|
||||
assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data
|
||||
assert b'filename="test.png"' in raw_data
|
||||
assert b"Content-Type: image/png" in raw_data
|
||||
assert binary_content in raw_data
|
||||
|
||||
def test_deserialize_request_with_binary_file_upload(self):
|
||||
# Test deserializing multipart/form-data request with binary file
|
||||
boundary = "----BoundaryABC123"
|
||||
# Simulate a small JPEG file header
|
||||
binary_content = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00"
|
||||
|
||||
body_parts = []
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="photo"; filename="photo.jpg"')
|
||||
body_parts.append(b"Content-Type: image/jpeg")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(binary_content)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="album"')
|
||||
body_parts.append(b"")
|
||||
body_parts.append(b"Vacation 2024")
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
raw_data = (
|
||||
b"POST /api/photos HTTP/1.1\r\n"
|
||||
b"Host: api.example.com\r\n"
|
||||
b"Content-Type: multipart/form-data; boundary=" + boundary.encode() + b"\r\n"
|
||||
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
|
||||
b"Accept: application/json\r\n"
|
||||
b"\r\n" + body
|
||||
)
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "POST"
|
||||
assert request.path == "/api/photos"
|
||||
assert "multipart/form-data" in request.content_type
|
||||
assert request.headers.get("Accept") == "application/json"
|
||||
|
||||
# Verify the binary content is preserved
|
||||
request_body = request.get_data()
|
||||
assert b"photo.jpg" in request_body
|
||||
assert b"image/jpeg" in request_body
|
||||
assert binary_content in request_body
|
||||
assert b"Vacation 2024" in request_body
|
||||
|
||||
def test_serialize_request_with_multiple_files(self):
|
||||
# Test request with multiple file uploads
|
||||
from io import BytesIO
|
||||
|
||||
boundary = "----MultiFilesBoundary"
|
||||
text_file = b"Text file contents"
|
||||
binary_file = b"\x00\x01\x02\x03\x04\x05"
|
||||
|
||||
body_parts = []
|
||||
# First file (text)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="files"; filename="doc.txt"')
|
||||
body_parts.append(b"Content-Type: text/plain")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(text_file)
|
||||
# Second file (binary)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="files"; filename="data.bin"')
|
||||
body_parts.append(b"Content-Type: application/octet-stream")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(binary_file)
|
||||
# Additional form field
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="folder"')
|
||||
body_parts.append(b"")
|
||||
body_parts.append(b"uploads/2024")
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/batch-upload",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "https",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_X_FORWARDED_PROTO": "https",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"POST /api/batch-upload HTTP/1.1\r\n" in raw_data
|
||||
assert b"doc.txt" in raw_data
|
||||
assert b"data.bin" in raw_data
|
||||
assert text_file in raw_data
|
||||
assert binary_file in raw_data
|
||||
assert b"uploads/2024" in raw_data
|
||||
|
||||
def test_roundtrip_file_upload_request(self):
|
||||
# Test that file upload request survives serialize -> deserialize
|
||||
from io import BytesIO
|
||||
|
||||
boundary = "----RoundTripBoundary"
|
||||
file_content = b"This is my file content with special chars: \xf0\x9f\x98\x80"
|
||||
|
||||
body_parts = []
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="upload"; filename="emoji.txt"')
|
||||
body_parts.append(b"Content-Type: text/plain; charset=utf-8")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(file_content)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="metadata"')
|
||||
body_parts.append(b"")
|
||||
body_parts.append(b'{"encoding": "utf-8", "size": 42}')
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": "PUT",
|
||||
"PATH_INFO": "/api/files/123",
|
||||
"QUERY_STRING": "version=2",
|
||||
"SERVER_NAME": "storage.example.com",
|
||||
"SERVER_PORT": "443",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "https",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_AUTHORIZATION": "Bearer token123",
|
||||
"HTTP_X_FORWARDED_PROTO": "https",
|
||||
}
|
||||
original_request = Request(environ)
|
||||
|
||||
# Serialize and deserialize
|
||||
raw_data = serialize_request(original_request)
|
||||
restored_request = deserialize_request(raw_data)
|
||||
|
||||
# Verify the request is preserved
|
||||
assert restored_request.method == "PUT"
|
||||
assert restored_request.path == "/api/files/123"
|
||||
assert restored_request.query_string == b"version=2"
|
||||
assert "multipart/form-data" in restored_request.content_type
|
||||
assert boundary in restored_request.content_type
|
||||
|
||||
# Verify file content is preserved
|
||||
restored_body = restored_request.get_data()
|
||||
assert b"emoji.txt" in restored_body
|
||||
assert file_content in restored_body
|
||||
assert b'{"encoding": "utf-8", "size": 42}' in restored_body
|
||||
0
dify/api/tests/unit_tests/core/prompt/__init__.py
Normal file
0
dify/api/tests/unit_tests/core/prompt/__init__.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageRole,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from models.model import Conversation
|
||||
|
||||
|
||||
def test__get_completion_model_prompt_messages():
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = "gpt-3.5-turbo-instruct"
|
||||
|
||||
prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}."
|
||||
prompt_template_config = CompletionModelPromptTemplate(text=prompt_template)
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
|
||||
window=MemoryConfig.WindowConfig(enabled=False),
|
||||
)
|
||||
|
||||
inputs = {"name": "John"}
|
||||
files = []
|
||||
context = "I am superman."
|
||||
|
||||
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
|
||||
|
||||
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
|
||||
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_completion_model_prompt_messages(
|
||||
prompt_template=prompt_template_config,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 1
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format(
|
||||
{
|
||||
"#context#": context,
|
||||
"#histories#": "\n".join(
|
||||
[
|
||||
f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: {prompt.content}"
|
||||
for prompt in history_prompt_messages
|
||||
]
|
||||
),
|
||||
**inputs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages(get_chat_model_args):
|
||||
model_config_mock, memory_config, messages, inputs, context = get_chat_model_args
|
||||
|
||||
files = []
|
||||
query = "Hi2."
|
||||
|
||||
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
|
||||
|
||||
history_prompt_messages = [UserPromptMessage(content="Hi1."), AssistantPromptMessage(content="Hello1!")]
|
||||
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 6
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
|
||||
{**inputs, "#context#": context}
|
||||
)
|
||||
assert prompt_messages[5].content == query
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
|
||||
model_config_mock, _, messages, inputs, context = get_chat_model_args
|
||||
|
||||
files = []
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 3
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
|
||||
{**inputs, "#context#": context}
|
||||
)
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
|
||||
model_config_mock, _, messages, inputs, context = get_chat_model_args
|
||||
dify_config.MULTIMODAL_SEND_FORMAT = "url"
|
||||
|
||||
files = [
|
||||
File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
storage_key="",
|
||||
)
|
||||
]
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
|
||||
mock_get_encoded_string.return_value = ImagePromptMessageContent(
|
||||
url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 4
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
|
||||
{**inputs, "#context#": context}
|
||||
)
|
||||
assert isinstance(prompt_messages[3].content, list)
|
||||
assert len(prompt_messages[3].content) == 2
|
||||
assert prompt_messages[3].content[0].data == files[0].remote_url
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_chat_model_args():
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = "gpt-4"
|
||||
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
|
||||
prompt_messages = [
|
||||
ChatModelMessage(
|
||||
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", role=PromptMessageRole.SYSTEM
|
||||
),
|
||||
ChatModelMessage(text="Hi.", role=PromptMessageRole.USER),
|
||||
ChatModelMessage(text="Hello!", role=PromptMessageRole.ASSISTANT),
|
||||
]
|
||||
|
||||
inputs = {"name": "John"}
|
||||
|
||||
context = "I am superman."
|
||||
|
||||
return model_config_mock, memory_config, prompt_messages, inputs, context
|
||||
@@ -0,0 +1,73 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
ModelConfigWithCredentialsEntity,
|
||||
)
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from models.model import Conversation
|
||||
|
||||
|
||||
def test_get_prompt():
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content="System Template"),
|
||||
UserPromptMessage(content="User Query"),
|
||||
]
|
||||
history_messages = [
|
||||
SystemPromptMessage(content="System Prompt 1"),
|
||||
UserPromptMessage(content="User Prompt 1"),
|
||||
AssistantPromptMessage(content="Assistant Thought 1"),
|
||||
ToolPromptMessage(content="Tool 1-1", name="Tool 1-1", tool_call_id="1"),
|
||||
ToolPromptMessage(content="Tool 1-2", name="Tool 1-2", tool_call_id="2"),
|
||||
SystemPromptMessage(content="System Prompt 2"),
|
||||
UserPromptMessage(content="User Prompt 2"),
|
||||
AssistantPromptMessage(content="Assistant Thought 2"),
|
||||
ToolPromptMessage(content="Tool 2-1", name="Tool 2-1", tool_call_id="3"),
|
||||
ToolPromptMessage(content="Tool 2-2", name="Tool 2-2", tool_call_id="4"),
|
||||
UserPromptMessage(content="User Prompt 3"),
|
||||
AssistantPromptMessage(content="Assistant Thought 3"),
|
||||
]
|
||||
|
||||
# use message number instead of token for testing
|
||||
def side_effect_get_num_tokens(*args):
|
||||
return len(args[2])
|
||||
|
||||
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
|
||||
large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens)
|
||||
|
||||
provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
|
||||
provider_model_bundle_mock.model_type_instance = large_language_model_mock
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config_mock.model = "openai"
|
||||
model_config_mock.credentials = {}
|
||||
model_config_mock.provider_model_bundle = provider_model_bundle_mock
|
||||
|
||||
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
|
||||
|
||||
transform = AgentHistoryPromptTransform(
|
||||
model_config=model_config_mock,
|
||||
prompt_messages=prompt_messages,
|
||||
history_messages=history_messages,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
max_token_limit = 5
|
||||
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
|
||||
result = transform.get_prompt()
|
||||
|
||||
assert len(result) == 4
|
||||
|
||||
max_token_limit = 20
|
||||
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
|
||||
result = transform.get_prompt()
|
||||
|
||||
assert len(result) == 12
|
||||
@@ -0,0 +1,91 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
|
||||
|
||||
class MockMessage:
|
||||
def __init__(self, id, parent_message_id):
|
||||
self.id = id
|
||||
self.parent_message_id = parent_message_id
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
def test_extract_thread_messages_single_message():
|
||||
messages = [MockMessage(str(uuid4()), UUID_NIL)]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 1
|
||||
assert result[0] == messages[0]
|
||||
|
||||
|
||||
def test_extract_thread_messages_linear_thread():
|
||||
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
MockMessage(id5, id4),
|
||||
MockMessage(id4, id3),
|
||||
MockMessage(id3, id2),
|
||||
MockMessage(id2, id1),
|
||||
MockMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 5
|
||||
assert [msg["id"] for msg in result] == [id5, id4, id3, id2, id1]
|
||||
|
||||
|
||||
def test_extract_thread_messages_branched_thread():
|
||||
id1, id2, id3, id4 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
MockMessage(id4, id2),
|
||||
MockMessage(id3, id2),
|
||||
MockMessage(id2, id1),
|
||||
MockMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert [msg["id"] for msg in result] == [id4, id2, id1]
|
||||
|
||||
|
||||
def test_extract_thread_messages_empty_list():
|
||||
messages = []
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_extract_thread_messages_partially_loaded():
|
||||
id0, id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
MockMessage(id3, id2),
|
||||
MockMessage(id2, id1),
|
||||
MockMessage(id1, id0),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert [msg["id"] for msg in result] == [id3, id2, id1]
|
||||
|
||||
|
||||
def test_extract_thread_messages_legacy_messages():
|
||||
id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
MockMessage(id3, UUID_NIL),
|
||||
MockMessage(id2, UUID_NIL),
|
||||
MockMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert [msg["id"] for msg in result] == [id3, id2, id1]
|
||||
|
||||
|
||||
def test_extract_thread_messages_mixed_with_legacy_messages():
|
||||
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
MockMessage(id5, id4),
|
||||
MockMessage(id4, id2),
|
||||
MockMessage(id3, id2),
|
||||
MockMessage(id2, UUID_NIL),
|
||||
MockMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert [msg["id"] for msg in result] == [id5, id4, id2, id1]
|
||||
27
dify/api/tests/unit_tests/core/prompt/test_prompt_message.py
Normal file
27
dify/api/tests/unit_tests/core/prompt/test_prompt_message.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
def test_build_prompt_message_with_prompt_message_contents():
|
||||
prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")])
|
||||
assert isinstance(prompt.content, list)
|
||||
assert isinstance(prompt.content[0], TextPromptMessageContent)
|
||||
assert prompt.content[0].data == "Hello, World!"
|
||||
|
||||
|
||||
def test_dump_prompt_message():
|
||||
example_url = "https://example.com/image.jpg"
|
||||
prompt = UserPromptMessage(
|
||||
content=[
|
||||
ImagePromptMessageContent(
|
||||
url=example_url,
|
||||
format="jpeg",
|
||||
mime_type="image/jpeg",
|
||||
)
|
||||
]
|
||||
)
|
||||
data = prompt.model_dump()
|
||||
assert data["content"][0].get("url") == example_url
|
||||
@@ -0,0 +1,52 @@
|
||||
# from unittest.mock import MagicMock
|
||||
|
||||
# from core.app.app_config.entities import ModelConfigEntity
|
||||
# from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
# from core.model_runtime.entities.message_entities import UserPromptMessage
|
||||
# from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
|
||||
# from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
# from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
# from core.prompt.prompt_transform import PromptTransform
|
||||
|
||||
|
||||
# def test__calculate_rest_token():
|
||||
# model_schema_mock = MagicMock(spec=AIModelEntity)
|
||||
# parameter_rule_mock = MagicMock(spec=ParameterRule)
|
||||
# parameter_rule_mock.name = "max_tokens"
|
||||
# model_schema_mock.parameter_rules = [parameter_rule_mock]
|
||||
# model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62}
|
||||
|
||||
# large_language_model_mock = MagicMock(spec=LargeLanguageModel)
|
||||
# large_language_model_mock.get_num_tokens.return_value = 6
|
||||
|
||||
# provider_mock = MagicMock(spec=ProviderEntity)
|
||||
# provider_mock.provider = "openai"
|
||||
|
||||
# provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
|
||||
# provider_configuration_mock.provider = provider_mock
|
||||
# provider_configuration_mock.model_settings = None
|
||||
|
||||
# provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
|
||||
# provider_model_bundle_mock.model_type_instance = large_language_model_mock
|
||||
# provider_model_bundle_mock.configuration = provider_configuration_mock
|
||||
|
||||
# model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
# model_config_mock.model = "gpt-4"
|
||||
# model_config_mock.credentials = {}
|
||||
# model_config_mock.parameters = {"max_tokens": 50}
|
||||
# model_config_mock.model_schema = model_schema_mock
|
||||
# model_config_mock.provider_model_bundle = provider_model_bundle_mock
|
||||
|
||||
# prompt_transform = PromptTransform()
|
||||
|
||||
# prompt_messages = [UserPromptMessage(content="Hello, how are you?")]
|
||||
# rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock)
|
||||
|
||||
# # Validate based on the mock configuration and expected logic
|
||||
# expected_rest_tokens = (
|
||||
# model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
# - model_config_mock.parameters["max_tokens"]
|
||||
# - large_language_model_mock.get_num_tokens.return_value
|
||||
# )
|
||||
# assert rest_tokens == expected_rest_tokens
|
||||
# assert rest_tokens == 6
|
||||
@@ -0,0 +1,246 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
|
||||
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
||||
from models.model import AppMode, Conversation
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_pcqm():
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = "You are a helpful assistant."
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.CHAT,
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=True,
|
||||
)
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"]
|
||||
+ pre_prompt
|
||||
+ "\n"
|
||||
+ prompt_rules["histories_prompt"]
|
||||
+ prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"]
|
||||
|
||||
|
||||
def test_get_baichuan_chat_app_prompt_template_with_pcqm():
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = "You are a helpful assistant."
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.CHAT,
|
||||
provider="baichuan",
|
||||
model="Baichuan2-53B",
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=True,
|
||||
)
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"]
|
||||
+ pre_prompt
|
||||
+ "\n"
|
||||
+ prompt_rules["histories_prompt"]
|
||||
+ prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"]
|
||||
|
||||
|
||||
def test_get_common_completion_app_prompt_template_with_pcq():
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = "You are a helpful assistant."
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
|
||||
|
||||
|
||||
def test_get_baichuan_completion_app_prompt_template_with_pcq():
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = "You are a helpful assistant."
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
provider="baichuan",
|
||||
model="Baichuan2-53B",
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_q():
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = ""
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.CHAT,
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=False,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == prompt_rules["query_prompt"]
|
||||
assert prompt_template["special_variable_keys"] == ["#query#"]
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_cq():
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = ""
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.CHAT,
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"] + prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_p():
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = "you are {{name}}"
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.CHAT,
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=False,
|
||||
query_in_prompt=False,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
assert prompt_template["prompt_template"].template == pre_prompt + "\n"
|
||||
assert prompt_template["custom_variable_keys"] == ["name"]
|
||||
assert prompt_template["special_variable_keys"] == []
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages():
|
||||
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = "gpt-4"
|
||||
|
||||
memory_mock = MagicMock(spec=TokenBufferMemory)
|
||||
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
|
||||
memory_mock.get_history_prompt_messages.return_value = history_prompt_messages
|
||||
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
|
||||
pre_prompt = "You are a helpful assistant {{name}}."
|
||||
inputs = {"name": "John"}
|
||||
context = "yes or no."
|
||||
query = "How are you?"
|
||||
prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages(
|
||||
app_mode=AppMode.CHAT,
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=[],
|
||||
context=context,
|
||||
memory=memory_mock,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.CHAT,
|
||||
provider=model_config_mock.provider,
|
||||
model=model_config_mock.model,
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=False,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
|
||||
full_inputs = {**inputs, "#context#": context}
|
||||
real_system_prompt = prompt_template["prompt_template"].format(full_inputs)
|
||||
|
||||
assert len(prompt_messages) == 4
|
||||
assert prompt_messages[0].content == real_system_prompt
|
||||
assert prompt_messages[1].content == history_prompt_messages[0].content
|
||||
assert prompt_messages[2].content == history_prompt_messages[1].content
|
||||
assert prompt_messages[3].content == query
|
||||
|
||||
|
||||
def test__get_completion_model_prompt_messages():
|
||||
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = "gpt-3.5-turbo-instruct"
|
||||
|
||||
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
|
||||
|
||||
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
|
||||
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
|
||||
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
pre_prompt = "You are a helpful assistant {{name}}."
|
||||
inputs = {"name": "John"}
|
||||
context = "yes or no."
|
||||
query = "How are you?"
|
||||
prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages(
|
||||
app_mode=AppMode.CHAT,
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=[],
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.CHAT,
|
||||
provider=model_config_mock.provider,
|
||||
model=model_config_mock.model,
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=True,
|
||||
)
|
||||
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
full_inputs = {
|
||||
**inputs,
|
||||
"#context#": context,
|
||||
"#query#": query,
|
||||
"#histories#": memory.get_history_prompt_text(
|
||||
max_token_limit=2000,
|
||||
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
||||
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
|
||||
),
|
||||
}
|
||||
real_prompt = prompt_template["prompt_template"].format(full_inputs)
|
||||
|
||||
assert len(prompt_messages) == 1
|
||||
assert stops == prompt_rules.get("stops")
|
||||
assert prompt_messages[0].content == real_prompt
|
||||
0
dify/api/tests/unit_tests/core/rag/__init__.py
Normal file
0
dify/api/tests/unit_tests/core/rag/__init__.py
Normal file
@@ -0,0 +1,733 @@
|
||||
import json
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
|
||||
AlibabaCloudMySQLVector,
|
||||
AlibabaCloudMySQLVectorConfig,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
|
||||
try:
|
||||
from mysql.connector import Error as MySQLError
|
||||
except ImportError:
|
||||
# Fallback for testing environments where mysql-connector-python might not be installed
|
||||
class MySQLError(Exception):
|
||||
def __init__(self, errno, msg):
|
||||
self.errno = errno
|
||||
self.msg = msg
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = AlibabaCloudMySQLVectorConfig(
|
||||
host="localhost",
|
||||
port=3306,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
max_connection=5,
|
||||
charset="utf8mb4",
|
||||
)
|
||||
self.collection_name = "test_collection"
|
||||
|
||||
# Sample documents for testing
|
||||
self.sample_documents = [
|
||||
Document(
|
||||
page_content="This is a test document about AI.",
|
||||
metadata={"doc_id": "doc1", "document_id": "dataset1", "source": "test"},
|
||||
),
|
||||
Document(
|
||||
page_content="Another document about machine learning.",
|
||||
metadata={"doc_id": "doc2", "document_id": "dataset1", "source": "test"},
|
||||
),
|
||||
]
|
||||
|
||||
# Sample embeddings
|
||||
self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_init(self, mock_pool_class):
|
||||
"""Test AlibabaCloudMySQLVector initialization."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor for vector support check
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
{"VERSION()": "8.0.36"}, # Version check
|
||||
{"vector_support": True}, # Vector support check
|
||||
]
|
||||
|
||||
alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
|
||||
assert alibabacloud_mysql_vector.collection_name == self.collection_name
|
||||
assert alibabacloud_mysql_vector.table_name == self.collection_name.lower()
|
||||
assert alibabacloud_mysql_vector.get_type() == "alibabacloud_mysql"
|
||||
assert alibabacloud_mysql_vector.distance_function == "cosine"
|
||||
assert alibabacloud_mysql_vector.pool is not None
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||
def test_create_collection(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation."""
|
||||
# Mock Redis operations
|
||||
mock_redis.lock.return_value.__enter__ = MagicMock()
|
||||
mock_redis.lock.return_value.__exit__ = MagicMock()
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
{"VERSION()": "8.0.36"}, # Version check
|
||||
{"vector_support": True}, # Vector support check
|
||||
]
|
||||
|
||||
alibabacloud_mysql_vector = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
alibabacloud_mysql_vector._create_collection(768)
|
||||
|
||||
# Verify SQL execution calls - should include table creation and index creation
|
||||
assert mock_cursor.execute.called
|
||||
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_vector_support_check_success(self, mock_pool_class):
|
||||
"""Test successful vector support check."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
# Should not raise an exception
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
assert vector_store is not None
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_vector_support_check_failure(self, mock_pool_class):
|
||||
"""Test vector support check failure."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.35"}, {"vector_support": False}]
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
|
||||
assert "RDS MySQL Vector functions are not available" in str(context.value)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_vector_support_check_function_error(self, mock_pool_class):
|
||||
"""Test vector support check with function not found error."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = {"VERSION()": "8.0.36"}
|
||||
mock_cursor.execute.side_effect = [None, MySQLError(errno=1305, msg="FUNCTION VEC_FromText does not exist")]
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
|
||||
assert "RDS MySQL Vector functions are not available" in str(context.value)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||
def test_create_documents(self, mock_redis, mock_pool_class):
|
||||
"""Test creating documents with embeddings."""
|
||||
# Setup mocks
|
||||
self._setup_mocks(mock_redis, mock_pool_class)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
result = vector_store.create(self.sample_documents, self.sample_embeddings)
|
||||
|
||||
assert len(result) == 2
|
||||
assert "doc1" in result
|
||||
assert "doc2" in result
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_add_texts(self, mock_pool_class):
|
||||
"""Test adding texts to the vector store."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
result = vector_store.add_texts(self.sample_documents, self.sample_embeddings)
|
||||
|
||||
assert len(result) == 2
|
||||
mock_cursor.executemany.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_text_exists(self, mock_pool_class):
|
||||
"""Test checking if text exists."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
{"VERSION()": "8.0.36"},
|
||||
{"vector_support": True},
|
||||
{"id": "doc1"}, # Text exists
|
||||
]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
exists = vector_store.text_exists("doc1")
|
||||
|
||||
assert exists
|
||||
# Check that the correct SQL was executed (last call after init)
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
last_call = execute_calls[-1]
|
||||
assert "SELECT id FROM" in last_call[0][0]
|
||||
assert last_call[0][1] == ("doc1",)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_text_not_exists(self, mock_pool_class):
|
||||
"""Test checking if text does not exist."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [
|
||||
{"VERSION()": "8.0.36"},
|
||||
{"vector_support": True},
|
||||
None, # Text does not exist
|
||||
]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
exists = vector_store.text_exists("nonexistent")
|
||||
|
||||
assert not exists
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_get_by_ids(self, mock_pool_class):
|
||||
"""Test getting documents by IDs."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[
|
||||
{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1"},
|
||||
{"meta": json.dumps({"doc_id": "doc2", "source": "test"}), "text": "Test document 2"},
|
||||
]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
docs = vector_store.get_by_ids(["doc1", "doc2"])
|
||||
|
||||
assert len(docs) == 2
|
||||
assert docs[0].page_content == "Test document 1"
|
||||
assert docs[1].page_content == "Test document 2"
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_get_by_ids_empty_list(self, mock_pool_class):
|
||||
"""Test getting documents with empty ID list."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
docs = vector_store.get_by_ids([])
|
||||
|
||||
assert len(docs) == 0
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_by_ids(self, mock_pool_class):
|
||||
"""Test deleting documents by IDs."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
vector_store.delete_by_ids(["doc1", "doc2"])
|
||||
|
||||
# Check that delete SQL was executed
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
|
||||
assert len(delete_calls) == 1
|
||||
delete_call = delete_calls[0]
|
||||
assert "DELETE FROM" in delete_call[0][0]
|
||||
assert delete_call[0][1] == ["doc1", "doc2"]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_by_ids_empty_list(self, mock_pool_class):
|
||||
"""Test deleting with empty ID list."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
vector_store.delete_by_ids([]) # Should not raise an exception
|
||||
|
||||
# Verify no delete SQL was executed
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
|
||||
assert len(delete_calls) == 0
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_by_ids_table_not_exists(self, mock_pool_class):
|
||||
"""Test deleting when table doesn't exist."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
# Simulate table doesn't exist error on delete
|
||||
|
||||
def execute_side_effect(*args, **kwargs):
|
||||
if "DELETE" in args[0]:
|
||||
raise MySQLError(errno=1146, msg="Table doesn't exist")
|
||||
|
||||
mock_cursor.execute.side_effect = execute_side_effect
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
# Should not raise an exception
|
||||
vector_store.delete_by_ids(["doc1"])
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_by_metadata_field(self, mock_pool_class):
|
||||
"""Test deleting documents by metadata field."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
vector_store.delete_by_metadata_field("document_id", "dataset1")
|
||||
|
||||
# Check that the correct SQL was executed
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
|
||||
assert len(delete_calls) == 1
|
||||
delete_call = delete_calls[0]
|
||||
assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0]
|
||||
assert delete_call[0][1] == ("$.document_id", "dataset1")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_cosine(self, mock_pool_class):
|
||||
"""Test vector search with cosine distance."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 0.1}]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
docs = vector_store.search_by_vector(query_vector, top_k=5)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "Test document 1"
|
||||
assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9
|
||||
assert docs[0].metadata["distance"] == 0.1
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_euclidean(self, mock_pool_class):
|
||||
"""Test vector search with euclidean distance."""
|
||||
config = AlibabaCloudMySQLVectorConfig(
|
||||
host="localhost",
|
||||
port=3306,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
max_connection=5,
|
||||
distance_function="euclidean",
|
||||
)
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[{"meta": json.dumps({"doc_id": "doc1", "source": "test"}), "text": "Test document 1", "distance": 2.0}]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
docs = vector_store.search_by_vector(query_vector, top_k=5)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_with_filter(self, mock_pool_class):
|
||||
"""Test vector search with document ID filter."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter([])
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
docs = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["dataset1"])
|
||||
|
||||
# Verify the SQL contains the WHERE clause for filtering
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
search_calls = [call for call in execute_calls if "VEC_DISTANCE" in str(call)]
|
||||
assert len(search_calls) > 0
|
||||
search_call = search_calls[0]
|
||||
assert "WHERE JSON_UNQUOTE" in search_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_with_score_threshold(self, mock_pool_class):
|
||||
"""Test vector search with score threshold."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[
|
||||
{
|
||||
"meta": json.dumps({"doc_id": "doc1", "source": "test"}),
|
||||
"text": "High similarity document",
|
||||
"distance": 0.1, # High similarity (score = 0.9)
|
||||
},
|
||||
{
|
||||
"meta": json.dumps({"doc_id": "doc2", "source": "test"}),
|
||||
"text": "Low similarity document",
|
||||
"distance": 0.8, # Low similarity (score = 0.2)
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
docs = vector_store.search_by_vector(query_vector, top_k=5, score_threshold=0.5)
|
||||
|
||||
# Only the high similarity document should be returned
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "High similarity document"
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_vector_invalid_top_k(self, mock_pool_class):
|
||||
"""Test vector search with invalid top_k."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_vector(query_vector, top_k=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_vector(query_vector, top_k="invalid")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_full_text(self, mock_pool_class):
|
||||
"""Test full-text search."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter(
|
||||
[
|
||||
{
|
||||
"meta": {"doc_id": "doc1", "source": "test"},
|
||||
"text": "This document contains machine learning content",
|
||||
"score": 1.5,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
docs = vector_store.search_by_full_text("machine learning", top_k=5)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "This document contains machine learning content"
|
||||
assert docs[0].metadata["score"] == 1.5
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_full_text_with_filter(self, mock_pool_class):
|
||||
"""Test full-text search with document ID filter."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
mock_cursor.__iter__ = lambda self: iter([])
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
docs = vector_store.search_by_full_text("machine learning", top_k=5, document_ids_filter=["dataset1"])
|
||||
|
||||
# Verify the SQL contains the AND clause for filtering
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
search_calls = [call for call in execute_calls if "MATCH" in str(call)]
|
||||
assert len(search_calls) > 0
|
||||
search_call = search_calls[0]
|
||||
assert "AND JSON_UNQUOTE" in search_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_search_by_full_text_invalid_top_k(self, mock_pool_class):
|
||||
"""Test full-text search with invalid top_k."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_full_text("test", top_k=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_full_text("test", top_k="invalid")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_delete_collection(self, mock_pool_class):
|
||||
"""Test deleting the entire collection."""
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
vector_store.delete()
|
||||
|
||||
# Check that DROP TABLE SQL was executed
|
||||
execute_calls = mock_cursor.execute.call_args_list
|
||||
drop_calls = [call for call in execute_calls if "DROP TABLE" in str(call)]
|
||||
assert len(drop_calls) == 1
|
||||
drop_call = drop_calls[0]
|
||||
assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
def test_unsupported_distance_function(self, mock_pool_class):
|
||||
"""Test that Pydantic validation rejects unsupported distance functions."""
|
||||
# Test that creating config with unsupported distance function raises ValidationError
|
||||
with pytest.raises(ValueError) as context:
|
||||
AlibabaCloudMySQLVectorConfig(
|
||||
host="localhost",
|
||||
port=3306,
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
database="test_db",
|
||||
max_connection=5,
|
||||
distance_function="manhattan", # Unsupported - not in Literal["cosine", "euclidean"]
|
||||
)
|
||||
|
||||
# The error should be related to validation
|
||||
assert "Input should be 'cosine' or 'euclidean'" in str(context.value) or "manhattan" in str(context.value)
|
||||
|
||||
def _setup_mocks(self, mock_redis, mock_pool_class):
|
||||
"""Helper method to setup common mocks."""
|
||||
# Mock Redis operations
|
||||
mock_redis.lock.return_value.__enter__ = MagicMock()
|
||||
mock_redis.lock.return_value.__exit__ = MagicMock()
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.set.return_value = None
|
||||
|
||||
# Mock the connection pool
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
# Mock connection and cursor
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_pool.get_connection.return_value = mock_conn
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.side_effect = [{"VERSION()": "8.0.36"}, {"vector_support": True}]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_config_override",
|
||||
[
|
||||
{"host": ""}, # Test empty host
|
||||
{"port": 0}, # Test invalid port
|
||||
{"max_connection": 0}, # Test invalid max_connection
|
||||
],
|
||||
)
|
||||
def test_config_validation_parametrized(invalid_config_override):
|
||||
"""Test configuration validation for various invalid inputs using parametrize."""
|
||||
config = {
|
||||
"host": "localhost",
|
||||
"port": 3306,
|
||||
"user": "test",
|
||||
"password": "test",
|
||||
"database": "test",
|
||||
"max_connection": 5,
|
||||
}
|
||||
config.update(invalid_config_override)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AlibabaCloudMySQLVectorConfig(**config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,18 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig
|
||||
|
||||
|
||||
def test_default_value():
|
||||
valid_config = {"uri": "http://localhost:19530", "user": "root", "password": "Milvus"}
|
||||
|
||||
for key in valid_config:
|
||||
config = valid_config.copy()
|
||||
del config[key]
|
||||
with pytest.raises(ValidationError) as e:
|
||||
MilvusConfig.model_validate(config)
|
||||
assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required"
|
||||
|
||||
config = MilvusConfig.model_validate(valid_config)
|
||||
assert config.database == "default"
|
||||
@@ -0,0 +1,27 @@
|
||||
import os
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
|
||||
|
||||
|
||||
def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
|
||||
url = "https://firecrawl.dev"
|
||||
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
|
||||
base_url = "https://api.firecrawl.dev"
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url)
|
||||
params = {
|
||||
"includePaths": [],
|
||||
"excludePaths": [],
|
||||
"maxDepth": 1,
|
||||
"limit": 1,
|
||||
}
|
||||
mocked_firecrawl = {
|
||||
"id": "test",
|
||||
}
|
||||
mocker.patch("httpx.post", return_value=_mock_response(mocked_firecrawl))
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
|
||||
assert job_id is not None
|
||||
assert isinstance(job_id, str)
|
||||
@@ -0,0 +1,22 @@
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
|
||||
|
||||
def test_markdown_to_tups():
|
||||
markdown = """
|
||||
this is some text without header
|
||||
|
||||
# title 1
|
||||
this is balabala text
|
||||
|
||||
## title 2
|
||||
this is more specific text.
|
||||
"""
|
||||
extractor = MarkdownExtractor(file_path="dummy_path")
|
||||
updated_output = extractor.markdown_to_tups(markdown)
|
||||
assert len(updated_output) == 3
|
||||
key, header_value = updated_output[0]
|
||||
assert key == None
|
||||
assert header_value.strip() == "this is some text without header"
|
||||
title_1, value = updated_output[1]
|
||||
assert title_1.strip() == "title 1"
|
||||
assert value.strip() == "this is balabala text"
|
||||
@@ -0,0 +1,93 @@
|
||||
from unittest import mock
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.extractor import notion_extractor
|
||||
|
||||
user_id = "user1"
|
||||
database_id = "database1"
|
||||
page_id = "page1"
|
||||
|
||||
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x"
|
||||
)
|
||||
|
||||
|
||||
def _generate_page(page_title: str):
|
||||
return {
|
||||
"object": "page",
|
||||
"id": page_id,
|
||||
"properties": {
|
||||
"Page": {
|
||||
"type": "title",
|
||||
"title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _generate_block(block_id: str, block_type: str, block_text: str):
|
||||
return {
|
||||
"object": "block",
|
||||
"id": block_id,
|
||||
"parent": {"type": "page_id", "page_id": page_id},
|
||||
"type": block_type,
|
||||
"has_children": False,
|
||||
block_type: {
|
||||
"rich_text": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": block_text},
|
||||
"plain_text": block_text,
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _mock_response(data):
|
||||
response = mock.Mock()
|
||||
response.status_code = 200
|
||||
response.json.return_value = data
|
||||
return response
|
||||
|
||||
|
||||
def _remove_multiple_new_lines(text):
|
||||
while "\n\n" in text:
|
||||
text = text.replace("\n\n", "\n")
|
||||
return text.strip()
|
||||
|
||||
|
||||
def test_notion_page(mocker: MockerFixture):
|
||||
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
|
||||
mocked_notion_page = {
|
||||
"object": "list",
|
||||
"results": [
|
||||
_generate_block("b1", "heading_1", texts[0]),
|
||||
_generate_block("b2", "heading_2", texts[1]),
|
||||
_generate_block("b3", "paragraph", texts[2]),
|
||||
_generate_block("b4", "heading_3", texts[3]),
|
||||
],
|
||||
"next_cursor": None,
|
||||
}
|
||||
mocker.patch("httpx.request", return_value=_mock_response(mocked_notion_page))
|
||||
|
||||
page_docs = extractor._load_data_as_documents(page_id, "page")
|
||||
assert len(page_docs) == 1
|
||||
content = _remove_multiple_new_lines(page_docs[0].page_content)
|
||||
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
|
||||
|
||||
|
||||
def test_notion_database(mocker: MockerFixture):
|
||||
page_title_list = ["page1", "page2", "page3"]
|
||||
mocked_notion_database = {
|
||||
"object": "list",
|
||||
"results": [_generate_page(i) for i in page_title_list],
|
||||
"next_cursor": None,
|
||||
}
|
||||
mocker.patch("httpx.post", return_value=_mock_response(mocked_notion_database))
|
||||
database_docs = extractor._load_data_as_documents(database_id, "database")
|
||||
assert len(database_docs) == 1
|
||||
content = _remove_multiple_new_lines(database_docs[0].page_content)
|
||||
assert content == "\n".join([f"Page:{i}" for i in page_title_list])
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Primarily used for testing merged cell scenarios"""
|
||||
|
||||
from docx import Document
|
||||
|
||||
from core.rag.extractor.word_extractor import WordExtractor
|
||||
|
||||
|
||||
def _generate_table_with_merged_cells():
|
||||
doc = Document()
|
||||
|
||||
"""
|
||||
The table looks like this:
|
||||
+-----+-----+-----+
|
||||
| 1-1 & 1-2 | 1-3 |
|
||||
+-----+-----+-----+
|
||||
| 2-1 | 2-2 | 2-3 |
|
||||
| & |-----+-----+
|
||||
| 3-1 | 3-2 | 3-3 |
|
||||
+-----+-----+-----+
|
||||
"""
|
||||
table = doc.add_table(rows=3, cols=3)
|
||||
table.style = "Table Grid"
|
||||
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
cell = table.cell(i, j)
|
||||
cell.text = f"{i + 1}-{j + 1}"
|
||||
|
||||
# Merge cells
|
||||
cell_0_0 = table.cell(0, 0)
|
||||
cell_0_1 = table.cell(0, 1)
|
||||
merged_cell_1 = cell_0_0.merge(cell_0_1)
|
||||
merged_cell_1.text = "1-1 & 1-2"
|
||||
|
||||
cell_1_0 = table.cell(1, 0)
|
||||
cell_2_0 = table.cell(2, 0)
|
||||
merged_cell_2 = cell_1_0.merge(cell_2_0)
|
||||
merged_cell_2.text = "2-1 & 3-1"
|
||||
|
||||
ground_truth = [["1-1 & 1-2", "", "1-3"], ["2-1 & 3-1", "2-2", "2-3"], ["2-1 & 3-1", "3-2", "3-3"]]
|
||||
|
||||
return doc.tables[0], ground_truth
|
||||
|
||||
|
||||
def test_parse_row():
|
||||
table, gt = _generate_table_with_merged_cells()
|
||||
extractor = object.__new__(WordExtractor)
|
||||
for idx, row in enumerate(table.rows):
|
||||
assert extractor._parse_row(row, {}, 3) == gt[idx]
|
||||
301
dify/api/tests/unit_tests/core/rag/pipeline/test_queue.py
Normal file
301
dify/api/tests/unit_tests/core/rag/pipeline/test_queue.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Unit tests for TenantIsolatedTaskQueue.
|
||||
|
||||
These tests verify the Redis-based task queue functionality for tenant-specific
|
||||
task management with proper serialization and deserialization.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
|
||||
|
||||
|
||||
class TestTaskWrapper:
|
||||
"""Test cases for TaskWrapper serialization/deserialization."""
|
||||
|
||||
def test_serialize_simple_data(self):
|
||||
"""Test serialization of simple data types."""
|
||||
data = {"key": "value", "number": 42, "list": [1, 2, 3]}
|
||||
wrapper = TaskWrapper(data=data)
|
||||
|
||||
serialized = wrapper.serialize()
|
||||
assert isinstance(serialized, str)
|
||||
|
||||
# Verify it's valid JSON
|
||||
parsed = json.loads(serialized)
|
||||
assert parsed["data"] == data
|
||||
|
||||
def test_serialize_complex_data(self):
|
||||
"""Test serialization of complex nested data."""
|
||||
data = {
|
||||
"nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5]}},
|
||||
"unicode": "测试中文",
|
||||
"special_chars": "!@#$%^&*()",
|
||||
}
|
||||
wrapper = TaskWrapper(data=data)
|
||||
|
||||
serialized = wrapper.serialize()
|
||||
parsed = json.loads(serialized)
|
||||
assert parsed["data"] == data
|
||||
|
||||
def test_deserialize_valid_data(self):
|
||||
"""Test deserialization of valid JSON data."""
|
||||
original_data = {"key": "value", "number": 42}
|
||||
# Serialize using TaskWrapper to get the correct format
|
||||
wrapper = TaskWrapper(data=original_data)
|
||||
serialized = wrapper.serialize()
|
||||
|
||||
wrapper = TaskWrapper.deserialize(serialized)
|
||||
assert wrapper.data == original_data
|
||||
|
||||
def test_deserialize_invalid_json(self):
|
||||
"""Test deserialization handles invalid JSON gracefully."""
|
||||
invalid_json = "{invalid json}"
|
||||
|
||||
# Pydantic will raise ValidationError for invalid JSON
|
||||
with pytest.raises(ValidationError):
|
||||
TaskWrapper.deserialize(invalid_json)
|
||||
|
||||
def test_serialize_ensure_ascii_false(self):
|
||||
"""Test that serialization preserves Unicode characters."""
|
||||
data = {"chinese": "中文测试", "emoji": "🚀"}
|
||||
wrapper = TaskWrapper(data=data)
|
||||
|
||||
serialized = wrapper.serialize()
|
||||
assert "中文测试" in serialized
|
||||
assert "🚀" in serialized
|
||||
|
||||
|
||||
class TestTenantIsolatedTaskQueue:
|
||||
"""Test cases for TenantIsolatedTaskQueue functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self):
|
||||
"""Mock Redis client for testing."""
|
||||
mock_redis = MagicMock()
|
||||
return mock_redis
|
||||
|
||||
@pytest.fixture
|
||||
def sample_queue(self, mock_redis_client):
|
||||
"""Create a sample TenantIsolatedTaskQueue instance."""
|
||||
return TenantIsolatedTaskQueue("tenant-123", "test-key")
|
||||
|
||||
def test_initialization(self, sample_queue):
|
||||
"""Test queue initialization with correct key generation."""
|
||||
assert sample_queue._tenant_id == "tenant-123"
|
||||
assert sample_queue._unique_key == "test-key"
|
||||
assert sample_queue._queue == "tenant_self_test-key_task_queue:tenant-123"
|
||||
assert sample_queue._task_key == "tenant_test-key_task:tenant-123"
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_get_task_key_exists(self, mock_redis, sample_queue):
|
||||
"""Test getting task key when it exists."""
|
||||
mock_redis.get.return_value = "1"
|
||||
|
||||
result = sample_queue.get_task_key()
|
||||
|
||||
assert result == "1"
|
||||
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_get_task_key_not_exists(self, mock_redis, sample_queue):
|
||||
"""Test getting task key when it doesn't exist."""
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = sample_queue.get_task_key()
|
||||
|
||||
assert result is None
|
||||
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_set_task_waiting_time_default_ttl(self, mock_redis, sample_queue):
|
||||
"""Test setting task waiting flag with default TTL."""
|
||||
sample_queue.set_task_waiting_time()
|
||||
|
||||
mock_redis.setex.assert_called_once_with(
|
||||
"tenant_test-key_task:tenant-123",
|
||||
3600, # DEFAULT_TASK_TTL
|
||||
1,
|
||||
)
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_set_task_waiting_time_custom_ttl(self, mock_redis, sample_queue):
|
||||
"""Test setting task waiting flag with custom TTL."""
|
||||
custom_ttl = 1800
|
||||
sample_queue.set_task_waiting_time(custom_ttl)
|
||||
|
||||
mock_redis.setex.assert_called_once_with("tenant_test-key_task:tenant-123", custom_ttl, 1)
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_delete_task_key(self, mock_redis, sample_queue):
|
||||
"""Test deleting task key."""
|
||||
sample_queue.delete_task_key()
|
||||
|
||||
mock_redis.delete.assert_called_once_with("tenant_test-key_task:tenant-123")
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_push_tasks_string_list(self, mock_redis, sample_queue):
|
||||
"""Test pushing string tasks directly."""
|
||||
tasks = ["task1", "task2", "task3"]
|
||||
|
||||
sample_queue.push_tasks(tasks)
|
||||
|
||||
mock_redis.lpush.assert_called_once_with(
|
||||
"tenant_self_test-key_task_queue:tenant-123", "task1", "task2", "task3"
|
||||
)
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_push_tasks_mixed_types(self, mock_redis, sample_queue):
|
||||
"""Test pushing mixed string and object tasks."""
|
||||
tasks = ["string_task", {"object_task": "data", "id": 123}, "another_string"]
|
||||
|
||||
sample_queue.push_tasks(tasks)
|
||||
|
||||
# Verify lpush was called
|
||||
mock_redis.lpush.assert_called_once()
|
||||
call_args = mock_redis.lpush.call_args
|
||||
|
||||
# Check queue name
|
||||
assert call_args[0][0] == "tenant_self_test-key_task_queue:tenant-123"
|
||||
|
||||
# Check serialized tasks
|
||||
serialized_tasks = call_args[0][1:]
|
||||
assert len(serialized_tasks) == 3
|
||||
assert serialized_tasks[0] == "string_task"
|
||||
assert serialized_tasks[2] == "another_string"
|
||||
|
||||
# Check object task is serialized as TaskWrapper JSON (without prefix)
|
||||
# It should be a valid JSON string that can be deserialized by TaskWrapper
|
||||
wrapper = TaskWrapper.deserialize(serialized_tasks[1])
|
||||
assert wrapper.data == {"object_task": "data", "id": 123}
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_push_tasks_empty_list(self, mock_redis, sample_queue):
|
||||
"""Test pushing empty task list."""
|
||||
sample_queue.push_tasks([])
|
||||
|
||||
mock_redis.lpush.assert_not_called()
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_default_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with default count (1)."""
|
||||
mock_redis.rpop.side_effect = ["task1", None]
|
||||
|
||||
result = sample_queue.pull_tasks()
|
||||
|
||||
assert result == ["task1"]
|
||||
assert mock_redis.rpop.call_count == 1
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_custom_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with custom count."""
|
||||
# First test: pull 3 tasks
|
||||
mock_redis.rpop.side_effect = ["task1", "task2", "task3", None]
|
||||
|
||||
result = sample_queue.pull_tasks(3)
|
||||
|
||||
assert result == ["task1", "task2", "task3"]
|
||||
assert mock_redis.rpop.call_count == 3
|
||||
|
||||
# Reset mock for second test
|
||||
mock_redis.reset_mock()
|
||||
mock_redis.rpop.side_effect = ["task1", "task2", None]
|
||||
|
||||
result = sample_queue.pull_tasks(3)
|
||||
|
||||
assert result == ["task1", "task2"]
|
||||
assert mock_redis.rpop.call_count == 3
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_zero_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with zero count returns empty list."""
|
||||
result = sample_queue.pull_tasks(0)
|
||||
|
||||
assert result == []
|
||||
mock_redis.rpop.assert_not_called()
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_negative_count(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with negative count returns empty list."""
|
||||
result = sample_queue.pull_tasks(-1)
|
||||
|
||||
assert result == []
|
||||
mock_redis.rpop.assert_not_called()
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_with_wrapped_objects(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks that include wrapped objects."""
|
||||
# Create a wrapped task
|
||||
task_data = {"task_id": 123, "data": "test"}
|
||||
wrapper = TaskWrapper(data=task_data)
|
||||
wrapped_task = wrapper.serialize()
|
||||
|
||||
mock_redis.rpop.side_effect = [
|
||||
"string_task",
|
||||
wrapped_task.encode("utf-8"), # Simulate bytes from Redis
|
||||
None,
|
||||
]
|
||||
|
||||
result = sample_queue.pull_tasks(2)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0] == "string_task"
|
||||
assert result[1] == {"task_id": 123, "data": "test"}
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_with_invalid_wrapped_data(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks with invalid JSON falls back to string."""
|
||||
# Invalid JSON string that cannot be deserialized
|
||||
invalid_json = "invalid json data"
|
||||
mock_redis.rpop.side_effect = [invalid_json, None]
|
||||
|
||||
result = sample_queue.pull_tasks(1)
|
||||
|
||||
assert result == [invalid_json]
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_pull_tasks_bytes_decoding(self, mock_redis, sample_queue):
|
||||
"""Test pulling tasks handles bytes from Redis correctly."""
|
||||
mock_redis.rpop.side_effect = [
|
||||
b"task1", # bytes
|
||||
"task2", # string
|
||||
None,
|
||||
]
|
||||
|
||||
result = sample_queue.pull_tasks(2)
|
||||
|
||||
assert result == ["task1", "task2"]
|
||||
|
||||
@patch("core.rag.pipeline.queue.redis_client")
|
||||
def test_complex_object_serialization_roundtrip(self, mock_redis, sample_queue):
|
||||
"""Test complex object serialization and deserialization roundtrip."""
|
||||
complex_task = {
|
||||
"id": uuid4().hex,
|
||||
"data": {"nested": {"deep": [1, 2, 3], "unicode": "测试中文", "special": "!@#$%^&*()"}},
|
||||
"metadata": {"created_at": "2024-01-01T00:00:00Z", "tags": ["tag1", "tag2", "tag3"]},
|
||||
}
|
||||
|
||||
# Push the complex task
|
||||
sample_queue.push_tasks([complex_task])
|
||||
|
||||
# Verify it was serialized as TaskWrapper JSON
|
||||
call_args = mock_redis.lpush.call_args
|
||||
wrapped_task = call_args[0][1]
|
||||
# Verify it's a valid TaskWrapper JSON (starts with {"data":)
|
||||
assert wrapped_task.startswith('{"data":')
|
||||
|
||||
# Verify it can be deserialized
|
||||
wrapper = TaskWrapper.deserialize(wrapped_task)
|
||||
assert wrapper.data == complex_task
|
||||
|
||||
# Simulate pulling it back
|
||||
mock_redis.rpop.return_value = wrapped_task
|
||||
result = sample_queue.pull_tasks(1)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] == complex_task
|
||||
1
dify/api/tests/unit_tests/core/repositories/__init__.py
Normal file
1
dify/api/tests/unit_tests/core/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Unit tests for core repositories module
|
||||
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Unit tests for CeleryWorkflowExecutionRepository.
|
||||
|
||||
These tests verify the Celery-based asynchronous storage functionality
|
||||
for workflow execution data.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Mock SQLAlchemy session factory."""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Create a real sessionmaker with in-memory SQLite for testing
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Mock Account user."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = str(uuid4())
|
||||
account.current_tenant_id = str(uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
"""Mock EndUser."""
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = str(uuid4())
|
||||
user.tenant_id = str(uuid4())
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_execution():
|
||||
"""Sample WorkflowExecution for testing."""
|
||||
return WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
|
||||
class TestCeleryWorkflowExecutionRepository:
|
||||
"""Test cases for CeleryWorkflowExecutionRepository."""
|
||||
|
||||
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with sessionmaker."""
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == app_id
|
||||
assert repo._triggered_from == triggered_from
|
||||
assert repo._creator_user_id == mock_account.id
|
||||
assert repo._creator_user_role is not None
|
||||
|
||||
def test_init_basic_functionality(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization basic functionality."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
|
||||
# Verify basic initialization
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == "test-app"
|
||||
assert repo._triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
|
||||
|
||||
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
|
||||
"""Test repository initialization with EndUser."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_end_user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_end_user.tenant_id
|
||||
|
||||
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
|
||||
"""Test that initialization fails without tenant_id."""
|
||||
# Create a mock Account with no tenant_id
|
||||
user = Mock(spec=Account)
|
||||
user.current_tenant_id = None
|
||||
user.id = str(uuid4())
|
||||
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_queues_celery_task(self, mock_task, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
"""Test that save operation queues a Celery task without tracking."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
# Verify Celery task was queued with correct parameters
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
|
||||
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test-app"
|
||||
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
|
||||
# Verify no task tracking occurs (no _pending_saves attribute)
|
||||
assert not hasattr(repo, "_pending_saves")
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_handles_celery_failure(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
"""Test that save operation handles Celery task failures."""
|
||||
mock_task.delay.side_effect = Exception("Celery is down")
|
||||
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Celery is down"):
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_operation_fire_and_forget(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
"""Test that save operation works in fire-and-forget mode."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Test that save doesn't block or maintain state
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
# Verify no pending saves are tracked (no _pending_saves attribute)
|
||||
assert not hasattr(repo, "_pending_saves")
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_multiple_save_operations(self, mock_task, mock_session_factory, mock_account):
|
||||
"""Test multiple save operations work correctly."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Create multiple executions
|
||||
exec1 = WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
exec2 = WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input2": "value2"},
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save both executions
|
||||
repo.save(exec1)
|
||||
repo.save(exec2)
|
||||
|
||||
# Should work without issues and not maintain state (no _pending_saves attribute)
|
||||
assert not hasattr(repo, "_pending_saves")
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_with_different_user_types(self, mock_task, mock_session_factory, mock_end_user):
|
||||
"""Test save operation with different user types."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_end_user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
execution = WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
repo.save(execution)
|
||||
|
||||
# Verify task was called with EndUser context
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
assert call_args["tenant_id"] == mock_end_user.tenant_id
|
||||
assert call_args["creator_user_id"] == mock_end_user.id
|
||||
@@ -0,0 +1,349 @@
|
||||
"""
|
||||
Unit tests for CeleryWorkflowNodeExecutionRepository.
|
||||
|
||||
These tests verify the Celery-based asynchronous storage functionality
|
||||
for workflow node execution data.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, EndUser
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Mock SQLAlchemy session factory."""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Create a real sessionmaker with in-memory SQLite for testing
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Mock Account user."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = str(uuid4())
|
||||
account.current_tenant_id = str(uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
"""Mock EndUser."""
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = str(uuid4())
|
||||
user.tenant_id = str(uuid4())
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_node_execution():
|
||||
"""Sample WorkflowNodeExecution for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="test_node",
|
||||
node_type=NodeType.START,
|
||||
title="Test Node",
|
||||
inputs={"input1": "value1"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
|
||||
class TestCeleryWorkflowNodeExecutionRepository:
|
||||
"""Test cases for CeleryWorkflowNodeExecutionRepository."""
|
||||
|
||||
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with sessionmaker."""
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == app_id
|
||||
assert repo._triggered_from == triggered_from
|
||||
assert repo._creator_user_id == mock_account.id
|
||||
assert repo._creator_user_role is not None
|
||||
|
||||
def test_init_with_cache_initialized(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with cache properly initialized."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
assert repo._execution_cache == {}
|
||||
assert repo._workflow_execution_mapping == {}
|
||||
|
||||
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
|
||||
"""Test repository initialization with EndUser."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_end_user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_end_user.tenant_id
|
||||
|
||||
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
|
||||
"""Test that initialization fails without tenant_id."""
|
||||
# Create a mock Account with no tenant_id
|
||||
user = Mock(spec=Account)
|
||||
user.current_tenant_id = None
|
||||
user.id = str(uuid4())
|
||||
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_save_caches_and_queues_celery_task(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
|
||||
):
|
||||
"""Test that save operation caches execution and queues a Celery task."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
# Verify Celery task was queued with correct parameters
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
|
||||
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test-app"
|
||||
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
|
||||
# Verify execution is cached
|
||||
assert sample_workflow_node_execution.id in repo._execution_cache
|
||||
assert repo._execution_cache[sample_workflow_node_execution.id] == sample_workflow_node_execution
|
||||
|
||||
# Verify workflow execution mapping is updated
|
||||
assert sample_workflow_node_execution.workflow_execution_id in repo._workflow_execution_mapping
|
||||
assert (
|
||||
sample_workflow_node_execution.id
|
||||
in repo._workflow_execution_mapping[sample_workflow_node_execution.workflow_execution_id]
|
||||
)
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_save_handles_celery_failure(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
|
||||
):
|
||||
"""Test that save operation handles Celery task failures."""
|
||||
mock_task.delay.side_effect = Exception("Celery is down")
|
||||
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Celery is down"):
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_get_by_workflow_run_from_cache(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
|
||||
):
|
||||
"""Test that get_by_workflow_run retrieves executions from cache."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Save execution to cache first
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
workflow_run_id = sample_workflow_node_execution.workflow_execution_id
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="asc")
|
||||
|
||||
result = repo.get_by_workflow_run(workflow_run_id, order_config)
|
||||
|
||||
# Verify results were retrieved from cache
|
||||
assert len(result) == 1
|
||||
assert result[0].id == sample_workflow_node_execution.id
|
||||
assert result[0] is sample_workflow_node_execution
|
||||
|
||||
def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account):
|
||||
"""Test get_by_workflow_run without order configuration."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
result = repo.get_by_workflow_run("workflow-run-id")
|
||||
|
||||
# Should return empty list since nothing in cache
|
||||
assert len(result) == 0
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_cache_operations(self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution):
|
||||
"""Test cache operations work correctly."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Test saving to cache
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
# Verify cache contains the execution
|
||||
assert sample_workflow_node_execution.id in repo._execution_cache
|
||||
|
||||
# Test retrieving from cache
|
||||
result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id)
|
||||
assert len(result) == 1
|
||||
assert result[0].id == sample_workflow_node_execution.id
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_multiple_executions_same_workflow(self, mock_task, mock_session_factory, mock_account):
|
||||
"""Test multiple executions for the same workflow."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Create multiple executions for the same workflow
|
||||
workflow_run_id = str(uuid4())
|
||||
exec1 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=1,
|
||||
node_id="node1",
|
||||
node_type=NodeType.START,
|
||||
title="Node 1",
|
||||
inputs={"input1": "value1"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
exec2 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=2,
|
||||
node_id="node2",
|
||||
node_type=NodeType.LLM,
|
||||
title="Node 2",
|
||||
inputs={"input2": "value2"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save both executions
|
||||
repo.save(exec1)
|
||||
repo.save(exec2)
|
||||
|
||||
# Verify both are cached and mapped
|
||||
assert len(repo._execution_cache) == 2
|
||||
assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2
|
||||
|
||||
# Test retrieval
|
||||
result = repo.get_by_workflow_run(workflow_run_id)
|
||||
assert len(result) == 2
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_ordering_functionality(self, mock_task, mock_session_factory, mock_account):
|
||||
"""Test ordering functionality works correctly."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Create executions with different indices
|
||||
workflow_run_id = str(uuid4())
|
||||
exec1 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=2,
|
||||
node_id="node2",
|
||||
node_type=NodeType.START,
|
||||
title="Node 2",
|
||||
inputs={},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
exec2 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_run_id,
|
||||
index=1,
|
||||
node_id="node1",
|
||||
node_type=NodeType.LLM,
|
||||
title="Node 1",
|
||||
inputs={},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save in random order
|
||||
repo.save(exec1)
|
||||
repo.save(exec2)
|
||||
|
||||
# Test ascending order
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="asc")
|
||||
result = repo.get_by_workflow_run(workflow_run_id, order_config)
|
||||
assert len(result) == 2
|
||||
assert result[0].index == 1
|
||||
assert result[1].index == 2
|
||||
|
||||
# Test descending order
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
result = repo.get_by_workflow_run(workflow_run_id, order_config)
|
||||
assert len(result) == 2
|
||||
assert result[0].index == 2
|
||||
assert result[1].index == 1
|
||||
244
dify/api/tests/unit_tests/core/repositories/test_factory.py
Normal file
244
dify/api/tests/unit_tests/core/repositories/test_factory.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
Unit tests for the RepositoryFactory.
|
||||
|
||||
This module tests the factory pattern implementation for creating repository instances
|
||||
based on configuration, including error handling.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from libs.module_loading import import_string
|
||||
from models import Account, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
class TestRepositoryFactory:
|
||||
"""Test cases for RepositoryFactory."""
|
||||
|
||||
def test_import_string_success(self):
|
||||
"""Test successful class import."""
|
||||
# Test importing a real class
|
||||
class_path = "unittest.mock.MagicMock"
|
||||
result = import_string(class_path)
|
||||
assert result is MagicMock
|
||||
|
||||
def test_import_string_invalid_path(self):
|
||||
"""Test import with invalid module path."""
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
import_string("invalid.module.path")
|
||||
assert "No module named" in str(exc_info.value)
|
||||
|
||||
def test_import_string_invalid_class_name(self):
|
||||
"""Test import with invalid class name."""
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
import_string("unittest.mock.NonExistentClass")
|
||||
assert "does not define" in str(exc_info.value)
|
||||
|
||||
def test_import_string_malformed_path(self):
|
||||
"""Test import with malformed path (no dots)."""
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
import_string("invalidpath")
|
||||
assert "doesn't look like a module path" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_success(self, mock_config):
|
||||
"""Test successful WorkflowExecutionRepository creation."""
|
||||
# Setup mock configuration
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
# Create mock repository class and instance
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock import_string
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
# Verify the repository was created with correct parameters
|
||||
mock_repository_class.assert_called_once_with(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_import_error(self, mock_config):
|
||||
"""Test WorkflowExecutionRepository creation with import error."""
|
||||
# Setup mock configuration with invalid class path
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_instantiation_error(self, mock_config):
|
||||
"""Test WorkflowExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
# Create a mock repository class that raises exception on instantiation
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_class.side_effect = Exception("Instantiation failed")
|
||||
|
||||
# Mock import_string to return a failing class
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_success(self, mock_config):
|
||||
"""Test successful WorkflowNodeExecutionRepository creation."""
|
||||
# Setup mock configuration
|
||||
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP
|
||||
|
||||
# Create mock repository class and instance
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock import_string
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
# Verify the repository was created with correct parameters
|
||||
mock_repository_class.assert_called_once_with(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with import error."""
|
||||
# Setup mock configuration with invalid class path
|
||||
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
# Create a mock repository class that raises exception on instantiation
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_class.side_effect = Exception("Instantiation failed")
|
||||
|
||||
# Mock import_string to return a failing class
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
|
||||
|
||||
def test_repository_import_error_exception(self):
|
||||
"""Test RepositoryImportError exception handling."""
|
||||
error_message = "Custom error message"
|
||||
error = RepositoryImportError(error_message)
|
||||
assert str(error) == error_message
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_with_engine_instead_of_sessionmaker(self, mock_config):
|
||||
"""Test repository creation with Engine instead of sessionmaker."""
|
||||
# Setup mock configuration
|
||||
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies using Engine instead of sessionmaker
|
||||
mock_engine = MagicMock(spec=Engine)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
# Create mock repository class and instance
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock import_string
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_engine, # Using Engine instead of sessionmaker
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
# Verify the repository was created with correct parameters
|
||||
mock_repository_class.assert_called_once_with(
|
||||
session_factory=mock_engine,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
@@ -0,0 +1,210 @@
|
||||
"""Unit tests for workflow node execution conflict handling."""
|
||||
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionConflictHandling:
|
||||
"""Test cases for handling duplicate key conflicts in workflow node execution."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
# Create a mock user with tenant_id
|
||||
self.mock_user = Mock(spec=Account)
|
||||
self.mock_user.id = "test-user-id"
|
||||
self.mock_user.current_tenant_id = "test-tenant-id"
|
||||
|
||||
# Create mock session factory
|
||||
self.mock_session_factory = Mock(spec=sessionmaker)
|
||||
|
||||
# Create repository instance
|
||||
self.repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=self.mock_session_factory,
|
||||
user=self.mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
def test_save_with_duplicate_key_retries_with_new_uuid(self):
|
||||
"""Test that save retries with a new UUID v7 when encountering duplicate key error."""
|
||||
# Create a mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = Mock(return_value=mock_session)
|
||||
mock_session.__exit__ = Mock(return_value=None)
|
||||
self.mock_session_factory.return_value = mock_session
|
||||
|
||||
# Mock session.get to return None (no existing record)
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Create IntegrityError for duplicate key with proper psycopg2.errors.UniqueViolation
|
||||
mock_unique_violation = Mock(spec=psycopg2.errors.UniqueViolation)
|
||||
duplicate_error = IntegrityError(
|
||||
"duplicate key value violates unique constraint",
|
||||
params=None,
|
||||
orig=mock_unique_violation,
|
||||
)
|
||||
|
||||
# First call to session.add raises IntegrityError, second succeeds
|
||||
mock_session.add.side_effect = [duplicate_error, None]
|
||||
mock_session.commit.side_effect = [None, None]
|
||||
|
||||
# Create test execution
|
||||
execution = WorkflowNodeExecution(
|
||||
id="original-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-execution-id",
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.START,
|
||||
title="Test Node",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
original_id = execution.id
|
||||
|
||||
# Save should succeed after retry
|
||||
self.repository.save(execution)
|
||||
|
||||
# Verify that session.add was called twice (initial attempt + retry)
|
||||
assert mock_session.add.call_count == 2
|
||||
|
||||
# Verify that the ID was changed (new UUID v7 generated)
|
||||
assert execution.id != original_id
|
||||
|
||||
def test_save_with_existing_record_updates_instead_of_insert(self):
|
||||
"""Test that save updates existing record instead of inserting duplicate."""
|
||||
# Create a mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = Mock(return_value=mock_session)
|
||||
mock_session.__exit__ = Mock(return_value=None)
|
||||
self.mock_session_factory.return_value = mock_session
|
||||
|
||||
# Mock existing record
|
||||
mock_existing = MagicMock()
|
||||
mock_session.get.return_value = mock_existing
|
||||
mock_session.commit.return_value = None
|
||||
|
||||
# Create test execution
|
||||
execution = WorkflowNodeExecution(
|
||||
id="existing-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-execution-id",
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.START,
|
||||
title="Test Node",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save should update existing record
|
||||
self.repository.save(execution)
|
||||
|
||||
# Verify that session.add was not called (update path)
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
# Verify that session.commit was called
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_save_exceeds_max_retries_raises_error(self):
|
||||
"""Test that save raises error after exceeding max retries."""
|
||||
# Create a mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = Mock(return_value=mock_session)
|
||||
mock_session.__exit__ = Mock(return_value=None)
|
||||
self.mock_session_factory.return_value = mock_session
|
||||
|
||||
# Mock session.get to return None (no existing record)
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Create IntegrityError for duplicate key with proper psycopg2.errors.UniqueViolation
|
||||
mock_unique_violation = Mock(spec=psycopg2.errors.UniqueViolation)
|
||||
duplicate_error = IntegrityError(
|
||||
"duplicate key value violates unique constraint",
|
||||
params=None,
|
||||
orig=mock_unique_violation,
|
||||
)
|
||||
|
||||
# All attempts fail with duplicate error
|
||||
mock_session.add.side_effect = duplicate_error
|
||||
|
||||
# Create test execution
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-execution-id",
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.START,
|
||||
title="Test Node",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save should raise IntegrityError after max retries
|
||||
with pytest.raises(IntegrityError):
|
||||
self.repository.save(execution)
|
||||
|
||||
# Verify that session.add was called 3 times (max_retries)
|
||||
assert mock_session.add.call_count == 3
|
||||
|
||||
def test_save_non_duplicate_integrity_error_raises_immediately(self):
|
||||
"""Test that non-duplicate IntegrityErrors are raised immediately without retry."""
|
||||
# Create a mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = Mock(return_value=mock_session)
|
||||
mock_session.__exit__ = Mock(return_value=None)
|
||||
self.mock_session_factory.return_value = mock_session
|
||||
|
||||
# Mock session.get to return None (no existing record)
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Create IntegrityError for non-duplicate constraint
|
||||
other_error = IntegrityError(
|
||||
"null value in column violates not-null constraint",
|
||||
params=None,
|
||||
orig=None,
|
||||
)
|
||||
|
||||
# First call raises non-duplicate error
|
||||
mock_session.add.side_effect = other_error
|
||||
|
||||
# Create test execution
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-execution-id",
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.START,
|
||||
title="Test Node",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save should raise error immediately
|
||||
with pytest.raises(IntegrityError):
|
||||
self.repository.save(execution)
|
||||
|
||||
# Verify that session.add was called only once (no retry)
|
||||
assert mock_session.add.call_count == 1
|
||||
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Unit tests for WorkflowNodeExecution truncation functionality.
|
||||
|
||||
Tests the truncation and offloading logic for large inputs and outputs
|
||||
in the SQLAlchemyWorkflowNodeExecutionRepository.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from models import Account, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import ExecutionOffLoadType
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
||||
|
||||
|
||||
@dataclass
|
||||
class TruncationTestCase:
|
||||
"""Test case data for truncation scenarios."""
|
||||
|
||||
name: str
|
||||
inputs: dict[str, Any] | None
|
||||
outputs: dict[str, Any] | None
|
||||
should_truncate_inputs: bool
|
||||
should_truncate_outputs: bool
|
||||
description: str
|
||||
|
||||
|
||||
def create_test_cases() -> list[TruncationTestCase]:
|
||||
"""Create test cases for different truncation scenarios."""
|
||||
# Create large data that will definitely exceed the threshold (10KB)
|
||||
large_data = {"data": "x" * (TRUNCATION_SIZE_THRESHOLD + 1000)}
|
||||
small_data = {"data": "small"}
|
||||
|
||||
return [
|
||||
TruncationTestCase(
|
||||
name="small_data_no_truncation",
|
||||
inputs=small_data,
|
||||
outputs=small_data,
|
||||
should_truncate_inputs=False,
|
||||
should_truncate_outputs=False,
|
||||
description="Small data should not be truncated",
|
||||
),
|
||||
TruncationTestCase(
|
||||
name="large_inputs_truncation",
|
||||
inputs=large_data,
|
||||
outputs=small_data,
|
||||
should_truncate_inputs=True,
|
||||
should_truncate_outputs=False,
|
||||
description="Large inputs should be truncated",
|
||||
),
|
||||
TruncationTestCase(
|
||||
name="large_outputs_truncation",
|
||||
inputs=small_data,
|
||||
outputs=large_data,
|
||||
should_truncate_inputs=False,
|
||||
should_truncate_outputs=True,
|
||||
description="Large outputs should be truncated",
|
||||
),
|
||||
TruncationTestCase(
|
||||
name="large_both_truncation",
|
||||
inputs=large_data,
|
||||
outputs=large_data,
|
||||
should_truncate_inputs=True,
|
||||
should_truncate_outputs=True,
|
||||
description="Both large inputs and outputs should be truncated",
|
||||
),
|
||||
TruncationTestCase(
|
||||
name="none_inputs_outputs",
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
should_truncate_inputs=False,
|
||||
should_truncate_outputs=False,
|
||||
description="None inputs and outputs should not be truncated",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def create_workflow_node_execution(
|
||||
execution_id: str = "test-execution-id",
|
||||
inputs: dict[str, Any] | None = None,
|
||||
outputs: dict[str, Any] | None = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Factory function to create a WorkflowNodeExecution for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
node_execution_id="test-node-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-execution-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
def mock_user() -> Account:
|
||||
"""Create a mock Account user for testing."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "test-user-id"
|
||||
user.current_tenant_id = "test-tenant-id"
|
||||
return user
|
||||
|
||||
|
||||
class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation:
|
||||
"""Test class for truncation functionality in SQLAlchemyWorkflowNodeExecutionRepository."""
|
||||
|
||||
def create_repository(self) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""Create a repository instance for testing."""
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=MagicMock(spec=Engine),
|
||||
user=mock_user(),
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
def test_to_domain_model_without_offload_data(self):
|
||||
"""Test _to_domain_model correctly handles models without offload data."""
|
||||
repo = self.create_repository()
|
||||
|
||||
# Create a mock database model without offload data
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "test-id"
|
||||
db_model.node_execution_id = "node-exec-id"
|
||||
db_model.workflow_id = "workflow-id"
|
||||
db_model.workflow_run_id = "run-id"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node-id"
|
||||
db_model.node_type = NodeType.LLM
|
||||
db_model.title = "Test Node"
|
||||
db_model.inputs = json.dumps({"value": "inputs"})
|
||||
db_model.process_data = json.dumps({"value": "process_data"})
|
||||
db_model.outputs = json.dumps({"value": "outputs"})
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.0
|
||||
db_model.execution_metadata = "{}"
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = None
|
||||
db_model.offload_data = []
|
||||
|
||||
domain_model = repo._to_domain_model(db_model)
|
||||
|
||||
# Check that no truncated data was set
|
||||
assert domain_model.get_truncated_inputs() is None
|
||||
assert domain_model.get_truncated_outputs() is None
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionModelTruncatedProperties:
|
||||
"""Test the truncated properties on WorkflowNodeExecutionModel."""
|
||||
|
||||
def test_inputs_truncated_with_offload_data(self):
|
||||
"""Test inputs_truncated property when offload data exists."""
|
||||
model = WorkflowNodeExecutionModel()
|
||||
offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
|
||||
model.offload_data = [offload]
|
||||
|
||||
assert model.inputs_truncated is True
|
||||
assert model.process_data_truncated is False
|
||||
assert model.outputs_truncated is False
|
||||
|
||||
def test_outputs_truncated_with_offload_data(self):
|
||||
"""Test outputs_truncated property when offload data exists."""
|
||||
model = WorkflowNodeExecutionModel()
|
||||
|
||||
# Mock offload data with outputs file
|
||||
offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
|
||||
model.offload_data = [offload]
|
||||
|
||||
assert model.inputs_truncated is False
|
||||
assert model.process_data_truncated is False
|
||||
assert model.outputs_truncated is True
|
||||
|
||||
def test_process_data_truncated_with_offload_data(self):
|
||||
model = WorkflowNodeExecutionModel()
|
||||
offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA)
|
||||
model.offload_data = [offload]
|
||||
assert model.process_data_truncated is True
|
||||
assert model.inputs_truncated is False
|
||||
assert model.outputs_truncated is False
|
||||
|
||||
def test_truncated_properties_without_offload_data(self):
|
||||
"""Test truncated properties when no offload data exists."""
|
||||
model = WorkflowNodeExecutionModel()
|
||||
model.offload_data = []
|
||||
|
||||
assert model.inputs_truncated is False
|
||||
assert model.outputs_truncated is False
|
||||
assert model.process_data_truncated is False
|
||||
|
||||
def test_truncated_properties_without_offload_attribute(self):
|
||||
"""Test truncated properties when offload_data attribute doesn't exist."""
|
||||
model = WorkflowNodeExecutionModel()
|
||||
# Don't set offload_data attribute at all
|
||||
|
||||
assert model.inputs_truncated is False
|
||||
assert model.outputs_truncated is False
|
||||
assert model.process_data_truncated is False
|
||||
1
dify/api/tests/unit_tests/core/schemas/__init__.py
Normal file
1
dify/api/tests/unit_tests/core/schemas/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Core schemas unit tests
|
||||
769
dify/api/tests/unit_tests/core/schemas/test_resolver.py
Normal file
769
dify/api/tests/unit_tests/core/schemas/test_resolver.py
Normal file
@@ -0,0 +1,769 @@
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.schemas import resolve_dify_schema_refs
|
||||
from core.schemas.registry import SchemaRegistry
|
||||
from core.schemas.resolver import (
|
||||
MaxDepthExceededError,
|
||||
SchemaResolver,
|
||||
_has_dify_refs,
|
||||
_has_dify_refs_hybrid,
|
||||
_has_dify_refs_recursive,
|
||||
_is_dify_schema_ref,
|
||||
_remove_metadata_fields,
|
||||
parse_dify_schema_uri,
|
||||
)
|
||||
|
||||
|
||||
class TestSchemaResolver:
|
||||
"""Test cases for schema reference resolution"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup method to initialize test resources"""
|
||||
self.registry = SchemaRegistry.default_registry()
|
||||
# Clear cache before each test
|
||||
SchemaResolver.clear_cache()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Cleanup after each test"""
|
||||
SchemaResolver.clear_cache()
|
||||
|
||||
def test_simple_ref_resolution(self):
|
||||
"""Test resolving a simple $ref to a complete schema"""
|
||||
schema_with_ref = {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
|
||||
|
||||
resolved = resolve_dify_schema_refs(schema_with_ref)
|
||||
|
||||
# Should be resolved to the actual qa_structure schema
|
||||
assert resolved["type"] == "object"
|
||||
assert resolved["title"] == "Q&A Structure"
|
||||
assert "qa_chunks" in resolved["properties"]
|
||||
assert resolved["properties"]["qa_chunks"]["type"] == "array"
|
||||
|
||||
# Metadata fields should be removed
|
||||
assert "$id" not in resolved
|
||||
assert "$schema" not in resolved
|
||||
assert "version" not in resolved
|
||||
|
||||
def test_nested_object_with_refs(self):
|
||||
"""Test resolving $refs within nested object structures"""
|
||||
nested_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
"metadata": {"type": "string", "description": "Additional metadata"},
|
||||
},
|
||||
}
|
||||
|
||||
resolved = resolve_dify_schema_refs(nested_schema)
|
||||
|
||||
# Original structure should be preserved
|
||||
assert resolved["type"] == "object"
|
||||
assert "metadata" in resolved["properties"]
|
||||
assert resolved["properties"]["metadata"]["type"] == "string"
|
||||
|
||||
# $ref should be resolved
|
||||
file_schema = resolved["properties"]["file_data"]
|
||||
assert file_schema["type"] == "object"
|
||||
assert file_schema["title"] == "File"
|
||||
assert "name" in file_schema["properties"]
|
||||
|
||||
# Metadata fields should be removed from resolved schema
|
||||
assert "$id" not in file_schema
|
||||
assert "$schema" not in file_schema
|
||||
assert "version" not in file_schema
|
||||
|
||||
def test_array_items_ref_resolution(self):
|
||||
"""Test resolving $refs in array items"""
|
||||
array_schema = {
|
||||
"type": "array",
|
||||
"items": {"$ref": "https://dify.ai/schemas/v1/general_structure.json"},
|
||||
"description": "Array of general structures",
|
||||
}
|
||||
|
||||
resolved = resolve_dify_schema_refs(array_schema)
|
||||
|
||||
# Array structure should be preserved
|
||||
assert resolved["type"] == "array"
|
||||
assert resolved["description"] == "Array of general structures"
|
||||
|
||||
# Items $ref should be resolved
|
||||
items_schema = resolved["items"]
|
||||
assert items_schema["type"] == "array"
|
||||
assert items_schema["title"] == "General Structure"
|
||||
|
||||
def test_non_dify_ref_unchanged(self):
|
||||
"""Test that non-Dify $refs are left unchanged"""
|
||||
external_ref_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"external_data": {"$ref": "https://example.com/external-schema.json"},
|
||||
"dify_data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
},
|
||||
}
|
||||
|
||||
resolved = resolve_dify_schema_refs(external_ref_schema)
|
||||
|
||||
# External $ref should remain unchanged
|
||||
assert resolved["properties"]["external_data"]["$ref"] == "https://example.com/external-schema.json"
|
||||
|
||||
# Dify $ref should be resolved
|
||||
assert resolved["properties"]["dify_data"]["type"] == "object"
|
||||
assert resolved["properties"]["dify_data"]["title"] == "File"
|
||||
|
||||
def test_no_refs_schema_unchanged(self):
|
||||
"""Test that schemas without $refs are returned unchanged"""
|
||||
simple_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name field"},
|
||||
"items": {"type": "array", "items": {"type": "number"}},
|
||||
},
|
||||
"required": ["name"],
|
||||
}
|
||||
|
||||
resolved = resolve_dify_schema_refs(simple_schema)
|
||||
|
||||
# Should be identical to input
|
||||
assert resolved == simple_schema
|
||||
assert resolved["type"] == "object"
|
||||
assert resolved["properties"]["name"]["type"] == "string"
|
||||
assert resolved["properties"]["items"]["items"]["type"] == "number"
|
||||
assert resolved["required"] == ["name"]
|
||||
|
||||
def test_recursion_depth_protection(self):
|
||||
"""Test that excessive recursion depth is prevented"""
|
||||
# Create a moderately nested structure
|
||||
deep_schema = {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}
|
||||
|
||||
# Wrap it in fewer layers to make the test more reasonable
|
||||
for _ in range(2):
|
||||
deep_schema = {"type": "object", "properties": {"nested": deep_schema}}
|
||||
|
||||
# Should handle normal cases fine with reasonable depth
|
||||
resolved = resolve_dify_schema_refs(deep_schema, max_depth=25)
|
||||
assert resolved is not None
|
||||
assert resolved["type"] == "object"
|
||||
|
||||
# Should raise error with very low max_depth
|
||||
with pytest.raises(MaxDepthExceededError) as exc_info:
|
||||
resolve_dify_schema_refs(deep_schema, max_depth=5)
|
||||
assert exc_info.value.max_depth == 5
|
||||
|
||||
def test_circular_reference_detection(self):
|
||||
"""Test that circular references are detected and handled"""
|
||||
# Mock registry with circular reference
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_schema.side_effect = lambda uri: {
|
||||
"$ref": "https://dify.ai/schemas/v1/circular.json",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
schema = {"$ref": "https://dify.ai/schemas/v1/circular.json"}
|
||||
resolved = resolve_dify_schema_refs(schema, registry=mock_registry)
|
||||
|
||||
# Should mark circular reference
|
||||
assert "$circular_ref" in resolved
|
||||
|
||||
def test_schema_not_found_handling(self):
|
||||
"""Test handling of missing schemas"""
|
||||
# Mock registry that returns None for unknown schemas
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_schema.return_value = None
|
||||
|
||||
schema = {"$ref": "https://dify.ai/schemas/v1/unknown.json"}
|
||||
resolved = resolve_dify_schema_refs(schema, registry=mock_registry)
|
||||
|
||||
# Should keep the original $ref when schema not found
|
||||
assert resolved["$ref"] == "https://dify.ai/schemas/v1/unknown.json"
|
||||
|
||||
def test_primitive_types_unchanged(self):
|
||||
"""Test that primitive types are returned unchanged"""
|
||||
assert resolve_dify_schema_refs("string") == "string"
|
||||
assert resolve_dify_schema_refs(123) == 123
|
||||
assert resolve_dify_schema_refs(True) is True
|
||||
assert resolve_dify_schema_refs(None) is None
|
||||
assert resolve_dify_schema_refs(3.14) == 3.14
|
||||
|
||||
def test_cache_functionality(self):
|
||||
"""Test that caching works correctly"""
|
||||
schema = {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
|
||||
# First resolution should fetch from registry
|
||||
resolved1 = resolve_dify_schema_refs(schema)
|
||||
|
||||
# Mock the registry to return different data
|
||||
with patch.object(self.registry, "get_schema") as mock_get:
|
||||
mock_get.return_value = {"type": "different"}
|
||||
|
||||
# Second resolution should use cache
|
||||
resolved2 = resolve_dify_schema_refs(schema)
|
||||
|
||||
# Should be the same as first resolution (from cache)
|
||||
assert resolved1 == resolved2
|
||||
# Mock should not have been called
|
||||
mock_get.assert_not_called()
|
||||
|
||||
# Clear cache and try again
|
||||
SchemaResolver.clear_cache()
|
||||
|
||||
# Now it should fetch again
|
||||
resolved3 = resolve_dify_schema_refs(schema)
|
||||
assert resolved3 == resolved1
|
||||
|
||||
def test_thread_safety(self):
|
||||
"""Test that the resolver is thread-safe"""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"} for i in range(10)},
|
||||
}
|
||||
|
||||
results = []
|
||||
|
||||
def resolve_in_thread():
|
||||
try:
|
||||
result = resolve_dify_schema_refs(schema)
|
||||
results.append(result)
|
||||
return True
|
||||
except Exception as e:
|
||||
results.append(e)
|
||||
return False
|
||||
|
||||
# Run multiple threads concurrently
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(resolve_in_thread) for _ in range(20)]
|
||||
success = all(f.result() for f in futures)
|
||||
|
||||
assert success
|
||||
# All results should be the same
|
||||
first_result = results[0]
|
||||
assert all(r == first_result for r in results if not isinstance(r, Exception))
|
||||
|
||||
def test_mixed_nested_structures(self):
|
||||
"""Test resolving refs in complex mixed structures"""
|
||||
complex_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {"type": "array", "items": {"$ref": "https://dify.ai/schemas/v1/file.json"}},
|
||||
"nested": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"qa": {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"},
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"general": {"$ref": "https://dify.ai/schemas/v1/general_structure.json"}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resolved = resolve_dify_schema_refs(complex_schema, max_depth=20)
|
||||
|
||||
# Check structure is preserved
|
||||
assert resolved["type"] == "object"
|
||||
assert "files" in resolved["properties"]
|
||||
assert "nested" in resolved["properties"]
|
||||
|
||||
# Check refs are resolved
|
||||
assert resolved["properties"]["files"]["items"]["type"] == "object"
|
||||
assert resolved["properties"]["files"]["items"]["title"] == "File"
|
||||
assert resolved["properties"]["nested"]["properties"]["qa"]["type"] == "object"
|
||||
assert resolved["properties"]["nested"]["properties"]["qa"]["title"] == "Q&A Structure"
|
||||
|
||||
|
||||
class TestUtilityFunctions:
|
||||
"""Test utility functions"""
|
||||
|
||||
def test_is_dify_schema_ref(self):
|
||||
"""Test _is_dify_schema_ref function"""
|
||||
# Valid Dify refs
|
||||
assert _is_dify_schema_ref("https://dify.ai/schemas/v1/file.json")
|
||||
assert _is_dify_schema_ref("https://dify.ai/schemas/v2/complex_name.json")
|
||||
assert _is_dify_schema_ref("https://dify.ai/schemas/v999/test-file.json")
|
||||
|
||||
# Invalid refs
|
||||
assert not _is_dify_schema_ref("https://example.com/schema.json")
|
||||
assert not _is_dify_schema_ref("https://dify.ai/other/path.json")
|
||||
assert not _is_dify_schema_ref("not a uri")
|
||||
assert not _is_dify_schema_ref("")
|
||||
assert not _is_dify_schema_ref(None)
|
||||
assert not _is_dify_schema_ref(123)
|
||||
assert not _is_dify_schema_ref(["list"])
|
||||
|
||||
def test_has_dify_refs(self):
|
||||
"""Test _has_dify_refs function"""
|
||||
# Schemas with Dify refs
|
||||
assert _has_dify_refs({"$ref": "https://dify.ai/schemas/v1/file.json"})
|
||||
assert _has_dify_refs(
|
||||
{"type": "object", "properties": {"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}}}
|
||||
)
|
||||
assert _has_dify_refs([{"type": "string"}, {"$ref": "https://dify.ai/schemas/v1/file.json"}])
|
||||
assert _has_dify_refs(
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {"nested": {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Schemas without Dify refs
|
||||
assert not _has_dify_refs({"type": "string"})
|
||||
assert not _has_dify_refs(
|
||||
{"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "number"}}}
|
||||
)
|
||||
assert not _has_dify_refs(
|
||||
[{"type": "string"}, {"type": "number"}, {"type": "object", "properties": {"name": {"type": "string"}}}]
|
||||
)
|
||||
|
||||
# Schemas with non-Dify refs (should return False)
|
||||
assert not _has_dify_refs({"$ref": "https://example.com/schema.json"})
|
||||
assert not _has_dify_refs(
|
||||
{"type": "object", "properties": {"external": {"$ref": "https://example.com/external.json"}}}
|
||||
)
|
||||
|
||||
# Primitive types
|
||||
assert not _has_dify_refs("string")
|
||||
assert not _has_dify_refs(123)
|
||||
assert not _has_dify_refs(True)
|
||||
assert not _has_dify_refs(None)
|
||||
|
||||
def test_has_dify_refs_hybrid_vs_recursive(self):
|
||||
"""Test that hybrid and recursive detection give same results"""
|
||||
test_schemas = [
|
||||
# No refs
|
||||
{"type": "string"},
|
||||
{"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
[{"type": "string"}, {"type": "number"}],
|
||||
# With Dify refs
|
||||
{"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
{"type": "object", "properties": {"data": {"$ref": "https://dify.ai/schemas/v1/file.json"}}},
|
||||
[{"type": "string"}, {"$ref": "https://dify.ai/schemas/v1/qa_structure.json"}],
|
||||
# With non-Dify refs
|
||||
{"$ref": "https://example.com/schema.json"},
|
||||
{"type": "object", "properties": {"external": {"$ref": "https://example.com/external.json"}}},
|
||||
# Complex nested
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level2": {"type": "array", "items": {"$ref": "https://dify.ai/schemas/v1/file.json"}}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
# Edge cases
|
||||
{"description": "This mentions $ref but is not a reference"},
|
||||
{"$ref": "not-a-url"},
|
||||
# Primitive types
|
||||
"string",
|
||||
123,
|
||||
True,
|
||||
None,
|
||||
[],
|
||||
]
|
||||
|
||||
for schema in test_schemas:
|
||||
hybrid_result = _has_dify_refs_hybrid(schema)
|
||||
recursive_result = _has_dify_refs_recursive(schema)
|
||||
|
||||
assert hybrid_result == recursive_result, f"Mismatch for schema: {schema}"
|
||||
|
||||
def test_parse_dify_schema_uri(self):
|
||||
"""Test parse_dify_schema_uri function"""
|
||||
# Valid URIs
|
||||
assert parse_dify_schema_uri("https://dify.ai/schemas/v1/file.json") == ("v1", "file")
|
||||
assert parse_dify_schema_uri("https://dify.ai/schemas/v2/complex_name.json") == ("v2", "complex_name")
|
||||
assert parse_dify_schema_uri("https://dify.ai/schemas/v999/test-file.json") == ("v999", "test-file")
|
||||
|
||||
# Invalid URIs
|
||||
assert parse_dify_schema_uri("https://example.com/schema.json") == ("", "")
|
||||
assert parse_dify_schema_uri("invalid") == ("", "")
|
||||
assert parse_dify_schema_uri("") == ("", "")
|
||||
|
||||
def test_remove_metadata_fields(self):
|
||||
"""Test _remove_metadata_fields function"""
|
||||
schema = {
|
||||
"$id": "should be removed",
|
||||
"$schema": "should be removed",
|
||||
"version": "should be removed",
|
||||
"type": "object",
|
||||
"title": "should remain",
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
cleaned = _remove_metadata_fields(schema)
|
||||
|
||||
assert "$id" not in cleaned
|
||||
assert "$schema" not in cleaned
|
||||
assert "version" not in cleaned
|
||||
assert cleaned["type"] == "object"
|
||||
assert cleaned["title"] == "should remain"
|
||||
assert "properties" in cleaned
|
||||
|
||||
# Original should be unchanged
|
||||
assert "$id" in schema
|
||||
|
||||
|
||||
class TestSchemaResolverClass:
|
||||
"""Test SchemaResolver class specifically"""
|
||||
|
||||
def test_resolver_initialization(self):
|
||||
"""Test resolver initialization"""
|
||||
# Default initialization
|
||||
resolver = SchemaResolver()
|
||||
assert resolver.max_depth == 10
|
||||
assert resolver.registry is not None
|
||||
|
||||
# Custom initialization
|
||||
custom_registry = MagicMock()
|
||||
resolver = SchemaResolver(registry=custom_registry, max_depth=5)
|
||||
assert resolver.max_depth == 5
|
||||
assert resolver.registry is custom_registry
|
||||
|
||||
def test_cache_sharing(self):
|
||||
"""Test that cache is shared between resolver instances"""
|
||||
SchemaResolver.clear_cache()
|
||||
|
||||
schema = {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
|
||||
# First resolver populates cache
|
||||
resolver1 = SchemaResolver()
|
||||
result1 = resolver1.resolve(schema)
|
||||
|
||||
# Second resolver should use the same cache
|
||||
resolver2 = SchemaResolver()
|
||||
with patch.object(resolver2.registry, "get_schema") as mock_get:
|
||||
result2 = resolver2.resolve(schema)
|
||||
# Should not call registry since it's in cache
|
||||
mock_get.assert_not_called()
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_resolver_with_list_schema(self):
|
||||
"""Test resolver with list as root schema"""
|
||||
list_schema = [
|
||||
{"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
{"type": "string"},
|
||||
{"$ref": "https://dify.ai/schemas/v1/qa_structure.json"},
|
||||
]
|
||||
|
||||
resolver = SchemaResolver()
|
||||
resolved = resolver.resolve(list_schema)
|
||||
|
||||
assert isinstance(resolved, list)
|
||||
assert len(resolved) == 3
|
||||
assert resolved[0]["type"] == "object"
|
||||
assert resolved[0]["title"] == "File"
|
||||
assert resolved[1] == {"type": "string"}
|
||||
assert resolved[2]["type"] == "object"
|
||||
assert resolved[2]["title"] == "Q&A Structure"
|
||||
|
||||
def test_cache_performance(self):
|
||||
"""Test that caching improves performance"""
|
||||
SchemaResolver.clear_cache()
|
||||
|
||||
# Create a schema with many references to the same schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
f"prop_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
for i in range(50) # Reduced to avoid depth issues
|
||||
},
|
||||
}
|
||||
|
||||
# First run (no cache) - run multiple times to warm up
|
||||
results1 = []
|
||||
for _ in range(3):
|
||||
SchemaResolver.clear_cache()
|
||||
start = time.perf_counter()
|
||||
result1 = resolve_dify_schema_refs(schema)
|
||||
time_no_cache = time.perf_counter() - start
|
||||
results1.append(time_no_cache)
|
||||
|
||||
avg_time_no_cache = sum(results1) / len(results1)
|
||||
|
||||
# Second run (with cache) - run multiple times
|
||||
results2 = []
|
||||
for _ in range(3):
|
||||
start = time.perf_counter()
|
||||
result2 = resolve_dify_schema_refs(schema)
|
||||
time_with_cache = time.perf_counter() - start
|
||||
results2.append(time_with_cache)
|
||||
|
||||
avg_time_with_cache = sum(results2) / len(results2)
|
||||
|
||||
# Cache should make it faster (more lenient check)
|
||||
assert result1 == result2
|
||||
# Cache should provide some performance benefit (allow for measurement variance)
|
||||
# We expect cache to be faster, but allow for small timing variations
|
||||
performance_ratio = avg_time_with_cache / avg_time_no_cache if avg_time_no_cache > 0 else 1.0
|
||||
assert performance_ratio <= 2.0, f"Cache performance degraded too much: {performance_ratio}"
|
||||
|
||||
def test_fast_path_performance_no_refs(self):
|
||||
"""Test that schemas without $refs use fast path and avoid deep copying"""
|
||||
# Create a moderately complex schema without any $refs (typical plugin output_schema)
|
||||
no_refs_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
f"property_{i}": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"value": {"type": "number"},
|
||||
"items": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
for i in range(50)
|
||||
},
|
||||
}
|
||||
|
||||
# Measure fast path (no refs) performance
|
||||
fast_times = []
|
||||
for _ in range(10):
|
||||
start = time.perf_counter()
|
||||
result_fast = resolve_dify_schema_refs(no_refs_schema)
|
||||
elapsed = time.perf_counter() - start
|
||||
fast_times.append(elapsed)
|
||||
|
||||
avg_fast_time = sum(fast_times) / len(fast_times)
|
||||
|
||||
# Most importantly: result should be identical to input (no copying)
|
||||
assert result_fast is no_refs_schema
|
||||
|
||||
# Create schema with $refs for comparison (same structure size)
|
||||
with_refs_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
f"property_{i}": {"$ref": "https://dify.ai/schemas/v1/file.json"}
|
||||
for i in range(20) # Fewer to avoid depth issues but still comparable
|
||||
},
|
||||
}
|
||||
|
||||
# Measure slow path (with refs) performance
|
||||
SchemaResolver.clear_cache()
|
||||
slow_times = []
|
||||
for _ in range(10):
|
||||
SchemaResolver.clear_cache()
|
||||
start = time.perf_counter()
|
||||
result_slow = resolve_dify_schema_refs(with_refs_schema, max_depth=50)
|
||||
elapsed = time.perf_counter() - start
|
||||
slow_times.append(elapsed)
|
||||
|
||||
avg_slow_time = sum(slow_times) / len(slow_times)
|
||||
|
||||
# The key benefit: fast path should be reasonably fast (main goal is no deep copy)
|
||||
# and definitely avoid the expensive BFS resolution
|
||||
# Even if detection has some overhead, it should still be faster for typical cases
|
||||
print(f"Fast path (no refs): {avg_fast_time:.6f}s")
|
||||
print(f"Slow path (with refs): {avg_slow_time:.6f}s")
|
||||
|
||||
# More lenient check: fast path should be at least somewhat competitive
|
||||
# The main benefit is avoiding deep copy and BFS, not necessarily being 5x faster
|
||||
assert avg_fast_time < avg_slow_time * 2 # Should not be more than 2x slower
|
||||
|
||||
def test_batch_processing_performance(self):
|
||||
"""Test performance improvement for batch processing of schemas without refs"""
|
||||
# Simulate the plugin tool scenario: many schemas, most without refs
|
||||
schemas_without_refs = [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {f"field_{j}": {"type": "string" if j % 2 else "number"} for j in range(10)},
|
||||
}
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
# Test batch processing performance
|
||||
start = time.perf_counter()
|
||||
results = [resolve_dify_schema_refs(schema) for schema in schemas_without_refs]
|
||||
batch_time = time.perf_counter() - start
|
||||
|
||||
# Verify all results are identical to inputs (fast path used)
|
||||
for original, result in zip(schemas_without_refs, results):
|
||||
assert result is original
|
||||
|
||||
# Should be very fast - each schema should take < 0.001 seconds on average
|
||||
avg_time_per_schema = batch_time / len(schemas_without_refs)
|
||||
assert avg_time_per_schema < 0.001
|
||||
|
||||
def test_has_dify_refs_performance(self):
|
||||
"""Test that _has_dify_refs is fast for large schemas without refs"""
|
||||
# Create a very large schema without refs
|
||||
large_schema = {"type": "object", "properties": {}}
|
||||
|
||||
# Add many nested properties
|
||||
current = large_schema
|
||||
for i in range(100):
|
||||
current["properties"][f"level_{i}"] = {"type": "object", "properties": {}}
|
||||
current = current["properties"][f"level_{i}"]
|
||||
|
||||
# _has_dify_refs should be fast even for large schemas
|
||||
times = []
|
||||
for _ in range(50):
|
||||
start = time.perf_counter()
|
||||
has_refs = _has_dify_refs(large_schema)
|
||||
elapsed = time.perf_counter() - start
|
||||
times.append(elapsed)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
|
||||
# Should be False and fast
|
||||
assert not has_refs
|
||||
assert avg_time < 0.01 # Should complete in less than 10ms
|
||||
|
||||
def test_hybrid_vs_recursive_performance(self):
|
||||
"""Test performance comparison between hybrid and recursive detection"""
|
||||
# Create test schemas of different types and sizes
|
||||
test_cases = [
|
||||
# Case 1: Small schema without refs (most common case)
|
||||
{
|
||||
"name": "small_no_refs",
|
||||
"schema": {"type": "object", "properties": {"name": {"type": "string"}, "value": {"type": "number"}}},
|
||||
"expected": False,
|
||||
},
|
||||
# Case 2: Medium schema without refs
|
||||
{
|
||||
"name": "medium_no_refs",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
f"field_{i}": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"value": {"type": "number"},
|
||||
"items": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
for i in range(20)
|
||||
},
|
||||
},
|
||||
"expected": False,
|
||||
},
|
||||
# Case 3: Large schema without refs
|
||||
{"name": "large_no_refs", "schema": {"type": "object", "properties": {}}, "expected": False},
|
||||
# Case 4: Schema with Dify refs
|
||||
{
|
||||
"name": "with_dify_refs",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
"data": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"expected": True,
|
||||
},
|
||||
# Case 5: Schema with non-Dify refs
|
||||
{
|
||||
"name": "with_external_refs",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"external": {"$ref": "https://example.com/schema.json"}, "data": {"type": "string"}},
|
||||
},
|
||||
"expected": False,
|
||||
},
|
||||
]
|
||||
|
||||
# Add deep nesting to large schema
|
||||
current = test_cases[2]["schema"]
|
||||
for i in range(50):
|
||||
current["properties"][f"level_{i}"] = {"type": "object", "properties": {}}
|
||||
current = current["properties"][f"level_{i}"]
|
||||
|
||||
# Performance comparison
|
||||
for test_case in test_cases:
|
||||
schema = test_case["schema"]
|
||||
expected = test_case["expected"]
|
||||
name = test_case["name"]
|
||||
|
||||
# Test correctness first
|
||||
assert _has_dify_refs_hybrid(schema) == expected
|
||||
assert _has_dify_refs_recursive(schema) == expected
|
||||
|
||||
# Measure hybrid performance
|
||||
hybrid_times = []
|
||||
for _ in range(10):
|
||||
start = time.perf_counter()
|
||||
result_hybrid = _has_dify_refs_hybrid(schema)
|
||||
elapsed = time.perf_counter() - start
|
||||
hybrid_times.append(elapsed)
|
||||
|
||||
# Measure recursive performance
|
||||
recursive_times = []
|
||||
for _ in range(10):
|
||||
start = time.perf_counter()
|
||||
result_recursive = _has_dify_refs_recursive(schema)
|
||||
elapsed = time.perf_counter() - start
|
||||
recursive_times.append(elapsed)
|
||||
|
||||
avg_hybrid = sum(hybrid_times) / len(hybrid_times)
|
||||
avg_recursive = sum(recursive_times) / len(recursive_times)
|
||||
|
||||
print(f"{name}: hybrid={avg_hybrid:.6f}s, recursive={avg_recursive:.6f}s")
|
||||
|
||||
# Results should be identical
|
||||
assert result_hybrid == result_recursive == expected
|
||||
|
||||
# For schemas without refs, hybrid should be competitive or better
|
||||
if not expected: # No refs case
|
||||
# Hybrid might be slightly slower due to JSON serialization overhead,
|
||||
# but should not be dramatically worse
|
||||
assert avg_hybrid < avg_recursive * 5 # At most 5x slower
|
||||
|
||||
def test_string_matching_edge_cases(self):
|
||||
"""Test edge cases for string-based detection"""
|
||||
# Case 1: False positive potential - $ref in description
|
||||
schema_false_positive = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {"type": "string", "description": "This field explains how $ref works in JSON Schema"}
|
||||
},
|
||||
}
|
||||
|
||||
# Both methods should return False
|
||||
assert not _has_dify_refs_hybrid(schema_false_positive)
|
||||
assert not _has_dify_refs_recursive(schema_false_positive)
|
||||
|
||||
# Case 2: Complex URL patterns
|
||||
complex_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dify_url": {"type": "string", "default": "https://dify.ai/schemas/info"},
|
||||
"actual_ref": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Both methods should return True (due to actual_ref)
|
||||
assert _has_dify_refs_hybrid(complex_schema)
|
||||
assert _has_dify_refs_recursive(complex_schema)
|
||||
|
||||
# Case 3: Non-JSON serializable objects (should fall back to recursive)
|
||||
import datetime
|
||||
|
||||
non_serializable = {
|
||||
"type": "object",
|
||||
"timestamp": datetime.datetime.now(),
|
||||
"data": {"$ref": "https://dify.ai/schemas/v1/file.json"},
|
||||
}
|
||||
|
||||
# Hybrid should fall back to recursive and still work
|
||||
assert _has_dify_refs_hybrid(non_serializable)
|
||||
assert _has_dify_refs_recursive(non_serializable)
|
||||
56
dify/api/tests/unit_tests/core/test_file.py
Normal file
56
dify/api/tests/unit_tests/core/test_file.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import json
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType, FileUploadConfig
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
def test_file_to_dict():
|
||||
file = File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
storage_key="storage_key",
|
||||
)
|
||||
|
||||
file_dict = file.to_dict()
|
||||
assert "_storage_key" not in file_dict
|
||||
assert "url" in file_dict
|
||||
|
||||
|
||||
def test_workflow_features_with_image():
|
||||
# Create a feature dict that mimics the old structure with image config
|
||||
features = {
|
||||
"file_upload": {
|
||||
"image": {"enabled": True, "number_limits": 5, "transfer_methods": ["remote_url", "local_file"]}
|
||||
}
|
||||
}
|
||||
|
||||
# Create a workflow instance with the features
|
||||
workflow = Workflow(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
type="chat",
|
||||
version="1.0",
|
||||
graph="{}",
|
||||
features=json.dumps(features),
|
||||
created_by="user-1",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Get the converted features through the property
|
||||
converted_features = json.loads(workflow.features)
|
||||
|
||||
# Create FileUploadConfig from the converted features
|
||||
file_upload_config = FileUploadConfig.model_validate(converted_features["file_upload"])
|
||||
|
||||
# Validate the config
|
||||
assert file_upload_config.number_limits == 5
|
||||
assert list(file_upload_config.allowed_file_types) == [FileType.IMAGE]
|
||||
assert list(file_upload_config.allowed_file_upload_methods) == [
|
||||
FileTransferMethod.REMOTE_URL,
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
]
|
||||
assert list(file_upload_config.allowed_file_extensions) == []
|
||||
73
dify/api/tests/unit_tests/core/test_model_manager.py
Normal file
73
dify/api/tests/unit_tests/core/test_model_manager.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.model_manager import LBModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lb_model_manager():
|
||||
load_balancing_configs = [
|
||||
ModelLoadBalancingConfiguration(id="id1", name="__inherit__", credentials={}),
|
||||
ModelLoadBalancingConfiguration(id="id2", name="first", credentials={"openai_api_key": "fake_key"}),
|
||||
ModelLoadBalancingConfiguration(id="id3", name="second", credentials={"openai_api_key": "fake_key"}),
|
||||
]
|
||||
|
||||
lb_model_manager = LBModelManager(
|
||||
tenant_id="tenant_id",
|
||||
provider="openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4",
|
||||
load_balancing_configs=load_balancing_configs,
|
||||
managed_credentials={"openai_api_key": "fake_key"},
|
||||
)
|
||||
|
||||
lb_model_manager.cooldown = MagicMock(return_value=None)
|
||||
|
||||
def is_cooldown(config: ModelLoadBalancingConfiguration):
|
||||
if config.id == "id1":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
lb_model_manager.in_cooldown = MagicMock(side_effect=is_cooldown)
|
||||
|
||||
return lb_model_manager
|
||||
|
||||
|
||||
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
|
||||
# initialize redis client
|
||||
redis_client.initialize(redis.Redis())
|
||||
|
||||
assert len(lb_model_manager._load_balancing_configs) == 3
|
||||
|
||||
config1 = lb_model_manager._load_balancing_configs[0]
|
||||
config2 = lb_model_manager._load_balancing_configs[1]
|
||||
config3 = lb_model_manager._load_balancing_configs[2]
|
||||
|
||||
assert lb_model_manager.in_cooldown(config1) is True
|
||||
assert lb_model_manager.in_cooldown(config2) is False
|
||||
assert lb_model_manager.in_cooldown(config3) is False
|
||||
|
||||
start_index = 0
|
||||
|
||||
def incr(key):
|
||||
nonlocal start_index
|
||||
start_index += 1
|
||||
return start_index
|
||||
|
||||
with (
|
||||
patch.object(redis_client, "incr", side_effect=incr),
|
||||
patch.object(redis_client, "set", return_value=None),
|
||||
patch.object(redis_client, "expire", return_value=None),
|
||||
):
|
||||
config = lb_model_manager.fetch_next()
|
||||
assert config == config2
|
||||
|
||||
config = lb_model_manager.fetch_next()
|
||||
assert config == config3
|
||||
485
dify/api/tests/unit_tests/core/test_provider_configuration.py
Normal file
485
dify/api/tests/unit_tests/core/test_provider_configuration.py
Normal file
@@ -0,0 +1,485 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
ModelSettings,
|
||||
ProviderQuotaType,
|
||||
QuotaConfiguration,
|
||||
QuotaUnit,
|
||||
RestrictModel,
|
||||
SystemConfiguration,
|
||||
)
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormOption,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity():
|
||||
"""Mock provider entity with basic configuration"""
|
||||
provider_entity = ProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
|
||||
description=I18nObject(en_US="OpenAI provider", zh_Hans="OpenAI 提供商"),
|
||||
icon_small=I18nObject(en_US="icon.png", zh_Hans="icon.png"),
|
||||
icon_large=I18nObject(en_US="icon.png", zh_Hans="icon.png"),
|
||||
background="background.png",
|
||||
help=None,
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
provider_credential_schema=None,
|
||||
model_credential_schema=None,
|
||||
)
|
||||
|
||||
return provider_entity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_system_configuration():
|
||||
"""Mock system configuration"""
|
||||
quota_config = QuotaConfiguration(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=1000,
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
restrict_models=[RestrictModel(model="gpt-4", reason="Experimental", model_type=ModelType.LLM)],
|
||||
)
|
||||
|
||||
system_config = SystemConfiguration(
|
||||
enabled=True,
|
||||
credentials={"openai_api_key": "test_key"},
|
||||
quota_configurations=[quota_config],
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
)
|
||||
|
||||
return system_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_custom_configuration():
|
||||
"""Mock custom configuration"""
|
||||
custom_config = CustomConfiguration(provider=None, models=[])
|
||||
return custom_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_configuration(mock_provider_entity, mock_system_configuration, mock_custom_configuration):
|
||||
"""Create a test provider configuration instance"""
|
||||
with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}):
|
||||
return ProviderConfiguration(
|
||||
tenant_id="test_tenant",
|
||||
provider=mock_provider_entity,
|
||||
preferred_provider_type=ProviderType.SYSTEM,
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=mock_system_configuration,
|
||||
custom_configuration=mock_custom_configuration,
|
||||
model_settings=[],
|
||||
)
|
||||
|
||||
|
||||
class TestProviderConfiguration:
|
||||
"""Test cases for ProviderConfiguration class"""
|
||||
|
||||
def test_get_current_credentials_system_provider_success(self, provider_configuration):
|
||||
"""Test successfully getting credentials from system provider"""
|
||||
# Arrange
|
||||
provider_configuration.using_provider_type = ProviderType.SYSTEM
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "test_key"}
|
||||
|
||||
def test_get_current_credentials_model_disabled(self, provider_configuration):
|
||||
"""Test getting credentials when model is disabled"""
|
||||
# Arrange
|
||||
model_setting = ModelSettings(
|
||||
model="gpt-4",
|
||||
model_type=ModelType.LLM,
|
||||
enabled=False,
|
||||
load_balancing_configs=[],
|
||||
has_invalid_load_balancing_configs=False,
|
||||
)
|
||||
provider_configuration.model_settings = [model_setting]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Model gpt-4 is disabled"):
|
||||
provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
def test_get_current_credentials_custom_provider_with_models(self, provider_configuration):
|
||||
"""Test getting credentials from custom provider with model configurations"""
|
||||
# Arrange
|
||||
provider_configuration.using_provider_type = ProviderType.CUSTOM
|
||||
|
||||
mock_model_config = Mock()
|
||||
mock_model_config.model_type = ModelType.LLM
|
||||
mock_model_config.model = "gpt-4"
|
||||
mock_model_config.credentials = {"openai_api_key": "custom_key"}
|
||||
provider_configuration.custom_configuration.models = [mock_model_config]
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "custom_key"}
|
||||
|
||||
def test_get_system_configuration_status_active(self, provider_configuration):
|
||||
"""Test getting active system configuration status"""
|
||||
# Arrange
|
||||
provider_configuration.system_configuration.enabled = True
|
||||
|
||||
# Act
|
||||
status = provider_configuration.get_system_configuration_status()
|
||||
|
||||
# Assert
|
||||
assert status == SystemConfigurationStatus.ACTIVE
|
||||
|
||||
def test_get_system_configuration_status_unsupported(self, provider_configuration):
|
||||
"""Test getting unsupported system configuration status"""
|
||||
# Arrange
|
||||
provider_configuration.system_configuration.enabled = False
|
||||
|
||||
# Act
|
||||
status = provider_configuration.get_system_configuration_status()
|
||||
|
||||
# Assert
|
||||
assert status == SystemConfigurationStatus.UNSUPPORTED
|
||||
|
||||
def test_get_system_configuration_status_quota_exceeded(self, provider_configuration):
|
||||
"""Test getting quota exceeded system configuration status"""
|
||||
# Arrange
|
||||
provider_configuration.system_configuration.enabled = True
|
||||
quota_config = provider_configuration.system_configuration.quota_configurations[0]
|
||||
quota_config.is_valid = False
|
||||
|
||||
# Act
|
||||
status = provider_configuration.get_system_configuration_status()
|
||||
|
||||
# Assert
|
||||
assert status == SystemConfigurationStatus.QUOTA_EXCEEDED
|
||||
|
||||
def test_is_custom_configuration_available_with_provider(self, provider_configuration):
|
||||
"""Test custom configuration availability with provider credentials"""
|
||||
# Arrange
|
||||
mock_provider = Mock()
|
||||
mock_provider.available_credentials = ["openai_api_key"]
|
||||
provider_configuration.custom_configuration.provider = mock_provider
|
||||
provider_configuration.custom_configuration.models = []
|
||||
|
||||
# Act
|
||||
result = provider_configuration.is_custom_configuration_available()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_is_custom_configuration_available_with_models(self, provider_configuration):
|
||||
"""Test custom configuration availability with model configurations"""
|
||||
# Arrange
|
||||
provider_configuration.custom_configuration.provider = None
|
||||
provider_configuration.custom_configuration.models = [Mock()]
|
||||
|
||||
# Act
|
||||
result = provider_configuration.is_custom_configuration_available()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_is_custom_configuration_available_false(self, provider_configuration):
|
||||
"""Test custom configuration not available"""
|
||||
# Arrange
|
||||
provider_configuration.custom_configuration.provider = None
|
||||
provider_configuration.custom_configuration.models = []
|
||||
|
||||
# Act
|
||||
result = provider_configuration.is_custom_configuration_available()
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_provider_record_found(self, mock_session, provider_configuration):
|
||||
"""Test getting provider record successfully"""
|
||||
# Arrange
|
||||
mock_provider = Mock(spec=Provider)
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_provider
|
||||
|
||||
# Act
|
||||
result = provider_configuration._get_provider_record(mock_session_instance)
|
||||
|
||||
# Assert
|
||||
assert result == mock_provider
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_provider_record_not_found(self, mock_session, provider_configuration):
|
||||
"""Test getting provider record when not found"""
|
||||
# Arrange
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
# Act
|
||||
result = provider_configuration._get_provider_record(mock_session_instance)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_init_with_customizable_model_only(
|
||||
self, mock_provider_entity, mock_system_configuration, mock_custom_configuration
|
||||
):
|
||||
"""Test initialization with customizable model only configuration"""
|
||||
# Arrange
|
||||
mock_provider_entity.configurate_methods = [ConfigurateMethod.CUSTOMIZABLE_MODEL]
|
||||
|
||||
# Act
|
||||
with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}):
|
||||
config = ProviderConfiguration(
|
||||
tenant_id="test_tenant",
|
||||
provider=mock_provider_entity,
|
||||
preferred_provider_type=ProviderType.SYSTEM,
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=mock_system_configuration,
|
||||
custom_configuration=mock_custom_configuration,
|
||||
model_settings=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert ConfigurateMethod.PREDEFINED_MODEL in config.provider.configurate_methods
|
||||
|
||||
def test_get_current_credentials_with_restricted_models(self, provider_configuration):
|
||||
"""Test getting credentials with model restrictions"""
|
||||
# Arrange
|
||||
provider_configuration.using_provider_type = ProviderType.SYSTEM
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-3.5-turbo")
|
||||
|
||||
# Assert
|
||||
assert credentials is not None
|
||||
assert "openai_api_key" in credentials
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_specific_provider_credential_success(self, mock_session, provider_configuration):
|
||||
"""Test getting specific provider credential successfully"""
|
||||
# Arrange
|
||||
credential_id = "test_credential_id"
|
||||
mock_credential = Mock()
|
||||
mock_credential.encrypted_config = '{"openai_api_key": "encrypted_key"}'
|
||||
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_credential
|
||||
|
||||
# Act
|
||||
with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get:
|
||||
mock_get.return_value = {"openai_api_key": "test_key"}
|
||||
result = provider_configuration._get_specific_provider_credential(credential_id)
|
||||
|
||||
# Assert
|
||||
assert result == {"openai_api_key": "test_key"}
|
||||
|
||||
@patch("core.entities.provider_configuration.Session")
|
||||
def test_get_specific_provider_credential_not_found(self, mock_session, provider_configuration):
|
||||
"""Test getting specific provider credential when not found"""
|
||||
# Arrange
|
||||
credential_id = "nonexistent_credential_id"
|
||||
|
||||
mock_session_instance = Mock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get:
|
||||
mock_get.return_value = None
|
||||
result = provider_configuration._get_specific_provider_credential(credential_id)
|
||||
assert result is None
|
||||
|
||||
# Act
|
||||
credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4")
|
||||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "test_key"}
|
||||
|
||||
def test_extract_secret_variables_with_secret_input(self, provider_configuration):
|
||||
"""Test extracting secret variables from credential form schemas"""
|
||||
# Arrange
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="api_key",
|
||||
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="secret_token",
|
||||
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert len(secret_variables) == 2
|
||||
assert "api_key" in secret_variables
|
||||
assert "secret_token" in secret_variables
|
||||
assert "model_name" not in secret_variables
|
||||
|
||||
def test_extract_secret_variables_no_secret_input(self, provider_configuration):
|
||||
"""Test extracting secret variables when no secret input fields exist"""
|
||||
# Arrange
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=FormType.SELECT,
|
||||
required=True,
|
||||
options=[FormOption(label=I18nObject(en_US="0.1", zh_Hans="0.1"), value="0.1")],
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert len(secret_variables) == 0
|
||||
|
||||
def test_extract_secret_variables_empty_list(self, provider_configuration):
|
||||
"""Test extracting secret variables from empty credential form schemas"""
|
||||
# Arrange
|
||||
credential_form_schemas = []
|
||||
|
||||
# Act
|
||||
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert len(secret_variables) == 0
|
||||
|
||||
@patch("core.entities.provider_configuration.encrypter")
|
||||
def test_obfuscated_credentials_with_secret_variables(self, mock_encrypter, provider_configuration):
|
||||
"""Test obfuscating credentials with secret variables"""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"api_key": "sk-1234567890abcdef",
|
||||
"model_name": "gpt-4",
|
||||
"secret_token": "secret_value_123",
|
||||
"temperature": "0.7",
|
||||
}
|
||||
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="api_key",
|
||||
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="secret_token",
|
||||
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=False,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
mock_encrypter.obfuscated_token.side_effect = lambda x: f"***{x[-4:]}"
|
||||
|
||||
# Act
|
||||
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert obfuscated["api_key"] == "***cdef"
|
||||
assert obfuscated["model_name"] == "gpt-4" # Not obfuscated
|
||||
assert obfuscated["secret_token"] == "***_123"
|
||||
assert obfuscated["temperature"] == "0.7" # Not obfuscated
|
||||
|
||||
# Verify encrypter was called for secret fields only
|
||||
assert mock_encrypter.obfuscated_token.call_count == 2
|
||||
mock_encrypter.obfuscated_token.assert_any_call("sk-1234567890abcdef")
|
||||
mock_encrypter.obfuscated_token.assert_any_call("secret_value_123")
|
||||
|
||||
def test_obfuscated_credentials_no_secret_variables(self, provider_configuration):
|
||||
"""Test obfuscating credentials when no secret variables exist"""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"model_name": "gpt-4",
|
||||
"temperature": "0.7",
|
||||
"max_tokens": "1000",
|
||||
}
|
||||
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="max_tokens",
|
||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大令牌数"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert obfuscated == credentials # No changes expected
|
||||
|
||||
def test_obfuscated_credentials_empty_credentials(self, provider_configuration):
|
||||
"""Test obfuscating empty credentials"""
|
||||
# Arrange
|
||||
credentials = {}
|
||||
credential_form_schemas = []
|
||||
|
||||
# Act
|
||||
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert obfuscated == {}
|
||||
192
dify/api/tests/unit_tests/core/test_provider_manager.py
Normal file
192
dify/api/tests/unit_tests/core/test_provider_manager.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.entities.provider_entities import ModelSettings
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity(mocker: MockerFixture):
|
||||
mock_entity = mocker.Mock()
|
||||
mock_entity.provider = "openai"
|
||||
mock_entity.configurate_methods = ["predefined-model"]
|
||||
mock_entity.supported_model_types = [ModelType.LLM]
|
||||
|
||||
# Use PropertyMock to ensure credential_form_schemas is iterable
|
||||
provider_credential_schema = mocker.Mock()
|
||||
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||
mock_entity.provider_credential_schema = provider_credential_schema
|
||||
|
||||
model_credential_schema = mocker.Mock()
|
||||
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||
mock_entity.model_credential_schema = model_credential_schema
|
||||
|
||||
return mock_entity
|
||||
|
||||
|
||||
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
ps = ProviderModelSetting(
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
ps.id = "id"
|
||||
|
||||
provider_model_settings = [ps]
|
||||
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="first",
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
load_balancing_model_configs[1].id = "id2"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 2
|
||||
assert result[0].load_balancing_configs[0].name == "__inherit__"
|
||||
assert result[0].load_balancing_configs[1].name == "first"
|
||||
|
||||
|
||||
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
|
||||
ps = ProviderModelSetting(
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
ps.id = "id"
|
||||
provider_model_settings = [ps]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
|
||||
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
ps = ProviderModelSetting(
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
ps.id = "id"
|
||||
provider_model_settings = [ps]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="first",
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
load_balancing_model_configs[1].id = "id2"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
@@ -0,0 +1,102 @@
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
import pytz
|
||||
|
||||
from core.trigger.debug import event_selectors
|
||||
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig
|
||||
|
||||
|
||||
class _DummyRedis:
|
||||
def __init__(self):
|
||||
self.store: dict[str, str] = {}
|
||||
|
||||
def get(self, key: str):
|
||||
return self.store.get(key)
|
||||
|
||||
def setex(self, name: str, time: int, value: str):
|
||||
self.store[name] = value
|
||||
|
||||
def expire(self, name: str, ttl: int):
|
||||
# Expiration not required for these tests.
|
||||
pass
|
||||
|
||||
def delete(self, name: str):
|
||||
self.store.pop(name, None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_schedule_config() -> ScheduleConfig:
|
||||
return ScheduleConfig(
|
||||
node_id="node-1",
|
||||
cron_expression="* * * * *",
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_schedule_service(monkeypatch: pytest.MonkeyPatch, dummy_schedule_config: ScheduleConfig):
|
||||
# Ensure poller always receives the deterministic config.
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.schedule_service.ScheduleService.to_schedule_config",
|
||||
staticmethod(lambda *_args, **_kwargs: dummy_schedule_config),
|
||||
)
|
||||
|
||||
|
||||
def _make_poller(
|
||||
monkeypatch: pytest.MonkeyPatch, redis_client: _DummyRedis
|
||||
) -> event_selectors.ScheduleTriggerDebugEventPoller:
|
||||
monkeypatch.setattr(event_selectors, "redis_client", redis_client)
|
||||
return event_selectors.ScheduleTriggerDebugEventPoller(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
node_config={"id": "node-1", "data": {"mode": "cron"}},
|
||||
node_id="node-1",
|
||||
)
|
||||
|
||||
|
||||
def test_schedule_poller_handles_aware_next_run(monkeypatch: pytest.MonkeyPatch):
|
||||
redis_client = _DummyRedis()
|
||||
poller = _make_poller(monkeypatch, redis_client)
|
||||
|
||||
base_now = datetime(2025, 1, 1, 12, 0, 10)
|
||||
aware_next_run = datetime(2025, 1, 1, 12, 0, 5, tzinfo=UTC)
|
||||
|
||||
monkeypatch.setattr(event_selectors, "naive_utc_now", lambda: base_now)
|
||||
monkeypatch.setattr(event_selectors, "calculate_next_run_at", lambda *_: aware_next_run)
|
||||
|
||||
event = poller.poll()
|
||||
|
||||
assert event is not None
|
||||
assert event.node_id == "node-1"
|
||||
assert event.workflow_args["inputs"] == {}
|
||||
|
||||
|
||||
def test_schedule_runtime_cache_normalizes_timezone(
|
||||
monkeypatch: pytest.MonkeyPatch, dummy_schedule_config: ScheduleConfig
|
||||
):
|
||||
redis_client = _DummyRedis()
|
||||
poller = _make_poller(monkeypatch, redis_client)
|
||||
|
||||
localized_time = pytz.timezone("Asia/Shanghai").localize(datetime(2025, 1, 1, 20, 0, 0))
|
||||
|
||||
cron_hash = hashlib.sha256(dummy_schedule_config.cron_expression.encode()).hexdigest()
|
||||
cache_key = poller.schedule_debug_runtime_key(cron_hash)
|
||||
|
||||
redis_client.store[cache_key] = json.dumps(
|
||||
{
|
||||
"cache_key": cache_key,
|
||||
"timezone": dummy_schedule_config.timezone,
|
||||
"cron_expression": dummy_schedule_config.cron_expression,
|
||||
"next_run_at": localized_time.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
runtime = poller.get_or_create_schedule_debug_runtime()
|
||||
|
||||
expected = localized_time.astimezone(UTC).replace(tzinfo=None)
|
||||
assert runtime.next_run_at == expected
|
||||
assert runtime.next_run_at.tzinfo is None
|
||||
29
dify/api/tests/unit_tests/core/tools/test_tool_entities.py
Normal file
29
dify/api/tests/unit_tests/core/tools/test_tool_entities.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
|
||||
|
||||
|
||||
def _make_identity() -> ToolIdentity:
|
||||
return ToolIdentity(
|
||||
author="author",
|
||||
name="tool",
|
||||
label=I18nObject(en_US="Label"),
|
||||
provider="builtin",
|
||||
)
|
||||
|
||||
|
||||
def test_log_message_metadata_none_defaults_to_empty_dict():
|
||||
log_message = ToolInvokeMessage.LogMessage(
|
||||
id="log-1",
|
||||
label="Log entry",
|
||||
status=ToolInvokeMessage.LogMessage.LogStatus.START,
|
||||
data={},
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
assert log_message.metadata == {}
|
||||
|
||||
|
||||
def test_tool_entity_output_schema_none_defaults_to_empty_dict():
|
||||
entity = ToolEntity(identity=_make_identity(), output_schema=None)
|
||||
|
||||
assert entity.output_schema == {}
|
||||
@@ -0,0 +1,49 @@
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
|
||||
|
||||
def test_get_parameter_type():
|
||||
assert ToolParameter.ToolParameterType.STRING.as_normal_type() == "string"
|
||||
assert ToolParameter.ToolParameterType.SELECT.as_normal_type() == "string"
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.as_normal_type() == "string"
|
||||
assert ToolParameter.ToolParameterType.BOOLEAN.as_normal_type() == "boolean"
|
||||
assert ToolParameter.ToolParameterType.NUMBER.as_normal_type() == "number"
|
||||
assert ToolParameter.ToolParameterType.FILE.as_normal_type() == "file"
|
||||
assert ToolParameter.ToolParameterType.FILES.as_normal_type() == "files"
|
||||
|
||||
|
||||
def test_cast_parameter_by_type():
|
||||
# string
|
||||
assert ToolParameter.ToolParameterType.STRING.cast_value("test") == "test"
|
||||
assert ToolParameter.ToolParameterType.STRING.cast_value(1) == "1"
|
||||
assert ToolParameter.ToolParameterType.STRING.cast_value(1.0) == "1.0"
|
||||
assert ToolParameter.ToolParameterType.STRING.cast_value(None) == ""
|
||||
|
||||
# secret input
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value("test") == "test"
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1) == "1"
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(1.0) == "1.0"
|
||||
assert ToolParameter.ToolParameterType.SECRET_INPUT.cast_value(None) == ""
|
||||
|
||||
# select
|
||||
assert ToolParameter.ToolParameterType.SELECT.cast_value("test") == "test"
|
||||
assert ToolParameter.ToolParameterType.SELECT.cast_value(1) == "1"
|
||||
assert ToolParameter.ToolParameterType.SELECT.cast_value(1.0) == "1.0"
|
||||
assert ToolParameter.ToolParameterType.SELECT.cast_value(None) == ""
|
||||
|
||||
# boolean
|
||||
true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"]
|
||||
for value in true_values:
|
||||
assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is True
|
||||
|
||||
false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""]
|
||||
for value in false_values:
|
||||
assert ToolParameter.ToolParameterType.BOOLEAN.cast_value(value) is False
|
||||
|
||||
# number
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value("1") == 1
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value("1.0") == 1.0
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value("-1.0") == -1.0
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value(1) == 1
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value(1.0) == 1.0
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value(-1.0) == -1.0
|
||||
assert ToolParameter.ToolParameterType.NUMBER.cast_value(None) is None
|
||||
181
dify/api/tests/unit_tests/core/tools/utils/test_encryption.py
Normal file
181
dify/api/tests/unit_tests/core/tools/utils/test_encryption.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import copy
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper.provider_encryption import ProviderConfigEncrypter
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# A no-op cache
|
||||
# ---------------------------
|
||||
class NoopCache:
|
||||
"""Simple cache stub: always returns None, does nothing for set/delete."""
|
||||
|
||||
def get(self):
|
||||
return None
|
||||
|
||||
def set(self, config):
|
||||
pass
|
||||
|
||||
def delete(self):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def secret_field() -> BasicProviderConfig:
|
||||
"""A SECRET_INPUT field named 'password'."""
|
||||
return BasicProviderConfig(
|
||||
name="password",
|
||||
type=BasicProviderConfig.Type.SECRET_INPUT,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def normal_field() -> BasicProviderConfig:
|
||||
"""A TEXT_INPUT field named 'username'."""
|
||||
return BasicProviderConfig(
|
||||
name="username",
|
||||
type=BasicProviderConfig.Type.TEXT_INPUT,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def encrypter_obj(secret_field, normal_field):
|
||||
"""
|
||||
Build ProviderConfigEncrypter with:
|
||||
- tenant_id = tenant123
|
||||
- one secret field (password) and one normal field (username)
|
||||
- NoopCache as cache
|
||||
"""
|
||||
return ProviderConfigEncrypter(
|
||||
tenant_id="tenant123",
|
||||
config=[secret_field, normal_field],
|
||||
provider_config_cache=NoopCache(),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# ProviderConfigEncrypter.encrypt()
|
||||
# ============================================================
|
||||
|
||||
|
||||
def test_encrypt_only_secret_is_encrypted_and_non_secret_unchanged(encrypter_obj):
|
||||
"""
|
||||
Secret field should be encrypted, non-secret field unchanged.
|
||||
Verify encrypt_token called only for secret field.
|
||||
Also check deep copy (input not modified).
|
||||
"""
|
||||
data_in = {"username": "alice", "password": "plain_pwd"}
|
||||
data_copy = copy.deepcopy(data_in)
|
||||
|
||||
with patch("core.helper.provider_encryption.encrypter.encrypt_token", return_value="CIPHERTEXT") as mock_encrypt:
|
||||
out = encrypter_obj.encrypt(data_in)
|
||||
|
||||
assert out["username"] == "alice"
|
||||
assert out["password"] == "CIPHERTEXT"
|
||||
mock_encrypt.assert_called_once_with("tenant123", "plain_pwd")
|
||||
assert data_in == data_copy # deep copy semantics
|
||||
|
||||
|
||||
def test_encrypt_missing_secret_key_is_ok(encrypter_obj):
|
||||
"""If secret field missing in input, no error and no encryption called."""
|
||||
with patch("core.helper.provider_encryption.encrypter.encrypt_token") as mock_encrypt:
|
||||
out = encrypter_obj.encrypt({"username": "alice"})
|
||||
assert out["username"] == "alice"
|
||||
mock_encrypt.assert_not_called()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# ProviderConfigEncrypter.mask_plugin_credentials()
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw", "prefix", "suffix"),
|
||||
[
|
||||
("longsecret", "lo", "et"),
|
||||
("abcdefg", "ab", "fg"),
|
||||
("1234567", "12", "67"),
|
||||
],
|
||||
)
|
||||
def test_mask_tool_credentials_long_secret(encrypter_obj, raw, prefix, suffix):
|
||||
"""
|
||||
For length > 6: keep first 2 and last 2, mask middle with '*'.
|
||||
"""
|
||||
data_in = {"username": "alice", "password": raw}
|
||||
data_copy = copy.deepcopy(data_in)
|
||||
|
||||
out = encrypter_obj.mask_plugin_credentials(data_in)
|
||||
masked = out["password"]
|
||||
|
||||
assert masked.startswith(prefix)
|
||||
assert masked.endswith(suffix)
|
||||
assert "*" in masked
|
||||
assert len(masked) == len(raw)
|
||||
assert data_in == data_copy # deep copy semantics
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raw", ["", "1", "12", "123", "123456"])
|
||||
def test_mask_tool_credentials_short_secret(encrypter_obj, raw):
|
||||
"""
|
||||
For length <= 6: fully mask with '*' of same length.
|
||||
"""
|
||||
out = encrypter_obj.mask_plugin_credentials({"password": raw})
|
||||
assert out["password"] == ("*" * len(raw))
|
||||
|
||||
|
||||
def test_mask_tool_credentials_missing_key_noop(encrypter_obj):
|
||||
"""If secret key missing, leave other fields unchanged."""
|
||||
data_in = {"username": "alice"}
|
||||
data_copy = copy.deepcopy(data_in)
|
||||
|
||||
out = encrypter_obj.mask_plugin_credentials(data_in)
|
||||
assert out["username"] == "alice"
|
||||
assert data_in == data_copy
|
||||
|
||||
|
||||
# ============================================================
|
||||
# ProviderConfigEncrypter.decrypt()
|
||||
# ============================================================
|
||||
|
||||
|
||||
def test_decrypt_normal_flow(encrypter_obj):
|
||||
"""
|
||||
Normal decrypt flow:
|
||||
- decrypt_token called for secret field
|
||||
- secret replaced with decrypted value
|
||||
- non-secret unchanged
|
||||
"""
|
||||
data_in = {"username": "alice", "password": "ENC"}
|
||||
data_copy = copy.deepcopy(data_in)
|
||||
|
||||
with patch("core.helper.provider_encryption.encrypter.decrypt_token", return_value="PLAIN") as mock_decrypt:
|
||||
out = encrypter_obj.decrypt(data_in)
|
||||
|
||||
assert out["username"] == "alice"
|
||||
assert out["password"] == "PLAIN"
|
||||
mock_decrypt.assert_called_once_with("tenant123", "ENC")
|
||||
assert data_in == data_copy # deep copy semantics
|
||||
|
||||
|
||||
@pytest.mark.parametrize("empty_val", ["", None])
|
||||
def test_decrypt_skip_empty_values(encrypter_obj, empty_val):
|
||||
"""Skip decrypt if value is empty or None, keep original."""
|
||||
with patch("core.helper.provider_encryption.encrypter.decrypt_token") as mock_decrypt:
|
||||
out = encrypter_obj.decrypt({"password": empty_val})
|
||||
|
||||
mock_decrypt.assert_not_called()
|
||||
assert out["password"] == empty_val
|
||||
|
||||
|
||||
def test_decrypt_swallow_exception_and_keep_original(encrypter_obj):
|
||||
"""
|
||||
If decrypt_token raises, exception should be swallowed,
|
||||
and original value preserved.
|
||||
"""
|
||||
with patch("core.helper.provider_encryption.encrypter.decrypt_token", side_effect=Exception("boom")):
|
||||
out = encrypter_obj.decrypt({"password": "ENC_ERR"})
|
||||
|
||||
assert out["password"] == "ENC_ERR"
|
||||
191
dify/api/tests/unit_tests/core/tools/utils/test_parser.py
Normal file
191
dify/api/tests/unit_tests/core/tools/utils/test_parser.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
|
||||
|
||||
def test_parse_openapi_to_tool_bundle_operation_id(app):
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Simple API", "version": "1.0.0"},
|
||||
"servers": [{"url": "http://localhost:3000"}],
|
||||
"paths": {
|
||||
"/": {
|
||||
"get": {
|
||||
"summary": "Root endpoint",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"/api/resources": {
|
||||
"get": {
|
||||
"summary": "Non-root endpoint without an operationId",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
}
|
||||
},
|
||||
},
|
||||
"post": {
|
||||
"summary": "Non-root endpoint with an operationId",
|
||||
"operationId": "createResource",
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Resource created",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
with app.test_request_context():
|
||||
tool_bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
|
||||
|
||||
assert len(tool_bundles) == 3
|
||||
assert tool_bundles[0].operation_id == "<root>_get"
|
||||
assert tool_bundles[1].operation_id == "apiresources_get"
|
||||
assert tool_bundles[2].operation_id == "createResource"
|
||||
|
||||
|
||||
def test_parse_openapi_to_tool_bundle_properties_all_of(app):
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Simple API", "version": "1.0.0"},
|
||||
"servers": [{"url": "http://localhost:3000"}],
|
||||
"paths": {
|
||||
"/api/resource": {
|
||||
"get": {
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Request",
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"Request": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prop1": {
|
||||
"enum": ["option1"],
|
||||
"description": "desc prop1",
|
||||
"allOf": [
|
||||
{"$ref": "#/components/schemas/AllOfItem"},
|
||||
{
|
||||
"enum": ["option2"],
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
"AllOfItem": {
|
||||
"type": "string",
|
||||
"enum": ["option3"],
|
||||
"description": "desc allOf item",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
with app.test_request_context():
|
||||
tool_bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
|
||||
|
||||
assert tool_bundles[0].parameters[0].type == "string"
|
||||
assert tool_bundles[0].parameters[0].llm_description == "desc prop1"
|
||||
# TODO: support enum in OpenAPI
|
||||
# assert set(tool_bundles[0].parameters[0].options) == {"option1", "option2", "option3"}
|
||||
|
||||
|
||||
def test_parse_openapi_to_tool_bundle_default_value_type_casting(app):
|
||||
"""
|
||||
Test that default values are properly cast to match parameter types.
|
||||
This addresses the issue where array default values like [] cause validation errors
|
||||
when parameter type is inferred as string/number/boolean.
|
||||
"""
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test API", "version": "1.0.0"},
|
||||
"servers": [{"url": "https://example.com"}],
|
||||
"paths": {
|
||||
"/product/create": {
|
||||
"post": {
|
||||
"operationId": "createProduct",
|
||||
"summary": "Create a product",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"categories": {
|
||||
"description": "List of category identifiers",
|
||||
"default": [],
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"name": {
|
||||
"description": "Product name",
|
||||
"default": "Default Product",
|
||||
"type": "string",
|
||||
},
|
||||
"price": {"description": "Product price", "default": 0.0, "type": "number"},
|
||||
"available": {
|
||||
"description": "Product availability",
|
||||
"default": True,
|
||||
"type": "boolean",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {"200": {"description": "Default Response"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with app.test_request_context():
|
||||
tool_bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
|
||||
|
||||
assert len(tool_bundles) == 1
|
||||
bundle = tool_bundles[0]
|
||||
assert len(bundle.parameters) == 4
|
||||
|
||||
# Find parameters by name
|
||||
params_by_name = {param.name: param for param in bundle.parameters}
|
||||
|
||||
# Check categories parameter (array type with [] default)
|
||||
categories_param = params_by_name["categories"]
|
||||
assert categories_param.type == "array" # Will be detected by _get_tool_parameter_type
|
||||
assert categories_param.default is None # Array default [] is converted to None
|
||||
|
||||
# Check name parameter (string type with string default)
|
||||
name_param = params_by_name["name"]
|
||||
assert name_param.type == "string"
|
||||
assert name_param.default == "Default Product"
|
||||
|
||||
# Check price parameter (number type with number default)
|
||||
price_param = params_by_name["price"]
|
||||
assert price_param.type == "number"
|
||||
assert price_param.default == 0.0
|
||||
|
||||
# Check available parameter (boolean type with boolean default)
|
||||
available_param = params_by_name["available"]
|
||||
assert available_param.type == "boolean"
|
||||
assert available_param.default is True
|
||||
@@ -0,0 +1,481 @@
|
||||
import json
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytz
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_dict, safe_json_value
|
||||
|
||||
|
||||
class TestSafeJsonValue:
|
||||
"""Test suite for safe_json_value function to ensure proper serialization of complex types"""
|
||||
|
||||
def test_datetime_conversion(self):
|
||||
"""Test datetime conversion with timezone handling"""
|
||||
# Test datetime with UTC timezone
|
||||
dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC)
|
||||
result = safe_json_value(dt)
|
||||
assert isinstance(result, str)
|
||||
assert "2024-01-01T12:00:00+00:00" in result
|
||||
|
||||
# Test datetime without timezone (should default to UTC)
|
||||
dt_no_tz = datetime(2024, 1, 1, 12, 0, 0)
|
||||
result = safe_json_value(dt_no_tz)
|
||||
assert isinstance(result, str)
|
||||
# The exact time will depend on the system's timezone, so we check the format
|
||||
assert "T" in result # ISO format separator
|
||||
# Check that it's a valid ISO format datetime string
|
||||
assert len(result) >= 19 # At least YYYY-MM-DDTHH:MM:SS
|
||||
|
||||
def test_date_conversion(self):
|
||||
"""Test date conversion to ISO format"""
|
||||
test_date = date(2024, 1, 1)
|
||||
result = safe_json_value(test_date)
|
||||
assert result == "2024-01-01"
|
||||
|
||||
def test_uuid_conversion(self):
|
||||
"""Test UUID conversion to string"""
|
||||
test_uuid = uuid4()
|
||||
result = safe_json_value(test_uuid)
|
||||
assert isinstance(result, str)
|
||||
assert result == str(test_uuid)
|
||||
|
||||
def test_decimal_conversion(self):
|
||||
"""Test Decimal conversion to float"""
|
||||
test_decimal = Decimal("123.456")
|
||||
result = safe_json_value(test_decimal)
|
||||
assert result == 123.456
|
||||
assert isinstance(result, float)
|
||||
|
||||
def test_bytes_conversion(self):
|
||||
"""Test bytes conversion with UTF-8 decoding"""
|
||||
# Test valid UTF-8 bytes
|
||||
test_bytes = b"Hello, World!"
|
||||
result = safe_json_value(test_bytes)
|
||||
assert result == "Hello, World!"
|
||||
|
||||
# Test invalid UTF-8 bytes (should fall back to hex)
|
||||
invalid_bytes = b"\xff\xfe\xfd"
|
||||
result = safe_json_value(invalid_bytes)
|
||||
assert result == "fffefd"
|
||||
|
||||
def test_memoryview_conversion(self):
|
||||
"""Test memoryview conversion to hex string"""
|
||||
test_bytes = b"test data"
|
||||
test_memoryview = memoryview(test_bytes)
|
||||
result = safe_json_value(test_memoryview)
|
||||
assert result == "746573742064617461" # hex of "test data"
|
||||
|
||||
def test_numpy_ndarray_conversion(self):
|
||||
"""Test numpy ndarray conversion to list"""
|
||||
# Test 1D array
|
||||
test_array = np.array([1, 2, 3, 4])
|
||||
result = safe_json_value(test_array)
|
||||
assert result == [1, 2, 3, 4]
|
||||
|
||||
# Test 2D array
|
||||
test_2d_array = np.array([[1, 2], [3, 4]])
|
||||
result = safe_json_value(test_2d_array)
|
||||
assert result == [[1, 2], [3, 4]]
|
||||
|
||||
# Test array with float values
|
||||
test_float_array = np.array([1.5, 2.7, 3.14])
|
||||
result = safe_json_value(test_float_array)
|
||||
assert result == [1.5, 2.7, 3.14]
|
||||
|
||||
def test_dict_conversion(self):
|
||||
"""Test dictionary conversion using safe_json_dict"""
|
||||
test_dict = {
|
||||
"string": "value",
|
||||
"number": 42,
|
||||
"float": 3.14,
|
||||
"boolean": True,
|
||||
"list": [1, 2, 3],
|
||||
"nested": {"key": "value"},
|
||||
}
|
||||
result = safe_json_value(test_dict)
|
||||
assert isinstance(result, dict)
|
||||
assert result == test_dict
|
||||
|
||||
def test_list_conversion(self):
|
||||
"""Test list conversion with mixed types"""
|
||||
test_list = [
|
||||
"string",
|
||||
42,
|
||||
3.14,
|
||||
True,
|
||||
[1, 2, 3],
|
||||
{"key": "value"},
|
||||
datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC),
|
||||
Decimal("123.456"),
|
||||
uuid4(),
|
||||
]
|
||||
result = safe_json_value(test_list)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == len(test_list)
|
||||
assert isinstance(result[6], str) # datetime should be converted to string
|
||||
assert isinstance(result[7], float) # Decimal should be converted to float
|
||||
assert isinstance(result[8], str) # UUID should be converted to string
|
||||
|
||||
def test_tuple_conversion(self):
|
||||
"""Test tuple conversion to list"""
|
||||
test_tuple = (1, "string", 3.14)
|
||||
result = safe_json_value(test_tuple)
|
||||
assert isinstance(result, list)
|
||||
assert result == [1, "string", 3.14]
|
||||
|
||||
def test_set_conversion(self):
|
||||
"""Test set conversion to list"""
|
||||
test_set = {1, "string", 3.14}
|
||||
result = safe_json_value(test_set)
|
||||
assert isinstance(result, list)
|
||||
# Note: set order is not guaranteed, so we check length and content
|
||||
assert len(result) == 3
|
||||
assert 1 in result
|
||||
assert "string" in result
|
||||
assert 3.14 in result
|
||||
|
||||
def test_basic_types_passthrough(self):
|
||||
"""Test that basic types are passed through unchanged"""
|
||||
assert safe_json_value("string") == "string"
|
||||
assert safe_json_value(42) == 42
|
||||
assert safe_json_value(3.14) == 3.14
|
||||
assert safe_json_value(True) is True
|
||||
assert safe_json_value(False) is False
|
||||
assert safe_json_value(None) is None
|
||||
|
||||
def test_nested_complex_structure(self):
|
||||
"""Test complex nested structure with all types"""
|
||||
complex_data = {
|
||||
"dates": [date(2024, 1, 1), date(2024, 1, 2)],
|
||||
"timestamps": [
|
||||
datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC),
|
||||
datetime(2024, 1, 2, 12, 0, 0, tzinfo=pytz.UTC),
|
||||
],
|
||||
"numbers": [Decimal("123.456"), Decimal("789.012")],
|
||||
"identifiers": [uuid4(), uuid4()],
|
||||
"binary_data": [b"hello", b"world"],
|
||||
"arrays": [np.array([1, 2, 3]), np.array([4, 5, 6])],
|
||||
}
|
||||
|
||||
result = safe_json_value(complex_data)
|
||||
|
||||
# Verify structure is maintained
|
||||
assert isinstance(result, dict)
|
||||
assert "dates" in result
|
||||
assert "timestamps" in result
|
||||
assert "numbers" in result
|
||||
assert "identifiers" in result
|
||||
assert "binary_data" in result
|
||||
assert "arrays" in result
|
||||
|
||||
# Verify conversions
|
||||
assert all(isinstance(d, str) for d in result["dates"])
|
||||
assert all(isinstance(t, str) for t in result["timestamps"])
|
||||
assert all(isinstance(n, float) for n in result["numbers"])
|
||||
assert all(isinstance(i, str) for i in result["identifiers"])
|
||||
assert all(isinstance(b, str) for b in result["binary_data"])
|
||||
assert all(isinstance(a, list) for a in result["arrays"])
|
||||
|
||||
|
||||
class TestSafeJsonDict:
|
||||
"""Test suite for safe_json_dict function"""
|
||||
|
||||
def test_valid_dict_conversion(self):
|
||||
"""Test valid dictionary conversion"""
|
||||
test_dict = {
|
||||
"string": "value",
|
||||
"number": 42,
|
||||
"datetime": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC),
|
||||
"decimal": Decimal("123.456"),
|
||||
}
|
||||
result = safe_json_dict(test_dict)
|
||||
assert isinstance(result, dict)
|
||||
assert result["string"] == "value"
|
||||
assert result["number"] == 42
|
||||
assert isinstance(result["datetime"], str)
|
||||
assert isinstance(result["decimal"], float)
|
||||
|
||||
def test_invalid_input_type(self):
|
||||
"""Test that invalid input types raise TypeError"""
|
||||
with pytest.raises(TypeError, match="safe_json_dict\\(\\) expects a dictionary \\(dict\\) as input"):
|
||||
safe_json_dict("not a dict")
|
||||
|
||||
with pytest.raises(TypeError, match="safe_json_dict\\(\\) expects a dictionary \\(dict\\) as input"):
|
||||
safe_json_dict([1, 2, 3])
|
||||
|
||||
with pytest.raises(TypeError, match="safe_json_dict\\(\\) expects a dictionary \\(dict\\) as input"):
|
||||
safe_json_dict(42)
|
||||
|
||||
def test_empty_dict(self):
|
||||
"""Test empty dictionary handling"""
|
||||
result = safe_json_dict({})
|
||||
assert result == {}
|
||||
|
||||
def test_nested_dict_conversion(self):
|
||||
"""Test nested dictionary conversion"""
|
||||
test_dict = {
|
||||
"level1": {
|
||||
"level2": {"datetime": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), "decimal": Decimal("123.456")}
|
||||
}
|
||||
}
|
||||
result = safe_json_dict(test_dict)
|
||||
assert isinstance(result["level1"]["level2"]["datetime"], str)
|
||||
assert isinstance(result["level1"]["level2"]["decimal"], float)
|
||||
|
||||
|
||||
class TestToolInvokeMessageJsonSerialization:
|
||||
"""Test suite for ToolInvokeMessage JSON serialization through safe_json_value"""
|
||||
|
||||
def test_json_message_serialization(self):
|
||||
"""Test JSON message serialization with complex data"""
|
||||
complex_data = {
|
||||
"timestamp": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC),
|
||||
"amount": Decimal("123.45"),
|
||||
"id": uuid4(),
|
||||
"binary": b"test data",
|
||||
"array": np.array([1, 2, 3]),
|
||||
}
|
||||
|
||||
# Create JSON message
|
||||
json_message = ToolInvokeMessage.JsonMessage(json_object=complex_data)
|
||||
message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message)
|
||||
|
||||
# Apply safe_json_value transformation
|
||||
transformed_data = safe_json_value(message.message.json_object)
|
||||
|
||||
# Verify transformations
|
||||
assert isinstance(transformed_data["timestamp"], str)
|
||||
assert isinstance(transformed_data["amount"], float)
|
||||
assert isinstance(transformed_data["id"], str)
|
||||
assert isinstance(transformed_data["binary"], str)
|
||||
assert isinstance(transformed_data["array"], list)
|
||||
|
||||
# Verify JSON serialization works
|
||||
json_string = json.dumps(transformed_data, ensure_ascii=False)
|
||||
assert isinstance(json_string, str)
|
||||
|
||||
# Verify we can deserialize back
|
||||
deserialized = json.loads(json_string)
|
||||
assert deserialized["amount"] == 123.45
|
||||
assert deserialized["array"] == [1, 2, 3]
|
||||
|
||||
def test_json_message_with_nested_structures(self):
|
||||
"""Test JSON message with deeply nested complex structures"""
|
||||
nested_data = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": {
|
||||
"dates": [date(2024, 1, 1), date(2024, 1, 2)],
|
||||
"timestamps": [datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC)],
|
||||
"numbers": [Decimal("1.1"), Decimal("2.2")],
|
||||
"arrays": [np.array([1, 2]), np.array([3, 4])],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
json_message = ToolInvokeMessage.JsonMessage(json_object=nested_data)
|
||||
message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message)
|
||||
|
||||
# Transform the data
|
||||
transformed_data = safe_json_value(message.message.json_object)
|
||||
|
||||
# Verify nested transformations
|
||||
level3 = transformed_data["level1"]["level2"]["level3"]
|
||||
assert all(isinstance(d, str) for d in level3["dates"])
|
||||
assert all(isinstance(t, str) for t in level3["timestamps"])
|
||||
assert all(isinstance(n, float) for n in level3["numbers"])
|
||||
assert all(isinstance(a, list) for a in level3["arrays"])
|
||||
|
||||
# Test JSON serialization
|
||||
json_string = json.dumps(transformed_data, ensure_ascii=False)
|
||||
assert isinstance(json_string, str)
|
||||
|
||||
# Verify deserialization
|
||||
deserialized = json.loads(json_string)
|
||||
assert deserialized["level1"]["level2"]["level3"]["numbers"] == [1.1, 2.2]
|
||||
|
||||
def test_json_message_transformer_integration(self):
|
||||
"""Test integration with ToolFileMessageTransformer for JSON messages"""
|
||||
complex_data = {
|
||||
"metadata": {
|
||||
"created_at": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC),
|
||||
"version": Decimal("1.0"),
|
||||
"tags": ["tag1", "tag2"],
|
||||
},
|
||||
"data": {"values": np.array([1.1, 2.2, 3.3]), "binary": b"binary content"},
|
||||
}
|
||||
|
||||
# Create message generator
|
||||
def message_generator():
|
||||
json_message = ToolInvokeMessage.JsonMessage(json_object=complex_data)
|
||||
message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message)
|
||||
yield message
|
||||
|
||||
# Transform messages
|
||||
transformed_messages = list(
|
||||
ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
message_generator(), user_id="test_user", tenant_id="test_tenant"
|
||||
)
|
||||
)
|
||||
|
||||
assert len(transformed_messages) == 1
|
||||
transformed_message = transformed_messages[0]
|
||||
assert transformed_message.type == ToolInvokeMessage.MessageType.JSON
|
||||
|
||||
# Verify the JSON object was transformed
|
||||
json_obj = transformed_message.message.json_object
|
||||
assert isinstance(json_obj["metadata"]["created_at"], str)
|
||||
assert isinstance(json_obj["metadata"]["version"], float)
|
||||
assert isinstance(json_obj["data"]["values"], list)
|
||||
assert isinstance(json_obj["data"]["binary"], str)
|
||||
|
||||
# Test final JSON serialization
|
||||
final_json = json.dumps(json_obj, ensure_ascii=False)
|
||||
assert isinstance(final_json, str)
|
||||
|
||||
# Verify we can deserialize
|
||||
deserialized = json.loads(final_json)
|
||||
assert deserialized["metadata"]["version"] == 1.0
|
||||
assert deserialized["data"]["values"] == [1.1, 2.2, 3.3]
|
||||
|
||||
def test_edge_cases_and_error_handling(self):
|
||||
"""Test edge cases and error handling in JSON serialization"""
|
||||
# Test with None values
|
||||
data_with_none = {"null_value": None, "empty_string": "", "zero": 0, "false_value": False}
|
||||
|
||||
json_message = ToolInvokeMessage.JsonMessage(json_object=data_with_none)
|
||||
message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message)
|
||||
|
||||
transformed_data = safe_json_value(message.message.json_object)
|
||||
json_string = json.dumps(transformed_data, ensure_ascii=False)
|
||||
|
||||
# Verify serialization works with edge cases
|
||||
assert json_string is not None
|
||||
deserialized = json.loads(json_string)
|
||||
assert deserialized["null_value"] is None
|
||||
assert deserialized["empty_string"] == ""
|
||||
assert deserialized["zero"] == 0
|
||||
assert deserialized["false_value"] is False
|
||||
|
||||
# Test with very large numbers
|
||||
large_data = {
|
||||
"large_int": 2**63 - 1,
|
||||
"large_float": 1.7976931348623157e308,
|
||||
"small_float": 2.2250738585072014e-308,
|
||||
}
|
||||
|
||||
json_message = ToolInvokeMessage.JsonMessage(json_object=large_data)
|
||||
message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message)
|
||||
|
||||
transformed_data = safe_json_value(message.message.json_object)
|
||||
json_string = json.dumps(transformed_data, ensure_ascii=False)
|
||||
|
||||
# Verify large numbers are handled correctly
|
||||
deserialized = json.loads(json_string)
|
||||
assert deserialized["large_int"] == 2**63 - 1
|
||||
assert deserialized["large_float"] == 1.7976931348623157e308
|
||||
assert deserialized["small_float"] == 2.2250738585072014e-308
|
||||
|
||||
|
||||
class TestEndToEndSerialization:
|
||||
"""Test suite for end-to-end serialization workflow"""
|
||||
|
||||
def test_complete_workflow_with_real_data(self):
|
||||
"""Test complete workflow from complex data to JSON string and back"""
|
||||
# Simulate real-world complex data structure
|
||||
real_world_data = {
|
||||
"user_profile": {
|
||||
"id": uuid4(),
|
||||
"name": "John Doe",
|
||||
"email": "john@example.com",
|
||||
"created_at": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC),
|
||||
"last_login": datetime(2024, 1, 15, 14, 30, 0, tzinfo=pytz.UTC),
|
||||
"preferences": {"theme": "dark", "language": "en", "timezone": "UTC"},
|
||||
},
|
||||
"analytics": {
|
||||
"session_count": 42,
|
||||
"total_time": Decimal("123.45"),
|
||||
"metrics": np.array([1.1, 2.2, 3.3, 4.4, 5.5]),
|
||||
"events": [
|
||||
{
|
||||
"timestamp": datetime(2024, 1, 1, 10, 0, 0, tzinfo=pytz.UTC),
|
||||
"action": "login",
|
||||
"duration": Decimal("5.67"),
|
||||
},
|
||||
{
|
||||
"timestamp": datetime(2024, 1, 1, 11, 0, 0, tzinfo=pytz.UTC),
|
||||
"action": "logout",
|
||||
"duration": Decimal("3600.0"),
|
||||
},
|
||||
],
|
||||
},
|
||||
"files": [
|
||||
{
|
||||
"id": uuid4(),
|
||||
"name": "document.pdf",
|
||||
"size": 1024,
|
||||
"uploaded_at": datetime(2024, 1, 1, 9, 0, 0, tzinfo=pytz.UTC),
|
||||
"checksum": b"abc123def456",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Step 1: Create ToolInvokeMessage
|
||||
json_message = ToolInvokeMessage.JsonMessage(json_object=real_world_data)
|
||||
message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message)
|
||||
|
||||
# Step 2: Apply safe_json_value transformation
|
||||
transformed_data = safe_json_value(message.message.json_object)
|
||||
|
||||
# Step 3: Serialize to JSON string
|
||||
json_string = json.dumps(transformed_data, ensure_ascii=False)
|
||||
|
||||
# Step 4: Verify the string is valid JSON
|
||||
assert isinstance(json_string, str)
|
||||
assert json_string.startswith("{")
|
||||
assert json_string.endswith("}")
|
||||
|
||||
# Step 5: Deserialize back to Python object
|
||||
deserialized_data = json.loads(json_string)
|
||||
|
||||
# Step 6: Verify data integrity
|
||||
assert deserialized_data["user_profile"]["name"] == "John Doe"
|
||||
assert deserialized_data["user_profile"]["email"] == "john@example.com"
|
||||
assert isinstance(deserialized_data["user_profile"]["created_at"], str)
|
||||
assert isinstance(deserialized_data["analytics"]["total_time"], float)
|
||||
assert deserialized_data["analytics"]["total_time"] == 123.45
|
||||
assert isinstance(deserialized_data["analytics"]["metrics"], list)
|
||||
assert deserialized_data["analytics"]["metrics"] == [1.1, 2.2, 3.3, 4.4, 5.5]
|
||||
assert isinstance(deserialized_data["files"][0]["checksum"], str)
|
||||
|
||||
# Step 7: Verify all complex types were properly converted
|
||||
self._verify_all_complex_types_converted(deserialized_data)
|
||||
|
||||
def _verify_all_complex_types_converted(self, data):
|
||||
"""Helper method to verify all complex types were properly converted"""
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if key in ["id", "checksum"]:
|
||||
# These should be strings (UUID/bytes converted)
|
||||
assert isinstance(value, str)
|
||||
elif key in ["created_at", "last_login", "timestamp", "uploaded_at"]:
|
||||
# These should be strings (datetime converted)
|
||||
assert isinstance(value, str)
|
||||
elif key in ["total_time", "duration"]:
|
||||
# These should be floats (Decimal converted)
|
||||
assert isinstance(value, float)
|
||||
elif key == "metrics":
|
||||
# This should be a list (ndarray converted)
|
||||
assert isinstance(value, list)
|
||||
else:
|
||||
# Recursively check nested structures
|
||||
self._verify_all_complex_types_converted(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
self._verify_all_complex_types_converted(item)
|
||||
@@ -0,0 +1,312 @@
|
||||
import pytest
|
||||
|
||||
from core.tools.utils.web_reader_tool import (
|
||||
extract_using_readabilipy,
|
||||
get_image_upload_file_ids,
|
||||
get_url,
|
||||
page_result,
|
||||
)
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
"""Minimal fake response object for ssrf_proxy / cloudscraper."""
|
||||
|
||||
def __init__(self, *, status_code=200, headers=None, content=b"", text=""):
|
||||
self.status_code = status_code
|
||||
self.headers = headers or {}
|
||||
self.content = content
|
||||
self.text = text or content.decode("utf-8", errors="ignore")
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Tests: page_result
|
||||
# ---------------------------
|
||||
@pytest.mark.parametrize(
|
||||
("text", "cursor", "maxlen", "expected"),
|
||||
[
|
||||
("abcdef", 0, 3, "abc"),
|
||||
("abcdef", 2, 10, "cdef"), # maxlen beyond end
|
||||
("abcdef", 6, 5, ""), # cursor at end
|
||||
("abcdef", 7, 5, ""), # cursor beyond end
|
||||
("", 0, 5, ""), # empty text
|
||||
],
|
||||
)
|
||||
def test_page_result(text, cursor, maxlen, expected):
|
||||
assert page_result(text, cursor, maxlen) == expected
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Tests: get_url
|
||||
# ---------------------------
|
||||
@pytest.fixture
|
||||
def stub_support_types(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Stub supported content types list."""
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
# e.g. binary types supported by ExtractProcessor
|
||||
monkeypatch.setattr(mod.extract_processor, "SUPPORT_URL_CONTENT_TYPES", ["application/pdf", "text/plain"])
|
||||
return mod
|
||||
|
||||
|
||||
def test_get_url_unsupported_content_type(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
# HEAD 200 but content-type not supported and not text/html
|
||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||
return FakeResponse(
|
||||
status_code=200,
|
||||
headers={"Content-Type": "image/png"}, # not supported
|
||||
)
|
||||
|
||||
monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head)
|
||||
|
||||
result = get_url("https://x.test/file.png")
|
||||
assert result == "Unsupported content-type [image/png] of URL."
|
||||
|
||||
|
||||
def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""
|
||||
When content-type is in SUPPORT_URL_CONTENT_TYPES,
|
||||
should call ExtractProcessor.load_from_url and return its text.
|
||||
"""
|
||||
calls = {"load": 0}
|
||||
|
||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||
return FakeResponse(
|
||||
status_code=200,
|
||||
headers={"Content-Type": "application/pdf"},
|
||||
)
|
||||
|
||||
def fake_load_from_url(url, return_text=False):
|
||||
calls["load"] += 1
|
||||
assert return_text is True
|
||||
return "PDF extracted text"
|
||||
|
||||
monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(stub_support_types.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url))
|
||||
|
||||
result = get_url("https://x.test/doc.pdf")
|
||||
assert calls["load"] == 1
|
||||
assert result == "PDF extracted text"
|
||||
|
||||
|
||||
def test_get_url_html_flow_with_chardet_and_readability(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""200 + text/html → GET, chardet detects encoding, readability returns article which is templated."""
|
||||
|
||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||
return FakeResponse(status_code=200, headers={"Content-Type": "text/html"})
|
||||
|
||||
def fake_get(url, headers=None, follow_redirects=True, timeout=None):
|
||||
html = b"<html><head><title>x</title></head><body>hello</body></html>"
|
||||
return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}, content=html)
|
||||
|
||||
# chardet.detect returns utf-8
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
|
||||
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
|
||||
|
||||
# readability → a dict that maps to Article, then FULL_TEMPLATE
|
||||
def fake_simple_json_from_html_string(html, use_readability=True):
|
||||
return {
|
||||
"title": "My Title",
|
||||
"byline": "Bob",
|
||||
"plain_text": [{"type": "text", "text": "Hello world"}],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(mod, "simple_json_from_html_string", fake_simple_json_from_html_string)
|
||||
|
||||
out = get_url("https://x.test/page")
|
||||
assert "TITLE: My Title" in out
|
||||
assert "AUTHOR: Bob" in out
|
||||
assert "Hello world" in out
|
||||
|
||||
|
||||
def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""If readability returns no text, should return empty string."""
|
||||
|
||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||
return FakeResponse(status_code=200, headers={"Content-Type": "text/html"})
|
||||
|
||||
def fake_get(url, headers=None, follow_redirects=True, timeout=None):
|
||||
return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}, content=b"<html/>")
|
||||
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
|
||||
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
|
||||
# readability returns empty plain_text
|
||||
monkeypatch.setattr(mod, "simple_json_from_html_string", lambda html, use_readability=True: {"plain_text": []})
|
||||
|
||||
out = get_url("https://x.test/empty")
|
||||
assert out == ""
|
||||
|
||||
|
||||
def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""HEAD 403 → use cloudscraper.get via ssrf_proxy.make_request, then proceed."""
|
||||
|
||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||
return FakeResponse(status_code=403, headers={})
|
||||
|
||||
# cloudscraper.create_scraper() → object with .get()
|
||||
class FakeScraper:
|
||||
def __init__(self):
|
||||
pass # removed unused attribute
|
||||
|
||||
def get(self, url, headers=None, follow_redirects=True, timeout=None):
|
||||
# mimic html 200
|
||||
html = b"<html><body>hi</body></html>"
|
||||
return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}, content=html)
|
||||
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.cloudscraper, "create_scraper", lambda: FakeScraper())
|
||||
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
|
||||
monkeypatch.setattr(
|
||||
mod,
|
||||
"simple_json_from_html_string",
|
||||
lambda html, use_readability=True: {"title": "T", "byline": "A", "plain_text": [{"type": "text", "text": "X"}]},
|
||||
)
|
||||
|
||||
out = get_url("https://x.test/403")
|
||||
assert "TITLE: T" in out
|
||||
assert "AUTHOR: A" in out
|
||||
assert "X" in out
|
||||
|
||||
|
||||
def test_get_url_head_non_200_returns_status(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""HEAD returns non-200 and non-403 → should directly return code message."""
|
||||
|
||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||
return FakeResponse(status_code=500)
|
||||
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
|
||||
out = get_url("https://x.test/fail")
|
||||
assert out == "URL returned status code 500."
|
||||
|
||||
|
||||
def test_get_url_content_disposition_filename_detection(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""
|
||||
If HEAD 200 with no Content-Type but Content-Disposition filename suggests a supported type,
|
||||
it should route to ExtractProcessor.load_from_url.
|
||||
"""
|
||||
calls = {"load": 0}
|
||||
|
||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||
return FakeResponse(status_code=200, headers={"Content-Disposition": 'attachment; filename="doc.pdf"'})
|
||||
|
||||
def fake_load_from_url(url, return_text=False):
|
||||
calls["load"] += 1
|
||||
return "From ExtractProcessor via filename"
|
||||
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url))
|
||||
|
||||
out = get_url("https://x.test/fname")
|
||||
assert calls["load"] == 1
|
||||
assert out == "From ExtractProcessor via filename"
|
||||
|
||||
|
||||
def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""
|
||||
If chardet returns an encoding but content.decode raises, should fallback to response.text.
|
||||
"""
|
||||
|
||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||
return FakeResponse(status_code=200, headers={"Content-Type": "text/html"})
|
||||
|
||||
# Return bytes that will raise with the chosen encoding
|
||||
def fake_get(url, headers=None, follow_redirects=True, timeout=None):
|
||||
return FakeResponse(
|
||||
status_code=200,
|
||||
headers={"Content-Type": "text/html"},
|
||||
content=b"\xff\xfe\xfa", # likely to fail under utf-8
|
||||
text="<html>fallback text</html>",
|
||||
)
|
||||
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head)
|
||||
monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get)
|
||||
monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"})
|
||||
monkeypatch.setattr(
|
||||
mod,
|
||||
"simple_json_from_html_string",
|
||||
lambda html, use_readability=True: {"title": "", "byline": "", "plain_text": [{"type": "text", "text": "ok"}]},
|
||||
)
|
||||
|
||||
out = get_url("https://x.test/enc-fallback")
|
||||
assert "ok" in out
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Tests: extract_using_readabilipy
|
||||
# ---------------------------
|
||||
|
||||
|
||||
def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch: pytest.MonkeyPatch):
|
||||
# stub readabilipy.simple_json_from_html_string
|
||||
def fake_simple_json_from_html_string(html, use_readability=True):
|
||||
return {
|
||||
"title": "Hello",
|
||||
"byline": "Alice",
|
||||
"plain_text": [{"type": "text", "text": "world"}],
|
||||
}
|
||||
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
monkeypatch.setattr(mod, "simple_json_from_html_string", fake_simple_json_from_html_string)
|
||||
|
||||
article = extract_using_readabilipy("<html>...</html>")
|
||||
assert article.title == "Hello"
|
||||
assert article.author == "Alice"
|
||||
assert isinstance(article.text, list)
|
||||
assert article.text
|
||||
assert article.text[0]["text"] == "world"
|
||||
|
||||
|
||||
def test_extract_using_readabilipy_defaults_when_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
def fake_simple_json_from_html_string(html, use_readability=True):
|
||||
return {} # all missing
|
||||
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
monkeypatch.setattr(mod, "simple_json_from_html_string", fake_simple_json_from_html_string)
|
||||
|
||||
article = extract_using_readabilipy("<html>...</html>")
|
||||
assert article.title == ""
|
||||
assert article.author == ""
|
||||
assert article.text == []
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Tests: get_image_upload_file_ids
|
||||
# ---------------------------
|
||||
def test_get_image_upload_file_ids():
|
||||
# should extract id from https + file-preview
|
||||
content = ""
|
||||
assert get_image_upload_file_ids(content) == ["abc123"]
|
||||
|
||||
# should extract id from http + image-preview
|
||||
content = ""
|
||||
assert get_image_upload_file_ids(content) == ["xyz789"]
|
||||
|
||||
# should not match invalid scheme 'htt://'
|
||||
content = ""
|
||||
assert get_image_upload_file_ids(content) == []
|
||||
|
||||
# should extract multiple ids in order
|
||||
content = """
|
||||
some text
|
||||

|
||||
middle
|
||||

|
||||
end
|
||||
"""
|
||||
assert get_image_upload_file_ids(content) == ["id1", "id2"]
|
||||
@@ -0,0 +1,53 @@
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
|
||||
|
||||
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
|
||||
`WorkflowAppGenerator.generate` returns a result with `error` key inside
|
||||
the `data` element.
|
||||
"""
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
# needs to patch those methods to avoid database access.
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
|
||||
# Mock user resolution to avoid database access
|
||||
from unittest.mock import Mock
|
||||
|
||||
mock_user = Mock()
|
||||
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
|
||||
|
||||
# replace `WorkflowAppGenerator.generate` 's return value.
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
|
||||
lambda *args, **kwargs: {"data": {"error": "oops"}},
|
||||
)
|
||||
|
||||
with pytest.raises(ToolInvokeError) as exc_info:
|
||||
# WorkflowTool always returns a generator, so we need to iterate to
|
||||
# actually `run` the tool.
|
||||
list(tool.invoke("test_user", {}))
|
||||
assert exc_info.value.args == ("oops",)
|
||||
382
dify/api/tests/unit_tests/core/variables/test_segment.py
Normal file
382
dify/api/tests/unit_tests/core/variables/test_segment.py
Normal file
@@ -0,0 +1,382 @@
|
||||
import dataclasses
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.helper import encrypter
|
||||
from core.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
SegmentUnion,
|
||||
StringSegment,
|
||||
get_segment_discriminator,
|
||||
)
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import (
|
||||
ArrayAnyVariable,
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
NoneVariable,
|
||||
ObjectVariable,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
VariableUnion,
|
||||
)
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def test_segment_group_to_text():
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="fake-user-id"),
|
||||
user_inputs={},
|
||||
environment_variables=[
|
||||
SecretVariable(name="secret_key", value="fake-secret-key"),
|
||||
],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
|
||||
template = (
|
||||
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
|
||||
)
|
||||
segments_group = variable_pool.convert_template(template)
|
||||
|
||||
assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key."
|
||||
assert segments_group.log == (
|
||||
f"Hello, fake-user-id! Your query is fake-user-query."
|
||||
f" And your key is {encrypter.obfuscated_token('fake-secret-key')}."
|
||||
)
|
||||
|
||||
|
||||
def test_convert_constant_to_segment_group():
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
template = "Hello, world!"
|
||||
segments_group = variable_pool.convert_template(template)
|
||||
assert segments_group.text == "Hello, world!"
|
||||
assert segments_group.log == "Hello, world!"
|
||||
|
||||
|
||||
def test_convert_variable_to_segment_group():
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="fake-user-id"),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
template = "{{#sys.user_id#}}"
|
||||
segments_group = variable_pool.convert_template(template)
|
||||
assert segments_group.text == "fake-user-id"
|
||||
assert segments_group.log == "fake-user-id"
|
||||
assert isinstance(segments_group.value[0], StringVariable)
|
||||
assert segments_group.value[0].value == "fake-user-id"
|
||||
|
||||
|
||||
class _Segments(BaseModel):
|
||||
segments: list[SegmentUnion]
|
||||
|
||||
|
||||
class _Variables(BaseModel):
|
||||
variables: list[VariableUnion]
|
||||
|
||||
|
||||
def create_test_file(
|
||||
file_type: FileType = FileType.DOCUMENT,
|
||||
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
|
||||
filename: str = "test.txt",
|
||||
extension: str = ".txt",
|
||||
mime_type: str = "text/plain",
|
||||
size: int = 1024,
|
||||
) -> File:
|
||||
"""Factory function to create File objects for testing"""
|
||||
return File(
|
||||
tenant_id="test-tenant",
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
|
||||
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
storage_key="test-storage-key",
|
||||
)
|
||||
|
||||
|
||||
class TestSegmentDumpAndLoad:
|
||||
"""Test suite for segment and variable serialization/deserialization"""
|
||||
|
||||
def test_segments(self):
|
||||
"""Test basic segment serialization compatibility"""
|
||||
model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")])
|
||||
json = model.model_dump_json()
|
||||
loaded = _Segments.model_validate_json(json)
|
||||
assert loaded == model
|
||||
|
||||
def test_segment_number(self):
|
||||
"""Test number segment serialization compatibility"""
|
||||
model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)])
|
||||
json = model.model_dump_json()
|
||||
loaded = _Segments.model_validate_json(json)
|
||||
assert loaded == model
|
||||
|
||||
def test_variables(self):
|
||||
"""Test variable serialization compatibility"""
|
||||
model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")])
|
||||
json = model.model_dump_json()
|
||||
restored = _Variables.model_validate_json(json)
|
||||
assert restored == model
|
||||
|
||||
def test_all_segments_serialization(self):
|
||||
"""Test serialization/deserialization of all segment types"""
|
||||
# Create one instance of each segment type
|
||||
test_file = create_test_file()
|
||||
|
||||
all_segments: list[SegmentUnion] = [
|
||||
NoneSegment(),
|
||||
StringSegment(value="test string"),
|
||||
IntegerSegment(value=42),
|
||||
FloatSegment(value=3.14),
|
||||
ObjectSegment(value={"key": "value", "number": 123}),
|
||||
FileSegment(value=test_file),
|
||||
ArrayAnySegment(value=[1, "string", 3.14, {"key": "value"}]),
|
||||
ArrayStringSegment(value=["hello", "world"]),
|
||||
ArrayNumberSegment(value=[1, 2.5, 3]),
|
||||
ArrayObjectSegment(value=[{"id": 1}, {"id": 2}]),
|
||||
ArrayFileSegment(value=[]), # Empty array to avoid file complexity
|
||||
]
|
||||
|
||||
# Test serialization and deserialization
|
||||
model = _Segments(segments=all_segments)
|
||||
json_str = model.model_dump_json()
|
||||
loaded = _Segments.model_validate_json(json_str)
|
||||
|
||||
# Verify all segments are preserved
|
||||
assert len(loaded.segments) == len(all_segments)
|
||||
|
||||
for original, loaded_segment in zip(all_segments, loaded.segments):
|
||||
assert type(loaded_segment) == type(original)
|
||||
assert loaded_segment.value_type == original.value_type
|
||||
|
||||
# For file segments, compare key properties instead of exact equality
|
||||
if isinstance(original, FileSegment) and isinstance(loaded_segment, FileSegment):
|
||||
orig_file = original.value
|
||||
loaded_file = loaded_segment.value
|
||||
assert isinstance(orig_file, File)
|
||||
assert isinstance(loaded_file, File)
|
||||
assert loaded_file.tenant_id == orig_file.tenant_id
|
||||
assert loaded_file.type == orig_file.type
|
||||
assert loaded_file.filename == orig_file.filename
|
||||
else:
|
||||
assert loaded_segment.value == original.value
|
||||
|
||||
def test_all_variables_serialization(self):
|
||||
"""Test serialization/deserialization of all variable types"""
|
||||
# Create one instance of each variable type
|
||||
test_file = create_test_file()
|
||||
|
||||
all_variables: list[VariableUnion] = [
|
||||
NoneVariable(name="none_var"),
|
||||
StringVariable(value="test string", name="string_var"),
|
||||
IntegerVariable(value=42, name="int_var"),
|
||||
FloatVariable(value=3.14, name="float_var"),
|
||||
ObjectVariable(value={"key": "value", "number": 123}, name="object_var"),
|
||||
FileVariable(value=test_file, name="file_var"),
|
||||
ArrayAnyVariable(value=[1, "string", 3.14, {"key": "value"}], name="array_any_var"),
|
||||
ArrayStringVariable(value=["hello", "world"], name="array_string_var"),
|
||||
ArrayNumberVariable(value=[1, 2.5, 3], name="array_number_var"),
|
||||
ArrayObjectVariable(value=[{"id": 1}, {"id": 2}], name="array_object_var"),
|
||||
ArrayFileVariable(value=[], name="array_file_var"), # Empty array to avoid file complexity
|
||||
]
|
||||
|
||||
# Test serialization and deserialization
|
||||
model = _Variables(variables=all_variables)
|
||||
json_str = model.model_dump_json()
|
||||
loaded = _Variables.model_validate_json(json_str)
|
||||
|
||||
# Verify all variables are preserved
|
||||
assert len(loaded.variables) == len(all_variables)
|
||||
|
||||
for original, loaded_variable in zip(all_variables, loaded.variables):
|
||||
assert type(loaded_variable) == type(original)
|
||||
assert loaded_variable.value_type == original.value_type
|
||||
assert loaded_variable.name == original.name
|
||||
|
||||
# For file variables, compare key properties instead of exact equality
|
||||
if isinstance(original, FileVariable) and isinstance(loaded_variable, FileVariable):
|
||||
orig_file = original.value
|
||||
loaded_file = loaded_variable.value
|
||||
assert isinstance(orig_file, File)
|
||||
assert isinstance(loaded_file, File)
|
||||
assert loaded_file.tenant_id == orig_file.tenant_id
|
||||
assert loaded_file.type == orig_file.type
|
||||
assert loaded_file.filename == orig_file.filename
|
||||
else:
|
||||
assert loaded_variable.value == original.value
|
||||
|
||||
def test_segment_discriminator_function_for_segment_types(self):
|
||||
"""Test the segment discriminator function"""
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestCase:
|
||||
segment: Segment
|
||||
expected_segment_type: SegmentType
|
||||
|
||||
file1 = create_test_file()
|
||||
file2 = create_test_file(filename="test2.txt")
|
||||
|
||||
cases = [
|
||||
TestCase(
|
||||
NoneSegment(),
|
||||
SegmentType.NONE,
|
||||
),
|
||||
TestCase(
|
||||
StringSegment(value=""),
|
||||
SegmentType.STRING,
|
||||
),
|
||||
TestCase(
|
||||
FloatSegment(value=0.0),
|
||||
SegmentType.FLOAT,
|
||||
),
|
||||
TestCase(
|
||||
IntegerSegment(value=0),
|
||||
SegmentType.INTEGER,
|
||||
),
|
||||
TestCase(
|
||||
ObjectSegment(value={}),
|
||||
SegmentType.OBJECT,
|
||||
),
|
||||
TestCase(
|
||||
FileSegment(value=file1),
|
||||
SegmentType.FILE,
|
||||
),
|
||||
TestCase(
|
||||
ArrayAnySegment(value=[0, 0.0, ""]),
|
||||
SegmentType.ARRAY_ANY,
|
||||
),
|
||||
TestCase(
|
||||
ArrayStringSegment(value=[""]),
|
||||
SegmentType.ARRAY_STRING,
|
||||
),
|
||||
TestCase(
|
||||
ArrayNumberSegment(value=[0, 0.0]),
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
),
|
||||
TestCase(
|
||||
ArrayObjectSegment(value=[{}]),
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
),
|
||||
TestCase(
|
||||
ArrayFileSegment(value=[file1, file2]),
|
||||
SegmentType.ARRAY_FILE,
|
||||
),
|
||||
]
|
||||
|
||||
for test_case in cases:
|
||||
segment = test_case.segment
|
||||
assert get_segment_discriminator(segment) == test_case.expected_segment_type, (
|
||||
f"get_segment_discriminator failed for type {type(segment)}"
|
||||
)
|
||||
model_dict = segment.model_dump(mode="json")
|
||||
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
|
||||
f"get_segment_discriminator failed for serialized form of type {type(segment)}"
|
||||
)
|
||||
|
||||
def test_variable_discriminator_function_for_variable_types(self):
|
||||
"""Test the variable discriminator function"""
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestCase:
|
||||
variable: Variable
|
||||
expected_segment_type: SegmentType
|
||||
|
||||
file1 = create_test_file()
|
||||
file2 = create_test_file(filename="test2.txt")
|
||||
|
||||
cases = [
|
||||
TestCase(
|
||||
NoneVariable(name="none_var"),
|
||||
SegmentType.NONE,
|
||||
),
|
||||
TestCase(
|
||||
StringVariable(value="test", name="string_var"),
|
||||
SegmentType.STRING,
|
||||
),
|
||||
TestCase(
|
||||
FloatVariable(value=0.0, name="float_var"),
|
||||
SegmentType.FLOAT,
|
||||
),
|
||||
TestCase(
|
||||
IntegerVariable(value=0, name="int_var"),
|
||||
SegmentType.INTEGER,
|
||||
),
|
||||
TestCase(
|
||||
ObjectVariable(value={}, name="object_var"),
|
||||
SegmentType.OBJECT,
|
||||
),
|
||||
TestCase(
|
||||
FileVariable(value=file1, name="file_var"),
|
||||
SegmentType.FILE,
|
||||
),
|
||||
TestCase(
|
||||
SecretVariable(value="secret", name="secret_var"),
|
||||
SegmentType.SECRET,
|
||||
),
|
||||
TestCase(
|
||||
ArrayAnyVariable(value=[0, 0.0, ""], name="array_any_var"),
|
||||
SegmentType.ARRAY_ANY,
|
||||
),
|
||||
TestCase(
|
||||
ArrayStringVariable(value=[""], name="array_string_var"),
|
||||
SegmentType.ARRAY_STRING,
|
||||
),
|
||||
TestCase(
|
||||
ArrayNumberVariable(value=[0, 0.0], name="array_number_var"),
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
),
|
||||
TestCase(
|
||||
ArrayObjectVariable(value=[{}], name="array_object_var"),
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
),
|
||||
TestCase(
|
||||
ArrayFileVariable(value=[file1, file2], name="array_file_var"),
|
||||
SegmentType.ARRAY_FILE,
|
||||
),
|
||||
]
|
||||
|
||||
for test_case in cases:
|
||||
variable = test_case.variable
|
||||
assert get_segment_discriminator(variable) == test_case.expected_segment_type, (
|
||||
f"get_segment_discriminator failed for type {type(variable)}"
|
||||
)
|
||||
model_dict = variable.model_dump(mode="json")
|
||||
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
|
||||
f"get_segment_discriminator failed for serialized form of type {type(variable)}"
|
||||
)
|
||||
|
||||
def test_invalid_value_for_discriminator(self):
|
||||
# Test invalid cases
|
||||
assert get_segment_discriminator({"value_type": "invalid"}) is None
|
||||
assert get_segment_discriminator({}) is None
|
||||
assert get_segment_discriminator("not_a_dict") is None
|
||||
assert get_segment_discriminator(42) is None
|
||||
assert get_segment_discriminator(object) is None
|
||||
165
dify/api/tests/unit_tests/core/variables/test_segment_type.py
Normal file
165
dify/api/tests/unit_tests/core/variables/test_segment_type.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import pytest
|
||||
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
|
||||
|
||||
class TestSegmentTypeIsArrayType:
|
||||
"""
|
||||
Test class for SegmentType.is_array_type method.
|
||||
|
||||
Provides comprehensive coverage of all SegmentType values to ensure
|
||||
correct identification of array and non-array types.
|
||||
"""
|
||||
|
||||
def test_is_array_type(self):
|
||||
"""
|
||||
Test that all SegmentType enum values are covered in our test cases.
|
||||
|
||||
Ensures comprehensive coverage by verifying that every SegmentType
|
||||
value is tested for the is_array_type method.
|
||||
"""
|
||||
# Arrange
|
||||
expected_array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
expected_non_array_types = [
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.STRING,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
SegmentType.GROUP,
|
||||
SegmentType.BOOLEAN,
|
||||
]
|
||||
|
||||
for seg_type in expected_array_types:
|
||||
assert seg_type.is_array_type()
|
||||
|
||||
for seg_type in expected_non_array_types:
|
||||
assert not seg_type.is_array_type()
|
||||
|
||||
# Act & Assert
|
||||
covered_types = set(expected_array_types) | set(expected_non_array_types)
|
||||
assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests"
|
||||
|
||||
def test_all_enum_values_are_supported(self):
|
||||
"""
|
||||
Test that all enum values are supported and return boolean values.
|
||||
|
||||
Validates that every SegmentType enum value can be processed by
|
||||
is_array_type method and returns a boolean value.
|
||||
"""
|
||||
enum_values: list[SegmentType] = list(SegmentType)
|
||||
for seg_type in enum_values:
|
||||
is_array = seg_type.is_array_type()
|
||||
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"
|
||||
|
||||
|
||||
class TestSegmentTypeIsValidArrayValidation:
|
||||
"""
|
||||
Test SegmentType.is_valid with array types using different validation strategies.
|
||||
"""
|
||||
|
||||
def test_array_validation_all_success(self):
|
||||
value = ["hello", "world", "foo"]
|
||||
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
|
||||
|
||||
def test_array_validation_all_fail(self):
|
||||
value = ["hello", 123, "world"]
|
||||
# Should return False, since 123 is not a string
|
||||
assert not SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
|
||||
|
||||
def test_array_validation_first(self):
|
||||
value = ["hello", 123, None]
|
||||
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST)
|
||||
|
||||
def test_array_validation_none(self):
|
||||
value = [1, 2, 3]
|
||||
# validation is None, skip
|
||||
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE)
|
||||
|
||||
|
||||
class TestSegmentTypeGetZeroValue:
|
||||
"""
|
||||
Test class for SegmentType.get_zero_value static method.
|
||||
|
||||
Provides comprehensive coverage of all supported SegmentType values to ensure
|
||||
correct zero value generation for each type.
|
||||
"""
|
||||
|
||||
def test_array_types_return_empty_list(self):
|
||||
"""Test that all array types return empty list segments."""
|
||||
array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
|
||||
for seg_type in array_types:
|
||||
result = SegmentType.get_zero_value(seg_type)
|
||||
assert result.value == []
|
||||
assert result.value_type == seg_type
|
||||
|
||||
def test_object_returns_empty_dict(self):
|
||||
"""Test that OBJECT type returns empty dictionary segment."""
|
||||
result = SegmentType.get_zero_value(SegmentType.OBJECT)
|
||||
assert result.value == {}
|
||||
assert result.value_type == SegmentType.OBJECT
|
||||
|
||||
def test_string_returns_empty_string(self):
|
||||
"""Test that STRING type returns empty string segment."""
|
||||
result = SegmentType.get_zero_value(SegmentType.STRING)
|
||||
assert result.value == ""
|
||||
assert result.value_type == SegmentType.STRING
|
||||
|
||||
def test_integer_returns_zero(self):
|
||||
"""Test that INTEGER type returns zero segment."""
|
||||
result = SegmentType.get_zero_value(SegmentType.INTEGER)
|
||||
assert result.value == 0
|
||||
assert result.value_type == SegmentType.INTEGER
|
||||
|
||||
def test_float_returns_zero_point_zero(self):
|
||||
"""Test that FLOAT type returns 0.0 segment."""
|
||||
result = SegmentType.get_zero_value(SegmentType.FLOAT)
|
||||
assert result.value == 0.0
|
||||
assert result.value_type == SegmentType.FLOAT
|
||||
|
||||
def test_number_returns_zero(self):
|
||||
"""Test that NUMBER type returns zero segment."""
|
||||
result = SegmentType.get_zero_value(SegmentType.NUMBER)
|
||||
assert result.value == 0
|
||||
# NUMBER type with integer value returns INTEGER segment type
|
||||
# (NUMBER is a union type that can be INTEGER or FLOAT)
|
||||
assert result.value_type == SegmentType.INTEGER
|
||||
# Verify that exposed_type returns NUMBER for frontend compatibility
|
||||
assert result.value_type.exposed_type() == SegmentType.NUMBER
|
||||
|
||||
def test_boolean_returns_false(self):
|
||||
"""Test that BOOLEAN type returns False segment."""
|
||||
result = SegmentType.get_zero_value(SegmentType.BOOLEAN)
|
||||
assert result.value is False
|
||||
assert result.value_type == SegmentType.BOOLEAN
|
||||
|
||||
def test_unsupported_types_raise_value_error(self):
|
||||
"""Test that unsupported types raise ValueError."""
|
||||
unsupported_types = [
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
SegmentType.GROUP,
|
||||
SegmentType.ARRAY_FILE,
|
||||
]
|
||||
|
||||
for seg_type in unsupported_types:
|
||||
with pytest.raises(ValueError, match="unsupported variable type"):
|
||||
SegmentType.get_zero_value(seg_type)
|
||||
@@ -0,0 +1,849 @@
|
||||
"""
|
||||
Comprehensive unit tests for SegmentType.is_valid and SegmentType._validate_array methods.
|
||||
|
||||
This module provides thorough testing of the validation logic for all SegmentType values,
|
||||
including edge cases, error conditions, and different ArrayValidation strategies.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
BooleanSegment,
|
||||
FileSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
|
||||
|
||||
def create_test_file(
|
||||
file_type: FileType = FileType.DOCUMENT,
|
||||
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
|
||||
filename: str = "test.txt",
|
||||
extension: str = ".txt",
|
||||
mime_type: str = "text/plain",
|
||||
size: int = 1024,
|
||||
) -> File:
|
||||
"""Factory function to create File objects for testing."""
|
||||
return File(
|
||||
tenant_id="test-tenant",
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
|
||||
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
storage_key="test-storage-key",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationTestCase:
|
||||
"""Test case data structure for validation tests."""
|
||||
|
||||
segment_type: SegmentType
|
||||
value: Any
|
||||
expected: bool
|
||||
description: str
|
||||
|
||||
def get_id(self):
|
||||
return self.description
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArrayValidationTestCase:
|
||||
"""Test case data structure for array validation tests."""
|
||||
|
||||
segment_type: SegmentType
|
||||
value: Any
|
||||
array_validation: ArrayValidation
|
||||
expected: bool
|
||||
description: str
|
||||
|
||||
def get_id(self):
|
||||
return self.description
|
||||
|
||||
|
||||
# Test data construction functions
|
||||
def get_boolean_cases() -> list[ValidationTestCase]:
|
||||
return [
|
||||
# valid values
|
||||
ValidationTestCase(SegmentType.BOOLEAN, True, True, "True boolean"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, False, True, "False boolean"),
|
||||
# Invalid values
|
||||
ValidationTestCase(SegmentType.BOOLEAN, 1, False, "Integer 1 (not boolean)"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, 0, False, "Integer 0 (not boolean)"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, "true", False, "String 'true'"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, "false", False, "String 'false'"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.BOOLEAN, {}, False, "Empty dict"),
|
||||
]
|
||||
|
||||
|
||||
def get_number_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid number values."""
|
||||
return [
|
||||
# valid values
|
||||
ValidationTestCase(SegmentType.NUMBER, 42, True, "Positive integer"),
|
||||
ValidationTestCase(SegmentType.NUMBER, -42, True, "Negative integer"),
|
||||
ValidationTestCase(SegmentType.NUMBER, 0, True, "Zero integer"),
|
||||
ValidationTestCase(SegmentType.NUMBER, 3.14, True, "Positive float"),
|
||||
ValidationTestCase(SegmentType.NUMBER, -3.14, True, "Negative float"),
|
||||
ValidationTestCase(SegmentType.NUMBER, 0.0, True, "Zero float"),
|
||||
ValidationTestCase(SegmentType.NUMBER, float("inf"), True, "Positive infinity"),
|
||||
ValidationTestCase(SegmentType.NUMBER, float("-inf"), True, "Negative infinity"),
|
||||
ValidationTestCase(SegmentType.NUMBER, float("nan"), True, "float(NaN)"),
|
||||
# invalid number values
|
||||
ValidationTestCase(SegmentType.NUMBER, "42", False, "String number"),
|
||||
ValidationTestCase(SegmentType.NUMBER, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.NUMBER, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.NUMBER, {}, False, "Empty dict"),
|
||||
ValidationTestCase(SegmentType.NUMBER, "3.14", False, "String float"),
|
||||
]
|
||||
|
||||
|
||||
def get_string_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid string values."""
|
||||
return [
|
||||
# valid values
|
||||
ValidationTestCase(SegmentType.STRING, "", True, "Empty string"),
|
||||
ValidationTestCase(SegmentType.STRING, "hello", True, "Simple string"),
|
||||
ValidationTestCase(SegmentType.STRING, "🚀", True, "Unicode emoji"),
|
||||
ValidationTestCase(SegmentType.STRING, "line1\nline2", True, "Multiline string"),
|
||||
# invalid values
|
||||
ValidationTestCase(SegmentType.STRING, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.STRING, 3.14, False, "Float"),
|
||||
ValidationTestCase(SegmentType.STRING, True, False, "Boolean"),
|
||||
ValidationTestCase(SegmentType.STRING, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.STRING, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.STRING, {}, False, "Empty dict"),
|
||||
]
|
||||
|
||||
|
||||
def get_object_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid object values."""
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.OBJECT, {}, True, "Empty dict"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"key": "value"}, True, "Simple dict"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"a": 1, "b": 2}, True, "Dict with numbers"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"nested": {"key": "value"}}, True, "Nested dict"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"list": [1, 2, 3]}, True, "Dict with list"),
|
||||
ValidationTestCase(SegmentType.OBJECT, {"mixed": [1, "two", {"three": 3}]}, True, "Complex dict"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.OBJECT, "not a dict", False, "String"),
|
||||
ValidationTestCase(SegmentType.OBJECT, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.OBJECT, 3.14, False, "Float"),
|
||||
ValidationTestCase(SegmentType.OBJECT, True, False, "Boolean"),
|
||||
ValidationTestCase(SegmentType.OBJECT, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.OBJECT, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.OBJECT, [1, 2, 3], False, "List with values"),
|
||||
]
|
||||
|
||||
|
||||
def get_secret_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid secret values."""
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.SECRET, "", True, "Empty secret"),
|
||||
ValidationTestCase(SegmentType.SECRET, "secret", True, "Simple secret"),
|
||||
ValidationTestCase(SegmentType.SECRET, "api_key_123", True, "API key format"),
|
||||
ValidationTestCase(SegmentType.SECRET, "very_long_secret_key_with_special_chars!@#", True, "Complex secret"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.SECRET, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.SECRET, 3.14, False, "Float"),
|
||||
ValidationTestCase(SegmentType.SECRET, True, False, "Boolean"),
|
||||
ValidationTestCase(SegmentType.SECRET, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.SECRET, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.SECRET, {}, False, "Empty dict"),
|
||||
]
|
||||
|
||||
|
||||
def get_file_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid file values."""
|
||||
test_file = create_test_file()
|
||||
image_file = create_test_file(
|
||||
file_type=FileType.IMAGE, filename="image.jpg", extension=".jpg", mime_type="image/jpeg"
|
||||
)
|
||||
remote_file = create_test_file(
|
||||
transfer_method=FileTransferMethod.REMOTE_URL, filename="remote.pdf", extension=".pdf"
|
||||
)
|
||||
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.FILE, test_file, True, "Document file"),
|
||||
ValidationTestCase(SegmentType.FILE, image_file, True, "Image file"),
|
||||
ValidationTestCase(SegmentType.FILE, remote_file, True, "Remote file"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.FILE, "not a file", False, "String"),
|
||||
ValidationTestCase(SegmentType.FILE, 123, False, "Integer"),
|
||||
ValidationTestCase(SegmentType.FILE, {"filename": "test.txt"}, False, "Dict resembling file"),
|
||||
ValidationTestCase(SegmentType.FILE, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.FILE, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.FILE, True, False, "Boolean"),
|
||||
]
|
||||
|
||||
|
||||
def get_none_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid none values."""
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(SegmentType.NONE, None, True, "None value"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.NONE, "", False, "Empty string"),
|
||||
ValidationTestCase(SegmentType.NONE, 0, False, "Zero integer"),
|
||||
ValidationTestCase(SegmentType.NONE, 0.0, False, "Zero float"),
|
||||
ValidationTestCase(SegmentType.NONE, False, False, "False boolean"),
|
||||
ValidationTestCase(SegmentType.NONE, [], False, "Empty list"),
|
||||
ValidationTestCase(SegmentType.NONE, {}, False, "Empty dict"),
|
||||
ValidationTestCase(SegmentType.NONE, "null", False, "String 'null'"),
|
||||
]
|
||||
|
||||
|
||||
def get_group_cases() -> list[ValidationTestCase]:
|
||||
"""Get test cases for valid group values."""
|
||||
test_file = create_test_file()
|
||||
segments = [
|
||||
StringSegment(value="hello"),
|
||||
IntegerSegment(value=42),
|
||||
BooleanSegment(value=True),
|
||||
ObjectSegment(value={"key": "value"}),
|
||||
FileSegment(value=test_file),
|
||||
NoneSegment(value=None),
|
||||
]
|
||||
|
||||
return [
|
||||
# valid cases
|
||||
ValidationTestCase(
|
||||
SegmentType.GROUP, SegmentGroup(value=segments), True, "Valid SegmentGroup with mixed segments"
|
||||
),
|
||||
ValidationTestCase(
|
||||
SegmentType.GROUP, [StringSegment(value="test"), IntegerSegment(value=123)], True, "List of Segment objects"
|
||||
),
|
||||
ValidationTestCase(SegmentType.GROUP, SegmentGroup(value=[]), True, "Empty SegmentGroup"),
|
||||
ValidationTestCase(SegmentType.GROUP, [], True, "Empty list"),
|
||||
# invalid cases
|
||||
ValidationTestCase(SegmentType.GROUP, "not a list", False, "String value"),
|
||||
ValidationTestCase(SegmentType.GROUP, 123, False, "Integer value"),
|
||||
ValidationTestCase(SegmentType.GROUP, True, False, "Boolean value"),
|
||||
ValidationTestCase(SegmentType.GROUP, None, False, "None value"),
|
||||
ValidationTestCase(SegmentType.GROUP, {"key": "value"}, False, "Dict value"),
|
||||
ValidationTestCase(SegmentType.GROUP, test_file, False, "File value"),
|
||||
ValidationTestCase(SegmentType.GROUP, ["string", 123, True], False, "List with non-Segment objects"),
|
||||
ValidationTestCase(
|
||||
SegmentType.GROUP,
|
||||
[StringSegment(value="test"), "not a segment"],
|
||||
False,
|
||||
"Mixed list with some non-Segment objects",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_any_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_ANY validation."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY,
|
||||
[1, "string", 3.14, {"key": "value"}, True],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Mixed types with NONE validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY,
|
||||
[1, "string", 3.14, {"key": "value"}, True],
|
||||
ArrayValidation.FIRST,
|
||||
True,
|
||||
"Mixed types with FIRST validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY,
|
||||
[1, "string", 3.14, {"key": "value"}, True],
|
||||
ArrayValidation.ALL,
|
||||
True,
|
||||
"Mixed types with ALL validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_ANY, [None, None, None], ArrayValidation.ALL, True, "All None values"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_string_validation_none_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_STRING validation with NONE strategy."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
["hello", "world"],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Valid strings with NONE validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
[123, 456],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Invalid elements with NONE validation",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
["valid", 123, True],
|
||||
ArrayValidation.NONE,
|
||||
True,
|
||||
"Mixed types with NONE validation",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_string_validation_first_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_STRING validation with FIRST strategy."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["hello", "world"], ArrayValidation.FIRST, True, "All valid strings"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
["hello", 123, True],
|
||||
ArrayValidation.FIRST,
|
||||
True,
|
||||
"First valid, others invalid",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING,
|
||||
[123, "hello", "world"],
|
||||
ArrayValidation.FIRST,
|
||||
False,
|
||||
"First invalid, others valid",
|
||||
),
|
||||
ArrayValidationTestCase(SegmentType.ARRAY_STRING, [None, "hello"], ArrayValidation.FIRST, False, "First None"),
|
||||
]
|
||||
|
||||
|
||||
def get_array_string_validation_all_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_STRING validation with ALL strategy."""
|
||||
return [
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["hello", "world", "test"], ArrayValidation.ALL, True, "All valid strings"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["hello", 123, "world"], ArrayValidation.ALL, False, "One invalid element"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, [123, 456, 789], ArrayValidation.ALL, False, "All invalid elements"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_STRING, ["valid", None, "also_valid"], ArrayValidation.ALL, False, "Contains None"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_number_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_NUMBER validation with different strategies."""
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [1, 2.5, 3], ArrayValidation.NONE, True, "Valid numbers with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, ["not", "numbers"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [42, "not a number"], ArrayValidation.FIRST, True, "First valid number"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, ["not a number", 42], ArrayValidation.FIRST, False, "First invalid"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [3.14, 2.71, 1.41], ArrayValidation.FIRST, True, "All valid floats"
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [1, 2, 3, 4.5], ArrayValidation.ALL, True, "All valid numbers"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER, [1, "invalid", 3], ArrayValidation.ALL, False, "One invalid element"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
[float("inf"), float("-inf"), float("nan")],
|
||||
ArrayValidation.ALL,
|
||||
True,
|
||||
"Special float values",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_object_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_OBJECT validation with different strategies."""
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT, [{}, {"key": "value"}], ArrayValidation.NONE, True, "Valid objects with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT, ["not", "objects"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
[{"valid": "object"}, "not an object"],
|
||||
ArrayValidation.FIRST,
|
||||
True,
|
||||
"First valid object",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
["not an object", {"valid": "object"}],
|
||||
ArrayValidation.FIRST,
|
||||
False,
|
||||
"First invalid",
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
[{}, {"a": 1}, {"nested": {"key": "value"}}],
|
||||
ArrayValidation.ALL,
|
||||
True,
|
||||
"All valid objects",
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
[{"valid": "object"}, "invalid", {"another": "object"}],
|
||||
ArrayValidation.ALL,
|
||||
False,
|
||||
"One invalid element",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_file_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_FILE validation with different strategies."""
|
||||
file1 = create_test_file(filename="file1.txt")
|
||||
file2 = create_test_file(filename="file2.txt")
|
||||
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.NONE, True, "Valid files with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, ["not", "files"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, [file1, "not a file"], ArrayValidation.FIRST, True, "First valid file"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, ["not a file", file1], ArrayValidation.FIRST, False, "First invalid"
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(SegmentType.ARRAY_FILE, [file1, file2], ArrayValidation.ALL, True, "All valid files"),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_FILE, [file1, "invalid", file2], ArrayValidation.ALL, False, "One invalid element"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_array_boolean_validation_cases() -> list[ArrayValidationTestCase]:
|
||||
"""Get test cases for ARRAY_BOOLEAN validation with different strategies."""
|
||||
return [
|
||||
# NONE strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, False, True], ArrayValidation.NONE, True, "Valid booleans with NONE"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [1, 0, "true"], ArrayValidation.NONE, True, "Invalid elements with NONE"
|
||||
),
|
||||
# FIRST strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, 1, 0], ArrayValidation.FIRST, True, "First valid boolean"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [1, True, False], ArrayValidation.FIRST, False, "First invalid (integer 1)"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [0, True, False], ArrayValidation.FIRST, False, "First invalid (integer 0)"
|
||||
),
|
||||
# ALL strategy
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, False, True, False], ArrayValidation.ALL, True, "All valid booleans"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN, [True, 1, False], ArrayValidation.ALL, False, "One invalid element (integer)"
|
||||
),
|
||||
ArrayValidationTestCase(
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
[True, "false", False],
|
||||
ArrayValidation.ALL,
|
||||
False,
|
||||
"One invalid element (string)",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestSegmentTypeIsValid:
|
||||
"""Test suite for SegmentType.is_valid method covering all non-array types."""
|
||||
|
||||
@pytest.mark.parametrize("case", get_boolean_cases(), ids=lambda case: case.description)
|
||||
def test_boolean_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_number_cases(), ids=lambda case: case.description)
|
||||
def test_number_validation(self, case: ValidationTestCase):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_string_cases(), ids=lambda case: case.description)
|
||||
def test_string_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_object_cases(), ids=lambda case: case.description)
|
||||
def test_object_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_secret_cases(), ids=lambda case: case.description)
|
||||
def test_secret_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_file_cases(), ids=lambda case: case.description)
|
||||
def test_file_validation(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_none_cases(), ids=lambda case: case.description)
|
||||
def test_none_validation_valid_cases(self, case):
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_group_cases(), ids=lambda case: case.description)
|
||||
def test_group_validation(self, case):
|
||||
"""Test GROUP type validation with various inputs."""
|
||||
assert case.segment_type.is_valid(case.value) == case.expected
|
||||
|
||||
def test_group_validation_edge_cases(self):
|
||||
"""Test GROUP validation edge cases."""
|
||||
test_file = create_test_file()
|
||||
|
||||
# Test with nested SegmentGroups
|
||||
inner_group = SegmentGroup(value=[StringSegment(value="inner"), IntegerSegment(value=42)])
|
||||
outer_group = SegmentGroup(value=[StringSegment(value="outer"), inner_group])
|
||||
assert SegmentType.GROUP.is_valid(outer_group) is True
|
||||
|
||||
# Test with ArrayFileSegment (which is also a Segment)
|
||||
file_segment = FileSegment(value=test_file)
|
||||
array_file_segment = ArrayFileSegment(value=[test_file, test_file])
|
||||
group_with_arrays = SegmentGroup(value=[file_segment, array_file_segment, StringSegment(value="test")])
|
||||
assert SegmentType.GROUP.is_valid(group_with_arrays) is True
|
||||
|
||||
# Test performance with large number of segments
|
||||
large_segment_list = [StringSegment(value=f"item_{i}") for i in range(1000)]
|
||||
large_group = SegmentGroup(value=large_segment_list)
|
||||
assert SegmentType.GROUP.is_valid(large_group) is True
|
||||
|
||||
def test_no_truly_unsupported_segment_types_exist(self):
|
||||
"""Test that all SegmentType enum values are properly handled in is_valid method.
|
||||
|
||||
This test ensures there are no SegmentType values that would raise AssertionError.
|
||||
If this test fails, it means a new SegmentType was added without proper validation support.
|
||||
"""
|
||||
# Test that ALL segment types are handled and don't raise AssertionError
|
||||
all_segment_types = set(SegmentType)
|
||||
|
||||
for segment_type in all_segment_types:
|
||||
# Create a valid test value for each type
|
||||
test_value: Any = None
|
||||
if segment_type == SegmentType.STRING:
|
||||
test_value = "test"
|
||||
elif segment_type in {SegmentType.NUMBER, SegmentType.INTEGER}:
|
||||
test_value = 42
|
||||
elif segment_type == SegmentType.FLOAT:
|
||||
test_value = 3.14
|
||||
elif segment_type == SegmentType.BOOLEAN:
|
||||
test_value = True
|
||||
elif segment_type == SegmentType.OBJECT:
|
||||
test_value = {"key": "value"}
|
||||
elif segment_type == SegmentType.SECRET:
|
||||
test_value = "secret"
|
||||
elif segment_type == SegmentType.FILE:
|
||||
test_value = create_test_file()
|
||||
elif segment_type == SegmentType.NONE:
|
||||
test_value = None
|
||||
elif segment_type == SegmentType.GROUP:
|
||||
test_value = SegmentGroup(value=[StringSegment(value="test")])
|
||||
elif segment_type.is_array_type():
|
||||
test_value = [] # Empty array is valid for all array types
|
||||
else:
|
||||
# If we get here, there's a segment type we don't know how to test
|
||||
# This should prompt us to add validation logic
|
||||
pytest.fail(f"Unknown segment type {segment_type} needs validation logic and test case")
|
||||
|
||||
# This should NOT raise AssertionError
|
||||
try:
|
||||
result = segment_type.is_valid(test_value)
|
||||
assert isinstance(result, bool), f"is_valid should return boolean for {segment_type}"
|
||||
except AssertionError as e:
|
||||
pytest.fail(
|
||||
f"SegmentType.{segment_type.name}.is_valid() raised AssertionError: {e}. "
|
||||
"This segment type needs to be handled in the is_valid method."
|
||||
)
|
||||
|
||||
|
||||
class TestSegmentTypeArrayValidation:
|
||||
"""Test suite for SegmentType._validate_array method and array type validation."""
|
||||
|
||||
def test_array_validation_non_list_values(self):
|
||||
"""Test that non-list values return False for all array types."""
|
||||
array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
|
||||
non_list_values = [
|
||||
"not a list",
|
||||
123,
|
||||
3.14,
|
||||
True,
|
||||
None,
|
||||
{"key": "value"},
|
||||
create_test_file(),
|
||||
]
|
||||
|
||||
for array_type in array_types:
|
||||
for value in non_list_values:
|
||||
assert array_type.is_valid(value) is False, f"{array_type} should reject {type(value).__name__}"
|
||||
|
||||
def test_empty_array_validation(self):
|
||||
"""Test that empty arrays are valid for all array types regardless of validation strategy."""
|
||||
array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
|
||||
validation_strategies = [ArrayValidation.NONE, ArrayValidation.FIRST, ArrayValidation.ALL]
|
||||
|
||||
for array_type in array_types:
|
||||
for strategy in validation_strategies:
|
||||
assert array_type.is_valid([], strategy) is True, (
|
||||
f"{array_type} should accept empty array with {strategy}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_any_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_any_validation(self, case):
|
||||
"""Test ARRAY_ANY validation accepts any list regardless of content."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_string_validation_none_cases(), ids=lambda case: case.description)
|
||||
def test_array_string_validation_with_none_strategy(self, case):
|
||||
"""Test ARRAY_STRING validation with NONE strategy (no element validation)."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_string_validation_first_cases(), ids=lambda case: case.description)
|
||||
def test_array_string_validation_with_first_strategy(self, case):
|
||||
"""Test ARRAY_STRING validation with FIRST strategy (validate first element only)."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_string_validation_all_cases(), ids=lambda case: case.description)
|
||||
def test_array_string_validation_with_all_strategy(self, case):
|
||||
"""Test ARRAY_STRING validation with ALL strategy (validate all elements)."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_number_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_number_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_NUMBER validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_object_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_object_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_OBJECT validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_file_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_file_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_FILE validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
@pytest.mark.parametrize("case", get_array_boolean_validation_cases(), ids=lambda case: case.description)
|
||||
def test_array_boolean_validation_with_different_strategies(self, case):
|
||||
"""Test ARRAY_BOOLEAN validation with different validation strategies."""
|
||||
assert case.segment_type.is_valid(case.value, case.array_validation) == case.expected
|
||||
|
||||
def test_default_array_validation_strategy(self):
|
||||
"""Test that default array validation strategy is FIRST."""
|
||||
# When no array_validation parameter is provided, it should default to FIRST
|
||||
assert SegmentType.ARRAY_STRING.is_valid(["valid", 123]) is False # First element valid
|
||||
assert SegmentType.ARRAY_STRING.is_valid([123, "valid"]) is False # First element invalid
|
||||
|
||||
assert SegmentType.ARRAY_NUMBER.is_valid([42, "invalid"]) is False # First element valid
|
||||
assert SegmentType.ARRAY_NUMBER.is_valid(["invalid", 42]) is False # First element invalid
|
||||
|
||||
def test_array_validation_edge_cases(self):
|
||||
"""Test edge cases for array validation."""
|
||||
# Test with nested arrays (should be invalid for specific array types)
|
||||
nested_array = [["nested", "array"], ["another", "nested"]]
|
||||
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.FIRST) is False
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_array, ArrayValidation.ALL) is False
|
||||
assert SegmentType.ARRAY_ANY.is_valid(nested_array, ArrayValidation.ALL) is True
|
||||
|
||||
# Test with very large arrays (performance consideration)
|
||||
large_valid_array = ["string"] * 1000
|
||||
large_mixed_array = ["string"] * 999 + [123] # Last element invalid
|
||||
|
||||
assert SegmentType.ARRAY_STRING.is_valid(large_valid_array, ArrayValidation.ALL) is True
|
||||
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.ALL) is False
|
||||
assert SegmentType.ARRAY_STRING.is_valid(large_mixed_array, ArrayValidation.FIRST) is True
|
||||
|
||||
|
||||
class TestSegmentTypeValidationIntegration:
|
||||
"""Integration tests for SegmentType validation covering interactions between methods."""
|
||||
|
||||
def test_non_array_types_ignore_array_validation_parameter(self):
|
||||
"""Test that non-array types ignore the array_validation parameter."""
|
||||
non_array_types = [
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
SegmentType.GROUP,
|
||||
]
|
||||
|
||||
for segment_type in non_array_types:
|
||||
# Create appropriate valid value for each type
|
||||
valid_value: Any
|
||||
if segment_type == SegmentType.STRING:
|
||||
valid_value = "test"
|
||||
elif segment_type == SegmentType.NUMBER:
|
||||
valid_value = 42
|
||||
elif segment_type == SegmentType.BOOLEAN:
|
||||
valid_value = True
|
||||
elif segment_type == SegmentType.OBJECT:
|
||||
valid_value = {"key": "value"}
|
||||
elif segment_type == SegmentType.SECRET:
|
||||
valid_value = "secret"
|
||||
elif segment_type == SegmentType.FILE:
|
||||
valid_value = create_test_file()
|
||||
elif segment_type == SegmentType.NONE:
|
||||
valid_value = None
|
||||
elif segment_type == SegmentType.GROUP:
|
||||
valid_value = SegmentGroup(value=[StringSegment(value="test")])
|
||||
else:
|
||||
continue # Skip unsupported types
|
||||
|
||||
# All array validation strategies should give the same result
|
||||
result_none = segment_type.is_valid(valid_value, ArrayValidation.NONE)
|
||||
result_first = segment_type.is_valid(valid_value, ArrayValidation.FIRST)
|
||||
result_all = segment_type.is_valid(valid_value, ArrayValidation.ALL)
|
||||
|
||||
assert result_none == result_first == result_all == True, (
|
||||
f"{segment_type} should ignore array_validation parameter"
|
||||
)
|
||||
|
||||
def test_comprehensive_type_coverage(self):
|
||||
"""Test that all SegmentType enum values are covered in validation tests."""
|
||||
all_segment_types = set(SegmentType)
|
||||
|
||||
# Types that should be handled by is_valid method
|
||||
handled_types = {
|
||||
# Non-array types
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
SegmentType.GROUP,
|
||||
# Array types
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
}
|
||||
|
||||
# Types that are not handled by is_valid (should raise AssertionError)
|
||||
unhandled_types = {
|
||||
SegmentType.INTEGER, # Handled by NUMBER validation logic
|
||||
SegmentType.FLOAT, # Handled by NUMBER validation logic
|
||||
}
|
||||
|
||||
# Verify all types are accounted for
|
||||
assert handled_types | unhandled_types == all_segment_types, "All SegmentType values should be categorized"
|
||||
|
||||
# Test that handled types work correctly
|
||||
for segment_type in handled_types:
|
||||
if segment_type.is_array_type():
|
||||
# Test with empty array (should always be valid)
|
||||
assert segment_type.is_valid([]) is True, f"{segment_type} should accept empty array"
|
||||
else:
|
||||
# Test with appropriate valid value
|
||||
if segment_type == SegmentType.STRING:
|
||||
assert segment_type.is_valid("test") is True
|
||||
elif segment_type == SegmentType.NUMBER:
|
||||
assert segment_type.is_valid(42) is True
|
||||
elif segment_type == SegmentType.BOOLEAN:
|
||||
assert segment_type.is_valid(True) is True
|
||||
elif segment_type == SegmentType.OBJECT:
|
||||
assert segment_type.is_valid({}) is True
|
||||
elif segment_type == SegmentType.SECRET:
|
||||
assert segment_type.is_valid("secret") is True
|
||||
elif segment_type == SegmentType.FILE:
|
||||
assert segment_type.is_valid(create_test_file()) is True
|
||||
elif segment_type == SegmentType.NONE:
|
||||
assert segment_type.is_valid(None) is True
|
||||
elif segment_type == SegmentType.GROUP:
|
||||
assert segment_type.is_valid(SegmentGroup(value=[StringSegment(value="test")])) is True
|
||||
|
||||
def test_boolean_vs_integer_type_distinction(self):
|
||||
"""Test the important distinction between boolean and integer types in validation."""
|
||||
# This tests the comment in the code about bool being a subclass of int
|
||||
|
||||
# Boolean type should only accept actual booleans, not integers
|
||||
assert SegmentType.BOOLEAN.is_valid(True) is True
|
||||
assert SegmentType.BOOLEAN.is_valid(False) is True
|
||||
assert SegmentType.BOOLEAN.is_valid(1) is False # Integer 1, not boolean
|
||||
assert SegmentType.BOOLEAN.is_valid(0) is False # Integer 0, not boolean
|
||||
|
||||
# Number type should accept both integers and floats, including booleans (since bool is subclass of int)
|
||||
assert SegmentType.NUMBER.is_valid(42) is True
|
||||
assert SegmentType.NUMBER.is_valid(3.14) is True
|
||||
assert SegmentType.NUMBER.is_valid(True) is True # bool is subclass of int
|
||||
assert SegmentType.NUMBER.is_valid(False) is True # bool is subclass of int
|
||||
|
||||
def test_array_validation_recursive_behavior(self):
|
||||
"""Test that array validation correctly handles recursive validation calls."""
|
||||
# When validating array elements, _validate_array calls is_valid recursively
|
||||
# with ArrayValidation.NONE to avoid infinite recursion
|
||||
|
||||
# Test nested validation doesn't cause issues
|
||||
nested_arrays = [["inner", "array"], ["another", "inner"]]
|
||||
|
||||
# ARRAY_ANY should accept nested arrays
|
||||
assert SegmentType.ARRAY_ANY.is_valid(nested_arrays, ArrayValidation.ALL) is True
|
||||
|
||||
# ARRAY_STRING should reject nested arrays (first element is not a string)
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.FIRST) is False
|
||||
assert SegmentType.ARRAY_STRING.is_valid(nested_arrays, ArrayValidation.ALL) is False
|
||||
91
dify/api/tests/unit_tests/core/variables/test_variables.py
Normal file
91
dify/api/tests/unit_tests/core/variables/test_variables.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.variables import (
|
||||
ArrayFileVariable,
|
||||
ArrayVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
SecretVariable,
|
||||
SegmentType,
|
||||
StringVariable,
|
||||
)
|
||||
from core.variables.variables import Variable
|
||||
|
||||
|
||||
def test_frozen_variables():
|
||||
var = StringVariable(name="text", value="text")
|
||||
with pytest.raises(ValidationError):
|
||||
var.value = "new value"
|
||||
|
||||
int_var = IntegerVariable(name="integer", value=42)
|
||||
with pytest.raises(ValidationError):
|
||||
int_var.value = 100
|
||||
|
||||
float_var = FloatVariable(name="float", value=3.14)
|
||||
with pytest.raises(ValidationError):
|
||||
float_var.value = 2.718
|
||||
|
||||
secret_var = SecretVariable(name="secret", value="secret_value")
|
||||
with pytest.raises(ValidationError):
|
||||
secret_var.value = "new_secret_value"
|
||||
|
||||
|
||||
def test_variable_value_type_immutable():
|
||||
with pytest.raises(ValidationError):
|
||||
StringVariable(value_type=SegmentType.ARRAY_ANY, name="text", value="text")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
StringVariable.model_validate({"value_type": "not text", "name": "text", "value": "text"})
|
||||
|
||||
var = IntegerVariable(name="integer", value=42)
|
||||
with pytest.raises(ValidationError):
|
||||
IntegerVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
|
||||
|
||||
var = FloatVariable(name="float", value=3.14)
|
||||
with pytest.raises(ValidationError):
|
||||
FloatVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
|
||||
|
||||
var = SecretVariable(name="secret", value="secret_value")
|
||||
with pytest.raises(ValidationError):
|
||||
SecretVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
|
||||
|
||||
|
||||
def test_object_variable_to_object():
|
||||
var = ObjectVariable(
|
||||
name="object",
|
||||
value={
|
||||
"key1": {
|
||||
"key2": "value2",
|
||||
},
|
||||
"key2": ["value5_1", 42, {}],
|
||||
},
|
||||
)
|
||||
|
||||
assert var.to_object() == {
|
||||
"key1": {
|
||||
"key2": "value2",
|
||||
},
|
||||
"key2": [
|
||||
"value5_1",
|
||||
42,
|
||||
{},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_variable_to_object():
|
||||
var: Variable = StringVariable(name="text", value="text")
|
||||
assert var.to_object() == "text"
|
||||
var = IntegerVariable(name="integer", value=42)
|
||||
assert var.to_object() == 42
|
||||
var = FloatVariable(name="float", value=3.14)
|
||||
assert var.to_object() == 3.14
|
||||
var = SecretVariable(name="secret", value="secret_value")
|
||||
assert var.to_object() == "secret_value"
|
||||
|
||||
|
||||
def test_array_file_variable_is_array_variable():
|
||||
var = ArrayFileVariable(name="files", value=[])
|
||||
assert isinstance(var, ArrayVariable)
|
||||
0
dify/api/tests/unit_tests/core/workflow/__init__.py
Normal file
0
dify/api/tests/unit_tests/core/workflow/__init__.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import json
|
||||
from time import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
|
||||
|
||||
|
||||
class StubCoordinator:
|
||||
def __init__(self) -> None:
|
||||
self.state = "initial"
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps({"state": self.state})
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
payload = json.loads(data)
|
||||
self.state = payload["state"]
|
||||
|
||||
|
||||
class TestGraphRuntimeState:
|
||||
def test_property_getters_and_setters(self):
|
||||
# FIXME(-LAN-): Mock VariablePool if needed
|
||||
variable_pool = VariablePool()
|
||||
start_time = time()
|
||||
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=start_time)
|
||||
|
||||
# Test variable_pool property (read-only)
|
||||
assert state.variable_pool == variable_pool
|
||||
|
||||
# Test start_at property
|
||||
assert state.start_at == start_time
|
||||
new_time = time() + 100
|
||||
state.start_at = new_time
|
||||
assert state.start_at == new_time
|
||||
|
||||
# Test total_tokens property
|
||||
assert state.total_tokens == 0
|
||||
state.total_tokens = 100
|
||||
assert state.total_tokens == 100
|
||||
|
||||
# Test node_run_steps property
|
||||
assert state.node_run_steps == 0
|
||||
state.node_run_steps = 5
|
||||
assert state.node_run_steps == 5
|
||||
|
||||
def test_outputs_immutability(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test that getting outputs returns a copy
|
||||
outputs1 = state.outputs
|
||||
outputs2 = state.outputs
|
||||
assert outputs1 == outputs2
|
||||
assert outputs1 is not outputs2 # Different objects
|
||||
|
||||
# Test that modifying retrieved outputs doesn't affect internal state
|
||||
outputs = state.outputs
|
||||
outputs["test"] = "value"
|
||||
assert "test" not in state.outputs
|
||||
|
||||
# Test set_output method
|
||||
state.set_output("key1", "value1")
|
||||
assert state.get_output("key1") == "value1"
|
||||
|
||||
# Test update_outputs method
|
||||
state.update_outputs({"key2": "value2", "key3": "value3"})
|
||||
assert state.get_output("key2") == "value2"
|
||||
assert state.get_output("key3") == "value3"
|
||||
|
||||
def test_llm_usage_immutability(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test that getting llm_usage returns a copy
|
||||
usage1 = state.llm_usage
|
||||
usage2 = state.llm_usage
|
||||
assert usage1 is not usage2 # Different objects
|
||||
|
||||
def test_type_validation(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test total_tokens validation
|
||||
with pytest.raises(ValueError):
|
||||
state.total_tokens = -1
|
||||
|
||||
# Test node_run_steps validation
|
||||
with pytest.raises(ValueError):
|
||||
state.node_run_steps = -1
|
||||
|
||||
def test_helper_methods(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test increment_node_run_steps
|
||||
initial_steps = state.node_run_steps
|
||||
state.increment_node_run_steps()
|
||||
assert state.node_run_steps == initial_steps + 1
|
||||
|
||||
# Test add_tokens
|
||||
initial_tokens = state.total_tokens
|
||||
state.add_tokens(50)
|
||||
assert state.total_tokens == initial_tokens + 50
|
||||
|
||||
# Test add_tokens validation
|
||||
with pytest.raises(ValueError):
|
||||
state.add_tokens(-1)
|
||||
|
||||
def test_ready_queue_default_instantiation(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
|
||||
queue = state.ready_queue
|
||||
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
|
||||
assert isinstance(queue, InMemoryReadyQueue)
|
||||
assert state.ready_queue is queue
|
||||
|
||||
def test_graph_execution_lazy_instantiation(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
|
||||
execution = state.graph_execution
|
||||
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
|
||||
assert isinstance(execution, GraphExecution)
|
||||
assert execution.workflow_id == ""
|
||||
assert state.graph_execution is execution
|
||||
|
||||
def test_response_coordinator_configuration(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = state.response_coordinator
|
||||
|
||||
mock_graph = MagicMock()
|
||||
with patch("core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator") as coordinator_cls:
|
||||
coordinator_instance = MagicMock()
|
||||
coordinator_cls.return_value = coordinator_instance
|
||||
|
||||
state.configure(graph=mock_graph)
|
||||
|
||||
assert state.response_coordinator is coordinator_instance
|
||||
coordinator_cls.assert_called_once_with(variable_pool=variable_pool, graph=mock_graph)
|
||||
|
||||
# Configure again with same graph should be idempotent
|
||||
state.configure(graph=mock_graph)
|
||||
|
||||
other_graph = MagicMock()
|
||||
with pytest.raises(ValueError):
|
||||
state.attach_graph(other_graph)
|
||||
|
||||
def test_read_only_wrapper_exposes_additional_state(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
state.configure()
|
||||
|
||||
wrapper = ReadOnlyGraphRuntimeStateWrapper(state)
|
||||
|
||||
assert wrapper.ready_queue_size == 0
|
||||
assert wrapper.exceptions_count == 0
|
||||
|
||||
def test_read_only_wrapper_serializes_runtime_state(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
state.total_tokens = 5
|
||||
state.set_output("result", {"success": True})
|
||||
state.ready_queue.put("node-1")
|
||||
|
||||
wrapper = ReadOnlyGraphRuntimeStateWrapper(state)
|
||||
|
||||
wrapper_snapshot = json.loads(wrapper.dumps())
|
||||
state_snapshot = json.loads(state.dumps())
|
||||
|
||||
assert wrapper_snapshot == state_snapshot
|
||||
|
||||
def test_dumps_and_loads_roundtrip_with_response_coordinator(self):
|
||||
variable_pool = VariablePool()
|
||||
variable_pool.add(("node1", "value"), "payload")
|
||||
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
state.total_tokens = 10
|
||||
state.node_run_steps = 3
|
||||
state.set_output("final", {"result": True})
|
||||
usage = LLMUsage.from_metadata(
|
||||
{
|
||||
"prompt_tokens": 2,
|
||||
"completion_tokens": 3,
|
||||
"total_tokens": 5,
|
||||
"total_price": "1.23",
|
||||
"currency": "USD",
|
||||
"latency": 0.5,
|
||||
}
|
||||
)
|
||||
state.llm_usage = usage
|
||||
state.ready_queue.put("node-A")
|
||||
|
||||
graph_execution = state.graph_execution
|
||||
graph_execution.workflow_id = "wf-123"
|
||||
graph_execution.exceptions_count = 4
|
||||
graph_execution.started = True
|
||||
|
||||
mock_graph = MagicMock()
|
||||
stub = StubCoordinator()
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub):
|
||||
state.attach_graph(mock_graph)
|
||||
|
||||
stub.state = "configured"
|
||||
|
||||
snapshot = state.dumps()
|
||||
|
||||
restored = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
assert restored.total_tokens == 10
|
||||
assert restored.node_run_steps == 3
|
||||
assert restored.get_output("final") == {"result": True}
|
||||
assert restored.llm_usage.total_tokens == usage.total_tokens
|
||||
assert restored.ready_queue.qsize() == 1
|
||||
assert restored.ready_queue.get(timeout=0.01) == "node-A"
|
||||
|
||||
restored_segment = restored.variable_pool.get(("node1", "value"))
|
||||
assert restored_segment is not None
|
||||
assert restored_segment.value == "payload"
|
||||
|
||||
restored_execution = restored.graph_execution
|
||||
assert restored_execution.workflow_id == "wf-123"
|
||||
assert restored_execution.exceptions_count == 4
|
||||
assert restored_execution.started is True
|
||||
|
||||
new_stub = StubCoordinator()
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
|
||||
restored.attach_graph(mock_graph)
|
||||
|
||||
assert new_stub.state == "configured"
|
||||
|
||||
def test_loads_rehydrates_existing_instance(self):
|
||||
variable_pool = VariablePool()
|
||||
variable_pool.add(("node", "key"), "value")
|
||||
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
state.total_tokens = 7
|
||||
state.node_run_steps = 2
|
||||
state.set_output("foo", "bar")
|
||||
state.ready_queue.put("node-1")
|
||||
|
||||
execution = state.graph_execution
|
||||
execution.workflow_id = "wf-456"
|
||||
execution.started = True
|
||||
|
||||
mock_graph = MagicMock()
|
||||
original_stub = StubCoordinator()
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub):
|
||||
state.attach_graph(mock_graph)
|
||||
|
||||
original_stub.state = "configured"
|
||||
snapshot = state.dumps()
|
||||
|
||||
new_stub = StubCoordinator()
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
|
||||
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
restored.attach_graph(mock_graph)
|
||||
restored.loads(snapshot)
|
||||
|
||||
assert restored.total_tokens == 7
|
||||
assert restored.node_run_steps == 2
|
||||
assert restored.get_output("foo") == "bar"
|
||||
assert restored.ready_queue.qsize() == 1
|
||||
assert restored.ready_queue.get(timeout=0.01) == "node-1"
|
||||
|
||||
restored_segment = restored.variable_pool.get(("node", "key"))
|
||||
assert restored_segment is not None
|
||||
assert restored_segment.value == "value"
|
||||
|
||||
restored_execution = restored.graph_execution
|
||||
assert restored_execution.workflow_id == "wf-456"
|
||||
assert restored_execution.started is True
|
||||
|
||||
assert new_stub.state == "configured"
|
||||
@@ -0,0 +1,171 @@
|
||||
"""Tests for _PrivateWorkflowPauseEntity implementation."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity:
|
||||
"""Test _PrivateWorkflowPauseEntity implementation."""
|
||||
|
||||
def test_entity_initialization(self):
|
||||
"""Test entity initialization with required parameters."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
# Create entity
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert entity._pause_model is mock_pause_model
|
||||
assert entity._cached_state is None
|
||||
|
||||
def test_from_models_classmethod(self):
|
||||
"""Test from_models class method."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
# Create entity using from_models
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(
|
||||
workflow_pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify entity creation
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model is mock_pause_model
|
||||
|
||||
def test_id_property(self):
|
||||
"""Test id property returns pause model ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.id == "pause-123"
|
||||
|
||||
def test_workflow_execution_id_property(self):
|
||||
"""Test workflow_execution_id property returns workflow run ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.workflow_execution_id == "execution-456"
|
||||
|
||||
def test_resumed_at_property(self):
|
||||
"""Test resumed_at property returns pause model resumed_at."""
|
||||
resumed_at = datetime(2023, 12, 25, 15, 30, 45)
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = resumed_at
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.resumed_at == resumed_at
|
||||
|
||||
def test_resumed_at_property_none(self):
|
||||
"""Test resumed_at property returns None when not set."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.resumed_at is None
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_first_call(self, mock_storage):
|
||||
"""Test get_state loads from storage on first call."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
mock_storage.load.return_value = state_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# First call should load from storage
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == state_data
|
||||
mock_storage.load.assert_called_once_with("test-state-key")
|
||||
assert entity._cached_state == state_data
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_cached_call(self, mock_storage):
|
||||
"""Test get_state returns cached data on subsequent calls."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
mock_storage.load.return_value = state_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# First call
|
||||
result1 = entity.get_state()
|
||||
# Second call should use cache
|
||||
result2 = entity.get_state()
|
||||
|
||||
assert result1 == state_data
|
||||
assert result2 == state_data
|
||||
# Storage should only be called once
|
||||
mock_storage.load.assert_called_once_with("test-state-key")
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_with_pre_cached_data(self, mock_storage):
|
||||
"""Test get_state returns pre-cached data."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Pre-cache data
|
||||
entity._cached_state = state_data
|
||||
|
||||
# Should return cached data without calling storage
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == state_data
|
||||
mock_storage.load.assert_not_called()
|
||||
|
||||
def test_entity_with_binary_state_data(self):
|
||||
"""Test entity with binary state data."""
|
||||
# Test with binary data that's not valid JSON
|
||||
binary_data = b"\x00\x01\x02\x03\x04\x05\xff\xfe"
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = binary_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == binary_data
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Tests for template module."""
|
||||
|
||||
from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment
|
||||
|
||||
|
||||
class TestTemplate:
|
||||
"""Test Template class functionality."""
|
||||
|
||||
def test_from_answer_template_simple(self):
|
||||
"""Test parsing a simple answer template."""
|
||||
template_str = "Hello, {{#node1.name#}}!"
|
||||
template = Template.from_answer_template(template_str)
|
||||
|
||||
assert len(template.segments) == 3
|
||||
assert isinstance(template.segments[0], TextSegment)
|
||||
assert template.segments[0].text == "Hello, "
|
||||
assert isinstance(template.segments[1], VariableSegment)
|
||||
assert template.segments[1].selector == ["node1", "name"]
|
||||
assert isinstance(template.segments[2], TextSegment)
|
||||
assert template.segments[2].text == "!"
|
||||
|
||||
def test_from_answer_template_multiple_vars(self):
|
||||
"""Test parsing an answer template with multiple variables."""
|
||||
template_str = "Hello {{#node1.name#}}, your age is {{#node2.age#}}."
|
||||
template = Template.from_answer_template(template_str)
|
||||
|
||||
assert len(template.segments) == 5
|
||||
assert isinstance(template.segments[0], TextSegment)
|
||||
assert template.segments[0].text == "Hello "
|
||||
assert isinstance(template.segments[1], VariableSegment)
|
||||
assert template.segments[1].selector == ["node1", "name"]
|
||||
assert isinstance(template.segments[2], TextSegment)
|
||||
assert template.segments[2].text == ", your age is "
|
||||
assert isinstance(template.segments[3], VariableSegment)
|
||||
assert template.segments[3].selector == ["node2", "age"]
|
||||
assert isinstance(template.segments[4], TextSegment)
|
||||
assert template.segments[4].text == "."
|
||||
|
||||
def test_from_answer_template_no_vars(self):
|
||||
"""Test parsing an answer template with no variables."""
|
||||
template_str = "Hello, world!"
|
||||
template = Template.from_answer_template(template_str)
|
||||
|
||||
assert len(template.segments) == 1
|
||||
assert isinstance(template.segments[0], TextSegment)
|
||||
assert template.segments[0].text == "Hello, world!"
|
||||
|
||||
def test_from_end_outputs_single(self):
|
||||
"""Test creating template from End node outputs with single variable."""
|
||||
outputs_config = [{"variable": "text", "value_selector": ["node1", "text"]}]
|
||||
template = Template.from_end_outputs(outputs_config)
|
||||
|
||||
assert len(template.segments) == 1
|
||||
assert isinstance(template.segments[0], VariableSegment)
|
||||
assert template.segments[0].selector == ["node1", "text"]
|
||||
|
||||
def test_from_end_outputs_multiple(self):
|
||||
"""Test creating template from End node outputs with multiple variables."""
|
||||
outputs_config = [
|
||||
{"variable": "text", "value_selector": ["node1", "text"]},
|
||||
{"variable": "result", "value_selector": ["node2", "result"]},
|
||||
]
|
||||
template = Template.from_end_outputs(outputs_config)
|
||||
|
||||
assert len(template.segments) == 3
|
||||
assert isinstance(template.segments[0], VariableSegment)
|
||||
assert template.segments[0].selector == ["node1", "text"]
|
||||
assert template.segments[0].variable_name == "text"
|
||||
assert isinstance(template.segments[1], TextSegment)
|
||||
assert template.segments[1].text == "\n"
|
||||
assert isinstance(template.segments[2], VariableSegment)
|
||||
assert template.segments[2].selector == ["node2", "result"]
|
||||
assert template.segments[2].variable_name == "result"
|
||||
|
||||
def test_from_end_outputs_empty(self):
|
||||
"""Test creating template from empty End node outputs."""
|
||||
outputs_config = []
|
||||
template = Template.from_end_outputs(outputs_config)
|
||||
|
||||
assert len(template.segments) == 0
|
||||
|
||||
def test_template_str_representation(self):
|
||||
"""Test string representation of template."""
|
||||
template_str = "Hello, {{#node1.name#}}!"
|
||||
template = Template.from_answer_template(template_str)
|
||||
|
||||
assert str(template) == template_str
|
||||
@@ -0,0 +1,136 @@
|
||||
from core.variables.segments import (
|
||||
BooleanSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
class TestVariablePoolGetAndNestedAttribute:
|
||||
#
|
||||
# _get_nested_attribute tests
|
||||
#
|
||||
def test__get_nested_attribute_existing_key(self):
|
||||
pool = VariablePool.empty()
|
||||
obj = {"a": 123}
|
||||
segment = pool._get_nested_attribute(obj, "a")
|
||||
assert segment is not None
|
||||
assert segment.value == 123
|
||||
|
||||
def test__get_nested_attribute_missing_key(self):
|
||||
pool = VariablePool.empty()
|
||||
obj = {"a": 123}
|
||||
segment = pool._get_nested_attribute(obj, "b")
|
||||
assert segment is None
|
||||
|
||||
def test__get_nested_attribute_non_dict(self):
|
||||
pool = VariablePool.empty()
|
||||
obj = ["not", "a", "dict"]
|
||||
segment = pool._get_nested_attribute(obj, "a")
|
||||
assert segment is None
|
||||
|
||||
def test__get_nested_attribute_with_none_value(self):
|
||||
pool = VariablePool.empty()
|
||||
obj = {"a": None}
|
||||
segment = pool._get_nested_attribute(obj, "a")
|
||||
assert segment is not None
|
||||
assert isinstance(segment, NoneSegment)
|
||||
|
||||
def test__get_nested_attribute_with_empty_string(self):
|
||||
pool = VariablePool.empty()
|
||||
obj = {"a": ""}
|
||||
segment = pool._get_nested_attribute(obj, "a")
|
||||
assert segment is not None
|
||||
assert isinstance(segment, StringSegment)
|
||||
assert segment.value == ""
|
||||
|
||||
#
|
||||
# get tests
|
||||
#
|
||||
def test_get_simple_variable(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(("node1", "var1"), "value1")
|
||||
segment = pool.get(("node1", "var1"))
|
||||
assert segment is not None
|
||||
assert segment.value == "value1"
|
||||
|
||||
def test_get_missing_variable(self):
|
||||
pool = VariablePool.empty()
|
||||
result = pool.get(("node1", "unknown"))
|
||||
assert result is None
|
||||
|
||||
def test_get_with_too_short_selector(self):
|
||||
pool = VariablePool.empty()
|
||||
result = pool.get(("only_node",))
|
||||
assert result is None
|
||||
|
||||
def test_get_nested_object_attribute(self):
|
||||
pool = VariablePool.empty()
|
||||
obj_value = {"inner": "hello"}
|
||||
pool.add(("node1", "obj"), obj_value)
|
||||
|
||||
# simulate selector with nested attr
|
||||
segment = pool.get(("node1", "obj", "inner"))
|
||||
assert segment is not None
|
||||
assert segment.value == "hello"
|
||||
|
||||
def test_get_nested_object_missing_attribute(self):
|
||||
pool = VariablePool.empty()
|
||||
obj_value = {"inner": "hello"}
|
||||
pool.add(("node1", "obj"), obj_value)
|
||||
|
||||
result = pool.get(("node1", "obj", "not_exist"))
|
||||
assert result is None
|
||||
|
||||
def test_get_nested_object_attribute_with_falsy_values(self):
|
||||
pool = VariablePool.empty()
|
||||
obj_value = {
|
||||
"inner_none": None,
|
||||
"inner_empty": "",
|
||||
"inner_zero": 0,
|
||||
"inner_false": False,
|
||||
}
|
||||
pool.add(("node1", "obj"), obj_value)
|
||||
|
||||
segment_none = pool.get(("node1", "obj", "inner_none"))
|
||||
assert segment_none is not None
|
||||
assert isinstance(segment_none, NoneSegment)
|
||||
|
||||
segment_empty = pool.get(("node1", "obj", "inner_empty"))
|
||||
assert segment_empty is not None
|
||||
assert isinstance(segment_empty, StringSegment)
|
||||
assert segment_empty.value == ""
|
||||
|
||||
segment_zero = pool.get(("node1", "obj", "inner_zero"))
|
||||
assert segment_zero is not None
|
||||
assert isinstance(segment_zero, IntegerSegment)
|
||||
assert segment_zero.value == 0
|
||||
|
||||
segment_false = pool.get(("node1", "obj", "inner_false"))
|
||||
assert segment_false is not None
|
||||
assert isinstance(segment_false, BooleanSegment)
|
||||
assert segment_false.value is False
|
||||
|
||||
|
||||
class TestVariablePoolGetNotModifyVariableDictionary:
|
||||
_NODE_ID = "start"
|
||||
_VAR_NAME = "name"
|
||||
|
||||
def test_convert_to_template_should_not_introduce_extra_keys(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add([self._NODE_ID, self._VAR_NAME], 0)
|
||||
pool.convert_template("The start.name is {{#start.name#}}")
|
||||
assert "The start" not in pool.variable_dictionary
|
||||
|
||||
def test_get_should_not_modify_variable_dictionary(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.get([self._NODE_ID, self._VAR_NAME])
|
||||
assert len(pool.variable_dictionary) == 1 # only contains `sys` node id
|
||||
assert "start" not in pool.variable_dictionary
|
||||
|
||||
pool = VariablePool.empty()
|
||||
pool.add([self._NODE_ID, self._VAR_NAME], "Joe")
|
||||
pool.get([self._NODE_ID, "count"])
|
||||
start_subdict = pool.variable_dictionary[self._NODE_ID]
|
||||
assert "count" not in start_subdict
|
||||
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
Unit tests for WorkflowNodeExecution domain model, focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionProcessDataTruncation:
|
||||
"""Test process_data truncation functionality in WorkflowNodeExecution domain model."""
|
||||
|
||||
def create_workflow_node_execution(
|
||||
self,
|
||||
process_data: dict[str, Any] | None = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a WorkflowNodeExecution instance for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=process_data,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
def test_initial_process_data_truncated_state(self):
|
||||
"""Test that process_data_truncated returns False initially."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
|
||||
assert execution.process_data_truncated is False
|
||||
assert execution.get_truncated_process_data() is None
|
||||
|
||||
def test_set_and_get_truncated_process_data(self):
|
||||
"""Test setting and getting truncated process_data."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
test_truncated_data = {"truncated": True, "key": "value"}
|
||||
|
||||
execution.set_truncated_process_data(test_truncated_data)
|
||||
|
||||
assert execution.process_data_truncated is True
|
||||
assert execution.get_truncated_process_data() == test_truncated_data
|
||||
|
||||
def test_set_truncated_process_data_to_none(self):
|
||||
"""Test setting truncated process_data to None."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
|
||||
# First set some data
|
||||
execution.set_truncated_process_data({"key": "value"})
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
# Then set to None
|
||||
execution.set_truncated_process_data(None)
|
||||
assert execution.process_data_truncated is False
|
||||
assert execution.get_truncated_process_data() is None
|
||||
|
||||
def test_get_response_process_data_with_no_truncation(self):
|
||||
"""Test get_response_process_data when no truncation is set."""
|
||||
original_data = {"original": True, "data": "value"}
|
||||
execution = self.create_workflow_node_execution(process_data=original_data)
|
||||
|
||||
response_data = execution.get_response_process_data()
|
||||
|
||||
assert response_data == original_data
|
||||
assert execution.process_data_truncated is False
|
||||
|
||||
def test_get_response_process_data_with_truncation(self):
|
||||
"""Test get_response_process_data when truncation is set."""
|
||||
original_data = {"original": True, "large_data": "x" * 10000}
|
||||
truncated_data = {"original": True, "large_data": "[TRUNCATED]"}
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=original_data)
|
||||
execution.set_truncated_process_data(truncated_data)
|
||||
|
||||
response_data = execution.get_response_process_data()
|
||||
|
||||
# Should return truncated data, not original
|
||||
assert response_data == truncated_data
|
||||
assert response_data != original_data
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
def test_get_response_process_data_with_none_process_data(self):
|
||||
"""Test get_response_process_data when process_data is None."""
|
||||
execution = self.create_workflow_node_execution(process_data=None)
|
||||
|
||||
response_data = execution.get_response_process_data()
|
||||
|
||||
assert response_data is None
|
||||
assert execution.process_data_truncated is False
|
||||
|
||||
def test_consistency_with_inputs_outputs_pattern(self):
|
||||
"""Test that process_data truncation follows the same pattern as inputs/outputs."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
|
||||
# Test that all truncation methods exist and behave consistently
|
||||
test_data = {"test": "data"}
|
||||
|
||||
# Test inputs truncation
|
||||
execution.set_truncated_inputs(test_data)
|
||||
assert execution.inputs_truncated is True
|
||||
assert execution.get_truncated_inputs() == test_data
|
||||
|
||||
# Test outputs truncation
|
||||
execution.set_truncated_outputs(test_data)
|
||||
assert execution.outputs_truncated is True
|
||||
assert execution.get_truncated_outputs() == test_data
|
||||
|
||||
# Test process_data truncation
|
||||
execution.set_truncated_process_data(test_data)
|
||||
assert execution.process_data_truncated is True
|
||||
assert execution.get_truncated_process_data() == test_data
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_data",
|
||||
[
|
||||
{"simple": "value"},
|
||||
{"nested": {"key": "value"}},
|
||||
{"list": [1, 2, 3]},
|
||||
{"mixed": {"string": "value", "number": 42, "list": [1, 2]}},
|
||||
{}, # empty dict
|
||||
],
|
||||
)
|
||||
def test_truncated_process_data_with_various_data_types(self, test_data):
|
||||
"""Test that truncated process_data works with various data types."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
|
||||
execution.set_truncated_process_data(test_data)
|
||||
|
||||
assert execution.process_data_truncated is True
|
||||
assert execution.get_truncated_process_data() == test_data
|
||||
assert execution.get_response_process_data() == test_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessDataScenario:
|
||||
"""Test scenario data for process_data functionality."""
|
||||
|
||||
name: str
|
||||
original_data: dict[str, Any] | None
|
||||
truncated_data: dict[str, Any] | None
|
||||
expected_truncated_flag: bool
|
||||
expected_response_data: dict[str, Any] | None
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionProcessDataScenarios:
|
||||
"""Test various scenarios for process_data handling."""
|
||||
|
||||
def get_process_data_scenarios(self) -> list[ProcessDataScenario]:
|
||||
"""Create test scenarios for process_data functionality."""
|
||||
return [
|
||||
ProcessDataScenario(
|
||||
name="no_process_data",
|
||||
original_data=None,
|
||||
truncated_data=None,
|
||||
expected_truncated_flag=False,
|
||||
expected_response_data=None,
|
||||
),
|
||||
ProcessDataScenario(
|
||||
name="process_data_without_truncation",
|
||||
original_data={"small": "data"},
|
||||
truncated_data=None,
|
||||
expected_truncated_flag=False,
|
||||
expected_response_data={"small": "data"},
|
||||
),
|
||||
ProcessDataScenario(
|
||||
name="process_data_with_truncation",
|
||||
original_data={"large": "x" * 10000, "metadata": "info"},
|
||||
truncated_data={"large": "[TRUNCATED]", "metadata": "info"},
|
||||
expected_truncated_flag=True,
|
||||
expected_response_data={"large": "[TRUNCATED]", "metadata": "info"},
|
||||
),
|
||||
ProcessDataScenario(
|
||||
name="empty_process_data",
|
||||
original_data={},
|
||||
truncated_data=None,
|
||||
expected_truncated_flag=False,
|
||||
expected_response_data={},
|
||||
),
|
||||
ProcessDataScenario(
|
||||
name="complex_nested_data_with_truncation",
|
||||
original_data={
|
||||
"config": {"setting": "value"},
|
||||
"logs": ["log1", "log2"] * 1000, # Large list
|
||||
"status": "running",
|
||||
},
|
||||
truncated_data={"config": {"setting": "value"}, "logs": "[TRUNCATED: 2000 items]", "status": "running"},
|
||||
expected_truncated_flag=True,
|
||||
expected_response_data={
|
||||
"config": {"setting": "value"},
|
||||
"logs": "[TRUNCATED: 2000 items]",
|
||||
"status": "running",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
get_process_data_scenarios(None),
|
||||
ids=[scenario.name for scenario in get_process_data_scenarios(None)],
|
||||
)
|
||||
def test_process_data_scenarios(self, scenario: ProcessDataScenario):
|
||||
"""Test various process_data scenarios."""
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=scenario.original_data,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
if scenario.truncated_data is not None:
|
||||
execution.set_truncated_process_data(scenario.truncated_data)
|
||||
|
||||
assert execution.process_data_truncated == scenario.expected_truncated_flag
|
||||
assert execution.get_response_process_data() == scenario.expected_response_data
|
||||
281
dify/api/tests/unit_tests/core/workflow/graph/test_graph.py
Normal file
281
dify/api/tests/unit_tests/core/workflow/graph/test_graph.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""Unit tests for Graph class methods."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.graph.edge import Edge
|
||||
from core.workflow.graph.graph import Graph
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
|
||||
def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node:
|
||||
"""Create a mock node for testing."""
|
||||
node = Mock(spec=Node)
|
||||
node.id = node_id
|
||||
node.execution_type = execution_type
|
||||
node.state = state
|
||||
node.node_type = NodeType.START
|
||||
return node
|
||||
|
||||
|
||||
class TestMarkInactiveRootBranches:
|
||||
"""Test cases for _mark_inactive_root_branches method."""
|
||||
|
||||
def test_single_root_no_marking(self):
|
||||
"""Test that single root graph doesn't mark anything as skipped."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {"child1": ["edge1"]}
|
||||
out_edges = {"root1": ["edge1"]}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
|
||||
def test_multiple_roots_mark_inactive(self):
|
||||
"""Test marking inactive root branches with multiple root nodes."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {"child1": ["edge1"], "child2": ["edge2"]}
|
||||
out_edges = {"root1": ["edge1"], "root2": ["edge2"]}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child2"].state == NodeState.SKIPPED
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
|
||||
def test_shared_downstream_node(self):
|
||||
"""Test that shared downstream nodes are not skipped if at least one path is active."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
"shared": create_mock_node("shared", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="child1", head="shared", source_handle="source"),
|
||||
"edge4": Edge(id="edge4", tail="child2", head="shared", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"child1": ["edge1"],
|
||||
"child2": ["edge2"],
|
||||
"shared": ["edge3", "edge4"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"child1": ["edge3"],
|
||||
"child2": ["edge4"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child2"].state == NodeState.SKIPPED
|
||||
assert nodes["shared"].state == NodeState.UNKNOWN # Not skipped because edge3 is active
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
assert edges["edge3"].state == NodeState.UNKNOWN
|
||||
assert edges["edge4"].state == NodeState.SKIPPED
|
||||
|
||||
def test_deep_branch_marking(self):
|
||||
"""Test marking deep branches with multiple levels."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"level1_a": create_mock_node("level1_a", NodeExecutionType.EXECUTABLE),
|
||||
"level1_b": create_mock_node("level1_b", NodeExecutionType.EXECUTABLE),
|
||||
"level2_a": create_mock_node("level2_a", NodeExecutionType.EXECUTABLE),
|
||||
"level2_b": create_mock_node("level2_b", NodeExecutionType.EXECUTABLE),
|
||||
"level3": create_mock_node("level3", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="level1_a", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="level1_b", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="level1_a", head="level2_a", source_handle="source"),
|
||||
"edge4": Edge(id="edge4", tail="level1_b", head="level2_b", source_handle="source"),
|
||||
"edge5": Edge(id="edge5", tail="level2_b", head="level3", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"level1_a": ["edge1"],
|
||||
"level1_b": ["edge2"],
|
||||
"level2_a": ["edge3"],
|
||||
"level2_b": ["edge4"],
|
||||
"level3": ["edge5"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"level1_a": ["edge3"],
|
||||
"level1_b": ["edge4"],
|
||||
"level2_b": ["edge5"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["level1_a"].state == NodeState.UNKNOWN
|
||||
assert nodes["level1_b"].state == NodeState.SKIPPED
|
||||
assert nodes["level2_a"].state == NodeState.UNKNOWN
|
||||
assert nodes["level2_b"].state == NodeState.SKIPPED
|
||||
assert nodes["level3"].state == NodeState.SKIPPED
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
assert edges["edge3"].state == NodeState.UNKNOWN
|
||||
assert edges["edge4"].state == NodeState.SKIPPED
|
||||
assert edges["edge5"].state == NodeState.SKIPPED
|
||||
|
||||
def test_non_root_execution_type(self):
|
||||
"""Test that nodes with non-ROOT execution type are not treated as root nodes."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"non_root": create_mock_node("non_root", NodeExecutionType.EXECUTABLE),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="non_root", head="child2", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {"child1": ["edge1"], "child2": ["edge2"]}
|
||||
out_edges = {"root1": ["edge1"], "non_root": ["edge2"]}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["non_root"].state == NodeState.UNKNOWN # Not marked as skipped
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child2"].state == NodeState.UNKNOWN
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.UNKNOWN
|
||||
|
||||
def test_empty_graph(self):
|
||||
"""Test handling of empty graph structures."""
|
||||
nodes = {}
|
||||
edges = {}
|
||||
in_edges = {}
|
||||
out_edges = {}
|
||||
|
||||
# Should not raise any errors
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "non_existent")
|
||||
|
||||
def test_three_roots_mark_two_inactive(self):
|
||||
"""Test with three root nodes where two should be marked inactive."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"root3": create_mock_node("root3", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
"child3": create_mock_node("child3", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="root3", head="child3", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"child1": ["edge1"],
|
||||
"child2": ["edge2"],
|
||||
"child3": ["edge3"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"root3": ["edge3"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root2")
|
||||
|
||||
assert nodes["root1"].state == NodeState.SKIPPED
|
||||
assert nodes["root2"].state == NodeState.UNKNOWN # Active root
|
||||
assert nodes["root3"].state == NodeState.SKIPPED
|
||||
assert nodes["child1"].state == NodeState.SKIPPED
|
||||
assert nodes["child2"].state == NodeState.UNKNOWN
|
||||
assert nodes["child3"].state == NodeState.SKIPPED
|
||||
assert edges["edge1"].state == NodeState.SKIPPED
|
||||
assert edges["edge2"].state == NodeState.UNKNOWN
|
||||
assert edges["edge3"].state == NodeState.SKIPPED
|
||||
|
||||
def test_convergent_paths(self):
|
||||
"""Test convergent paths where multiple inactive branches lead to same node."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"root3": create_mock_node("root3", NodeExecutionType.ROOT),
|
||||
"mid1": create_mock_node("mid1", NodeExecutionType.EXECUTABLE),
|
||||
"mid2": create_mock_node("mid2", NodeExecutionType.EXECUTABLE),
|
||||
"convergent": create_mock_node("convergent", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="mid1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="mid2", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="root3", head="convergent", source_handle="source"),
|
||||
"edge4": Edge(id="edge4", tail="mid1", head="convergent", source_handle="source"),
|
||||
"edge5": Edge(id="edge5", tail="mid2", head="convergent", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"mid1": ["edge1"],
|
||||
"mid2": ["edge2"],
|
||||
"convergent": ["edge3", "edge4", "edge5"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"root3": ["edge3"],
|
||||
"mid1": ["edge4"],
|
||||
"mid2": ["edge5"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["root3"].state == NodeState.SKIPPED
|
||||
assert nodes["mid1"].state == NodeState.UNKNOWN
|
||||
assert nodes["mid2"].state == NodeState.SKIPPED
|
||||
assert nodes["convergent"].state == NodeState.UNKNOWN # Not skipped due to active path from root1
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
assert edges["edge3"].state == NodeState.SKIPPED
|
||||
assert edges["edge4"].state == NodeState.UNKNOWN
|
||||
assert edges["edge5"].state == NodeState.SKIPPED
|
||||
@@ -0,0 +1,59 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
|
||||
def _make_node(node_id: str, node_type: NodeType = NodeType.START) -> Node:
|
||||
node = MagicMock(spec=Node)
|
||||
node.id = node_id
|
||||
node.node_type = node_type
|
||||
node.execution_type = None # attribute not used in builder path
|
||||
return node
|
||||
|
||||
|
||||
def test_graph_builder_creates_linear_graph():
|
||||
builder = Graph.new()
|
||||
root = _make_node("root", NodeType.START)
|
||||
mid = _make_node("mid", NodeType.LLM)
|
||||
end = _make_node("end", NodeType.END)
|
||||
|
||||
graph = builder.add_root(root).add_node(mid).add_node(end).build()
|
||||
|
||||
assert graph.root_node is root
|
||||
assert graph.nodes == {"root": root, "mid": mid, "end": end}
|
||||
assert len(graph.edges) == 2
|
||||
first_edge = next(iter(graph.edges.values()))
|
||||
assert first_edge.tail == "root"
|
||||
assert first_edge.head == "mid"
|
||||
assert graph.out_edges["mid"] == [edge_id for edge_id, edge in graph.edges.items() if edge.tail == "mid"]
|
||||
|
||||
|
||||
def test_graph_builder_supports_custom_predecessor():
|
||||
builder = Graph.new()
|
||||
root = _make_node("root")
|
||||
branch = _make_node("branch")
|
||||
other = _make_node("other")
|
||||
|
||||
graph = builder.add_root(root).add_node(branch).add_node(other, from_node_id="root").build()
|
||||
|
||||
outgoing_root = graph.out_edges["root"]
|
||||
assert len(outgoing_root) == 2
|
||||
edge_targets = {graph.edges[eid].head for eid in outgoing_root}
|
||||
assert edge_targets == {"branch", "other"}
|
||||
|
||||
|
||||
def test_graph_builder_validates_usage():
|
||||
builder = Graph.new()
|
||||
node = _make_node("node")
|
||||
|
||||
with pytest.raises(ValueError, match="Root node"):
|
||||
builder.add_node(node)
|
||||
|
||||
builder.add_root(node)
|
||||
duplicate = _make_node("node")
|
||||
with pytest.raises(ValueError, match="Duplicate"):
|
||||
builder.add_node(duplicate)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user