This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,885 @@
import copy
import pytest
from faker import Faker
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_CONTEXT,
CHAT_APP_CHAT_PROMPT_CONFIG,
CHAT_APP_COMPLETION_PROMPT_CONFIG,
COMPLETION_APP_CHAT_PROMPT_CONFIG,
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
CONTEXT,
)
from models.model import AppMode
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class TestAdvancedPromptTemplateService:
"""Integration tests for AdvancedPromptTemplateService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
# This service doesn't have external dependencies, but we keep the pattern
# for consistency with other test files
return {}
def test_get_prompt_baichuan_model_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful prompt generation for Baichuan model.
This test verifies:
- Proper prompt generation for Baichuan models
- Correct model detection logic
- Appropriate prompt template selection
"""
fake = Faker()
# Test data for Baichuan model
args = {
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify context is included for Baichuan model
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_prompt_common_model_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful prompt generation for common models.
This test verifies:
- Proper prompt generation for non-Baichuan models
- Correct model detection logic
- Appropriate prompt template selection
"""
fake = Faker()
# Test data for common model
args = {
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify context is included for common model
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_prompt_case_insensitive_baichuan_detection(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan model detection is case insensitive.
This test verifies:
- Model name detection works regardless of case
- Proper prompt template selection for different case variations
"""
fake = Faker()
# Test different case variations
test_cases = ["Baichuan-13B-Chat", "BAICHUAN-13B-CHAT", "baichuan-13b-chat", "BaiChuan-13B-Chat"]
for model_name in test_cases:
args = {
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": model_name,
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify Baichuan template is used
assert result is not None
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT in prompt_text
def test_get_common_prompt_chat_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation for chat app with completion mode.
This test verifies:
- Correct prompt template selection for chat app + completion mode
- Proper context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
assert "conversation_histories_role" in result["completion_prompt_config"]
assert "stop" in result
# Verify context is included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_common_prompt_chat_app_chat_mode(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test common prompt generation for chat app with chat mode.
This test verifies:
- Correct prompt template selection for chat app + chat mode
- Proper context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "role" in result["chat_prompt_config"]["prompt"][0]
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify context is included
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_completion_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation for completion app with completion mode.
This test verifies:
- Correct prompt template selection for completion app + completion mode
- Proper context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
assert "stop" in result
# Verify context is included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_completion_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation for completion app with chat mode.
This test verifies:
- Correct prompt template selection for completion app + chat mode
- Proper context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "role" in result["chat_prompt_config"]["prompt"][0]
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify context is included
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test common prompt generation without context.
This test verifies:
- Correct handling when has_context is "false"
- Context is not included in prompt
- Template structure remains intact
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "false")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify context is NOT included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert CONTEXT not in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_common_prompt_unsupported_app_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation with unsupported app mode.
This test verifies:
- Proper handling of unsupported app modes
- Default empty dict return
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt("unsupported_mode", "completion", "true")
# Assert: Verify empty dict is returned
assert result == {}
def test_get_common_prompt_unsupported_model_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation with unsupported model mode.
This test verifies:
- Proper handling of unsupported model modes
- Default empty dict return
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "unsupported_mode", "true")
# Assert: Verify empty dict is returned
assert result == {}
def test_get_completion_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test completion prompt generation with context.
This test verifies:
- Proper context integration in completion prompts
- Template structure preservation
- Context placement at the beginning
"""
fake = Faker()
# Create test prompt template
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "true", CONTEXT)
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify context is prepended to original text
result_text = result["completion_prompt_config"]["prompt"]["text"]
assert result_text.startswith(CONTEXT)
assert original_text in result_text
assert result_text == CONTEXT + original_text
def test_get_completion_prompt_without_context(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test completion prompt generation without context.
This test verifies:
- Original template is preserved when no context
- No modification to prompt text
"""
fake = Faker()
# Create test prompt template
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT)
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify original text is unchanged
result_text = result["completion_prompt_config"]["prompt"]["text"]
assert result_text == original_text
assert CONTEXT not in result_text
def test_get_chat_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test chat prompt generation with context.
This test verifies:
- Proper context integration in chat prompts
- Template structure preservation
- Context placement at the beginning of first message
"""
fake = Faker()
# Create test prompt template
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "true", CONTEXT)
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify context is prepended to original text
result_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert result_text.startswith(CONTEXT)
assert original_text in result_text
assert result_text == CONTEXT + original_text
def test_get_chat_prompt_without_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test chat prompt generation without context.
This test verifies:
- Original template is preserved when no context
- No modification to prompt text
"""
fake = Faker()
# Create test prompt template
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT)
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify original text is unchanged
result_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert result_text == original_text
assert CONTEXT not in result_text
def test_get_baichuan_prompt_chat_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for chat app with completion mode.
This test verifies:
- Correct Baichuan prompt template selection for chat app + completion mode
- Proper Baichuan context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
assert "conversation_histories_role" in result["completion_prompt_config"]
assert "stop" in result
# Verify Baichuan context is included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_baichuan_prompt_chat_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for chat app with chat mode.
This test verifies:
- Correct Baichuan prompt template selection for chat app + chat mode
- Proper Baichuan context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "role" in result["chat_prompt_config"]["prompt"][0]
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify Baichuan context is included
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_completion_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for completion app with completion mode.
This test verifies:
- Correct Baichuan prompt template selection for completion app + completion mode
- Proper Baichuan context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
assert "stop" in result
# Verify Baichuan context is included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_completion_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for completion app with chat mode.
This test verifies:
- Correct Baichuan prompt template selection for completion app + chat mode
- Proper Baichuan context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "role" in result["chat_prompt_config"]["prompt"][0]
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify Baichuan context is included
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test Baichuan prompt generation without context.
This test verifies:
- Correct handling when has_context is "false"
- Baichuan context is not included in prompt
- Template structure remains intact
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify Baichuan context is NOT included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT not in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_baichuan_prompt_unsupported_app_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation with unsupported app mode.
This test verifies:
- Proper handling of unsupported app modes
- Default empty dict return
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt("unsupported_mode", "completion", "true")
# Assert: Verify empty dict is returned
assert result == {}
def test_get_baichuan_prompt_unsupported_model_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation with unsupported model mode.
This test verifies:
- Proper handling of unsupported model modes
- Default empty dict return
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "unsupported_mode", "true")
# Assert: Verify empty dict is returned
assert result == {}
def test_get_prompt_all_app_modes_common_model(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test prompt generation for all app modes with common model.
This test verifies:
- All app modes work correctly with common models
- Proper template selection for each combination
"""
fake = Faker()
# Test all app modes
app_modes = [AppMode.CHAT, AppMode.COMPLETION]
model_modes = ["completion", "chat"]
for app_mode in app_modes:
for model_mode in model_modes:
args = {
"app_mode": app_mode,
"model_mode": model_mode,
"model_name": "gpt-3.5-turbo",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify result is not empty
assert result is not None
assert result != {}
def test_get_prompt_all_app_modes_baichuan_model(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test prompt generation for all app modes with Baichuan model.
This test verifies:
- All app modes work correctly with Baichuan models
- Proper template selection for each combination
"""
fake = Faker()
# Test all app modes
app_modes = [AppMode.CHAT, AppMode.COMPLETION]
model_modes = ["completion", "chat"]
for app_mode in app_modes:
for model_mode in model_modes:
args = {
"app_mode": app_mode,
"model_mode": model_mode,
"model_name": "baichuan-13b-chat",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify result is not empty
assert result is not None
assert result != {}
def test_get_prompt_edge_cases(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test prompt generation with edge cases.
This test verifies:
- Handling of edge case inputs
- Proper error handling
- Consistent behavior with unusual inputs
"""
fake = Faker()
# Test edge cases
edge_cases = [
{"app_mode": "", "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true"},
{"app_mode": AppMode.CHAT, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"},
{"app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "", "has_context": "true"},
{
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "",
},
]
for args in edge_cases:
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify method handles edge cases gracefully
# Should either return a valid result or empty dict, but not crash
assert result is not None
def test_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test that original templates are not modified.
This test verifies:
- Original template constants are not modified
- Deep copy is used properly
- Template immutability is maintained
"""
fake = Faker()
# Store original templates
original_chat_completion = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_chat_chat = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
original_completion_completion = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
original_completion_chat = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG)
# Test with context
args = {
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify original templates are unchanged
assert original_chat_completion == CHAT_APP_COMPLETION_PROMPT_CONFIG
assert original_chat_chat == CHAT_APP_CHAT_PROMPT_CONFIG
assert original_completion_completion == COMPLETION_APP_COMPLETION_PROMPT_CONFIG
assert original_completion_chat == COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_baichuan_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test that original Baichuan templates are not modified.
This test verifies:
- Original Baichuan template constants are not modified
- Deep copy is used properly
- Template immutability is maintained
"""
fake = Faker()
# Store original templates
original_baichuan_chat_completion = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_baichuan_chat_chat = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG)
original_baichuan_completion_completion = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
original_baichuan_completion_chat = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG)
# Test with context
args = {
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify original templates are unchanged
assert original_baichuan_chat_completion == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
assert original_baichuan_chat_chat == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
assert original_baichuan_completion_completion == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
assert original_baichuan_completion_chat == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_context_integration_consistency(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test consistency of context integration across different scenarios.
This test verifies:
- Context is always prepended correctly
- Context integration is consistent across different templates
- No context duplication or corruption
"""
fake = Faker()
# Test different scenarios
test_scenarios = [
{
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.CHAT,
"model_mode": "chat",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION,
"model_mode": "chat",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
]
for args in test_scenarios:
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify context integration is consistent
assert result is not None
assert result != {}
# Check that context is properly integrated
if "completion_prompt_config" in result:
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert prompt_text.startswith(CONTEXT)
elif "chat_prompt_config" in result:
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert prompt_text.startswith(CONTEXT)
def test_baichuan_context_integration_consistency(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test consistency of Baichuan context integration across different scenarios.
This test verifies:
- Baichuan context is always prepended correctly
- Context integration is consistent across different templates
- No context duplication or corruption
"""
fake = Faker()
# Test different scenarios
test_scenarios = [
{
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.CHAT,
"model_mode": "chat",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION,
"model_mode": "chat",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
]
for args in test_scenarios:
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify context integration is consistent
assert result is not None
assert result != {}
# Check that Baichuan context is properly integrated
if "completion_prompt_config" in result:
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert prompt_text.startswith(BAICHUAN_CONTEXT)
elif "chat_prompt_config" in result:
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert prompt_text.startswith(BAICHUAN_CONTEXT)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,505 @@
from unittest.mock import patch
import pytest
from faker import Faker
from models.api_based_extension import APIBasedExtension
from services.account_service import AccountService, TenantService
from services.api_based_extension_service import APIBasedExtensionService
class TestAPIBasedExtensionService:
"""Integration tests for APIBasedExtensionService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.account_service.FeatureService") as mock_account_feature_service,
patch("services.api_based_extension_service.APIBasedExtensionRequestor") as mock_requestor,
):
# Setup default mock returns
mock_account_feature_service.get_features.return_value.billing.enabled = False
# Mock successful ping response
mock_requestor_instance = mock_requestor.return_value
mock_requestor_instance.request.return_value = {"result": "pong"}
yield {
"account_feature_service": mock_account_feature_service,
"requestor": mock_requestor,
"requestor_instance": mock_requestor_instance,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True
# Create account and tenant
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
return account, tenant
def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful saving of API-based extension.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Setup extension data
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
# Save extension
saved_extension = APIBasedExtensionService.save(extension_data)
# Verify extension was saved correctly
assert saved_extension.id is not None
assert saved_extension.tenant_id == tenant.id
assert saved_extension.name == extension_data.name
assert saved_extension.api_endpoint == extension_data.api_endpoint
assert saved_extension.api_key == extension_data.api_key # Should be decrypted when retrieved
assert saved_extension.created_at is not None
# Verify extension was saved to database
from extensions.ext_database import db
db.session.refresh(saved_extension)
assert saved_extension.id is not None
# Verify ping connection was called
mock_external_service_dependencies["requestor_instance"].request.assert_called_once()
def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test validation errors when saving extension with invalid data.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Test empty name
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name="",
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
with pytest.raises(ValueError, match="name must not be empty"):
APIBasedExtensionService.save(extension_data)
# Test empty api_endpoint
extension_data.name = fake.company()
extension_data.api_endpoint = ""
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
APIBasedExtensionService.save(extension_data)
# Test empty api_key
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = ""
with pytest.raises(ValueError, match="api_key must not be empty"):
APIBasedExtensionService.save(extension_data)
def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of all extensions by tenant ID.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Create multiple extensions
extensions = []
assert tenant is not None
for i in range(3):
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=f"Extension {i}: {fake.company()}",
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
saved_extension = APIBasedExtensionService.save(extension_data)
extensions.append(saved_extension)
# Get all extensions for tenant
extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id)
# Verify results
assert len(extension_list) == 3
# Verify all extensions belong to the correct tenant and are ordered by created_at desc
for i, extension in enumerate(extension_list):
assert extension.tenant_id == tenant.id
assert extension.api_key is not None # Should be decrypted
if i > 0:
# Verify descending order (newer first)
assert extension.created_at <= extension_list[i - 1].created_at
def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of extension by tenant ID and extension ID.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Create an extension
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
created_extension = APIBasedExtensionService.save(extension_data)
# Get extension by ID
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id)
# Verify extension was retrieved correctly
assert retrieved_extension is not None
assert retrieved_extension.id == created_extension.id
assert retrieved_extension.tenant_id == tenant.id
assert retrieved_extension.name == extension_data.name
assert retrieved_extension.api_endpoint == extension_data.api_endpoint
assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted
assert retrieved_extension.created_at is not None
def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test retrieval of extension when extension is not found.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
non_existent_extension_id = fake.uuid4()
# Try to get non-existent extension
with pytest.raises(ValueError, match="API based extension is not found"):
APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id)
def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful deletion of extension.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Create an extension first
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
created_extension = APIBasedExtensionService.save(extension_data)
extension_id = created_extension.id
# Delete the extension
APIBasedExtensionService.delete(created_extension)
# Verify extension was deleted
from extensions.ext_database import db
deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first()
assert deleted_extension is None
def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test validation error when saving extension with duplicate name.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Create first extension
extension_data1 = APIBasedExtension(
tenant_id=tenant.id,
name="Test Extension",
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
APIBasedExtensionService.save(extension_data1)
# Try to create second extension with same name
extension_data2 = APIBasedExtension(
tenant_id=tenant.id,
name="Test Extension", # Same name
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
APIBasedExtensionService.save(extension_data2)
def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful update of existing extension.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Create initial extension
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
created_extension = APIBasedExtensionService.save(extension_data)
# Save original values for later comparison
original_name = created_extension.name
original_endpoint = created_extension.api_endpoint
# Update the extension with guaranteed different values
new_name = fake.company()
# Ensure new endpoint is different from original
new_endpoint = f"https://{fake.domain_name()}/api"
# If by chance they're the same, generate a new one
while new_endpoint == original_endpoint:
new_endpoint = f"https://{fake.domain_name()}/api"
new_api_key = fake.password(length=20)
created_extension.name = new_name
created_extension.api_endpoint = new_endpoint
created_extension.api_key = new_api_key
updated_extension = APIBasedExtensionService.save(created_extension)
# Verify extension was updated correctly
assert updated_extension.id == created_extension.id
assert updated_extension.tenant_id == tenant.id
assert updated_extension.name == new_name
assert updated_extension.api_endpoint == new_endpoint
# Verify original values were changed
assert updated_extension.name != original_name
assert updated_extension.api_endpoint != original_endpoint
# Verify ping connection was called for both create and update
assert mock_external_service_dependencies["requestor_instance"].request.call_count == 2
# Verify the update by retrieving the extension again
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id)
assert retrieved_extension.name == new_name
assert retrieved_extension.api_endpoint == new_endpoint
assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved
def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test connection error when saving extension with invalid endpoint.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Mock connection error
mock_external_service_dependencies["requestor_instance"].request.side_effect = ValueError(
"connection error: request timeout"
)
assert tenant is not None
# Setup extension data
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint="https://invalid-endpoint.com/api",
api_key=fake.password(length=20),
)
# Try to save extension with connection error
with pytest.raises(ValueError, match="connection error: request timeout"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_invalid_api_key_length(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test validation error when saving extension with API key that is too short.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Setup extension data with short API key
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key="1234", # Less than 5 characters
)
# Try to save extension with short API key
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test validation errors when saving extension with empty required fields.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant is not None
# Test with None values
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=None, # type: ignore # why str become None here???
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
with pytest.raises(ValueError, match="name must not be empty"):
APIBasedExtensionService.save(extension_data)
# Test with None api_endpoint
extension_data.name = fake.company()
extension_data.api_endpoint = None
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
APIBasedExtensionService.save(extension_data)
# Test with None api_key
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = None
with pytest.raises(ValueError, match="api_key must not be empty"):
APIBasedExtensionService.save(extension_data)
def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test retrieval of extensions when no extensions exist for tenant.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Get all extensions for tenant (none exist)
extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id)
# Verify empty list is returned
assert len(extension_list) == 0
assert extension_list == []
def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test validation error when ping response is invalid.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Mock invalid ping response
mock_external_service_dependencies["requestor_instance"].request.return_value = {"result": "invalid"}
assert tenant is not None
# Setup extension data
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
# Try to save extension with invalid ping response
with pytest.raises(ValueError, match="{'result': 'invalid'}"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test validation error when ping response is missing result field.
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Mock ping response without result field
mock_external_service_dependencies["requestor_instance"].request.return_value = {"status": "ok"}
assert tenant is not None
# Setup extension data
extension_data = APIBasedExtension(
tenant_id=tenant.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
# Try to save extension with missing ping result
with pytest.raises(ValueError, match="{'status': 'ok'}"):
APIBasedExtensionService.save(extension_data)
def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test retrieval of extension when tenant ID doesn't match.
"""
fake = Faker()
account1, tenant1 = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Create second account and tenant
account2, tenant2 = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
assert tenant1 is not None
# Create extension in first tenant
extension_data = APIBasedExtension(
tenant_id=tenant1.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
created_extension = APIBasedExtensionService.save(extension_data)
# Try to get extension with wrong tenant ID
with pytest.raises(ValueError, match="API based extension is not found"):
APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id)

View File

@@ -0,0 +1,432 @@
import json
from unittest.mock import MagicMock, patch
import pytest
import yaml
from faker import Faker
from models.model import App, AppModelConfig
from services.account_service import AccountService, TenantService
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
from services.app_service import AppService
class TestAppDslService:
"""Integration tests for AppDslService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.app_dsl_service.WorkflowService") as mock_workflow_service,
patch("services.app_dsl_service.DependenciesAnalysisService") as mock_dependencies_service,
patch("services.app_dsl_service.WorkflowDraftVariableService") as mock_draft_variable_service,
patch("services.app_dsl_service.ssrf_proxy") as mock_ssrf_proxy,
patch("services.app_dsl_service.redis_client") as mock_redis_client,
patch("services.app_dsl_service.app_was_created") as mock_app_was_created,
patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated,
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
):
# Setup default mock returns
mock_workflow_service.return_value.get_draft_workflow.return_value = None
mock_workflow_service.return_value.sync_draft_workflow.return_value = MagicMock()
mock_dependencies_service.generate_latest_dependencies.return_value = []
mock_dependencies_service.get_leaked_dependencies.return_value = []
mock_dependencies_service.generate_dependencies.return_value = []
mock_draft_variable_service.return_value.delete_workflow_variables.return_value = None
mock_ssrf_proxy.get.return_value.content = b"test content"
mock_ssrf_proxy.get.return_value.raise_for_status.return_value = None
mock_redis_client.setex.return_value = None
mock_redis_client.get.return_value = None
mock_redis_client.delete.return_value = None
mock_app_was_created.send.return_value = None
mock_app_model_config_was_updated.send.return_value = None
# Mock ModelManager for app service
mock_model_instance = mock_model_manager.return_value
mock_model_instance.get_default_model_instance.return_value = None
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
# Mock FeatureService and EnterpriseService
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
yield {
"workflow_service": mock_workflow_service,
"dependencies_service": mock_dependencies_service,
"draft_variable_service": mock_draft_variable_service,
"ssrf_proxy": mock_ssrf_proxy,
"redis_client": mock_redis_client,
"app_was_created": mock_app_was_created,
"app_model_config_was_updated": mock_app_model_config_was_updated,
"model_manager": mock_model_manager,
"feature_service": mock_feature_service,
"enterprise_service": mock_enterprise_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (app, account) - Created app and account instances
"""
fake = Faker()
# Setup mocks for account creation
with patch("services.account_service.FeatureService") as mock_account_feature_service:
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Setup app creation arguments
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
# Create app
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
return app, account
def _create_simple_yaml_content(self, app_name="Test App", app_mode="chat"):
"""
Helper method to create simple YAML content for testing.
"""
yaml_data = {
"version": "0.3.0",
"kind": "app",
"app": {
"name": app_name,
"mode": app_mode,
"icon": "🤖",
"icon_background": "#FFEAD5",
"description": "Test app description",
"use_icon_as_answer_icon": False,
},
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 1000,
"temperature": 0.7,
"top_p": 1.0,
},
},
"pre_prompt": "You are a helpful assistant.",
"prompt_type": "simple",
},
}
return yaml.dump(yaml_data, allow_unicode=True)
def test_import_app_missing_yaml_content(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app import with missing YAML content.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Import app without YAML content
dsl_service = AppDslService(db_session_with_containers)
result = dsl_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT,
name="Missing Content App",
)
# Verify import failed
assert result.status == ImportStatus.FAILED
assert result.app_id is None
assert "yaml_content is required" in result.error
assert result.imported_dsl_version == ""
# Verify no app was created in database
apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count()
assert apps_count == 1 # Only the original test app
def test_import_app_missing_yaml_url(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app import with missing YAML URL.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Import app without YAML URL
dsl_service = AppDslService(db_session_with_containers)
result = dsl_service.import_app(
account=account,
import_mode=ImportMode.YAML_URL,
name="Missing URL App",
)
# Verify import failed
assert result.status == ImportStatus.FAILED
assert result.app_id is None
assert "yaml_url is required" in result.error
assert result.imported_dsl_version == ""
# Verify no app was created in database
apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count()
assert apps_count == 1 # Only the original test app
def test_import_app_invalid_import_mode(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app import with invalid import mode.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create YAML content
yaml_content = self._create_simple_yaml_content(fake.company(), "chat")
# Import app with invalid mode should raise ValueError
dsl_service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="Invalid import_mode: invalid-mode"):
dsl_service.import_app(
account=account,
import_mode="invalid-mode",
yaml_content=yaml_content,
name="Invalid Mode App",
)
# Verify no app was created in database
apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count()
assert apps_count == 1 # Only the original test app
def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful DSL export for chat app.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create model config for the app
model_config = AppModelConfig()
model_config.id = fake.uuid4()
model_config.app_id = app.id
model_config.provider = "openai"
model_config.model_id = "gpt-3.5-turbo"
model_config.model = json.dumps(
{
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 1000,
"temperature": 0.7,
},
}
)
model_config.pre_prompt = "You are a helpful assistant."
model_config.prompt_type = "simple"
model_config.created_by = account.id
model_config.updated_by = account.id
# Set the app_model_config_id to link the config
app.app_model_config_id = model_config.id
db_session_with_containers.add(model_config)
db_session_with_containers.commit()
# Export DSL
exported_dsl = AppDslService.export_dsl(app, include_secret=False)
# Parse exported YAML
exported_data = yaml.safe_load(exported_dsl)
# Verify exported data structure
assert exported_data["kind"] == "app"
assert exported_data["app"]["name"] == app.name
assert exported_data["app"]["mode"] == app.mode
assert exported_data["app"]["icon"] == app.icon
assert exported_data["app"]["icon_background"] == app.icon_background
assert exported_data["app"]["description"] == app.description
# Verify model config was exported
assert "model_config" in exported_data
# The exported model_config structure may be different from the database structure
# Check that the model config exists and has the expected content
assert exported_data["model_config"] is not None
# Verify dependencies were exported
assert "dependencies" in exported_data
assert isinstance(exported_data["dependencies"], list)
def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful DSL export for workflow app.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Update app to workflow mode
app.mode = "workflow"
db_session_with_containers.commit()
# Mock workflow service to return a workflow
mock_workflow = MagicMock()
mock_workflow.to_dict.return_value = {
"graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []},
"features": {},
"environment_variables": [],
"conversation_variables": [],
}
mock_external_service_dependencies[
"workflow_service"
].return_value.get_draft_workflow.return_value = mock_workflow
# Export DSL
exported_dsl = AppDslService.export_dsl(app, include_secret=False)
# Parse exported YAML
exported_data = yaml.safe_load(exported_dsl)
# Verify exported data structure
assert exported_data["kind"] == "app"
assert exported_data["app"]["name"] == app.name
assert exported_data["app"]["mode"] == "workflow"
# Verify workflow was exported
assert "workflow" in exported_data
assert "graph" in exported_data["workflow"]
assert "nodes" in exported_data["workflow"]["graph"]
# Verify dependencies were exported
assert "dependencies" in exported_data
assert isinstance(exported_data["dependencies"], list)
# Verify workflow service was called
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with(
app, None
)
def test_export_dsl_with_workflow_id_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful DSL export with specific workflow ID.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Update app to workflow mode
app.mode = "workflow"
db_session_with_containers.commit()
# Mock workflow service to return a workflow when specific workflow_id is provided
mock_workflow = MagicMock()
mock_workflow.to_dict.return_value = {
"graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []},
"features": {},
"environment_variables": [],
"conversation_variables": [],
}
# Mock the get_draft_workflow method to return different workflows based on workflow_id
def mock_get_draft_workflow(app_model, workflow_id=None):
if workflow_id == "specific-workflow-id":
return mock_workflow
return None
mock_external_service_dependencies[
"workflow_service"
].return_value.get_draft_workflow.side_effect = mock_get_draft_workflow
# Export DSL with specific workflow ID
exported_dsl = AppDslService.export_dsl(app, include_secret=False, workflow_id="specific-workflow-id")
# Parse exported YAML
exported_data = yaml.safe_load(exported_dsl)
# Verify exported data structure
assert exported_data["kind"] == "app"
assert exported_data["app"]["name"] == app.name
assert exported_data["app"]["mode"] == "workflow"
# Verify workflow was exported
assert "workflow" in exported_data
assert "graph" in exported_data["workflow"]
assert "nodes" in exported_data["workflow"]["graph"]
# Verify dependencies were exported
assert "dependencies" in exported_data
assert isinstance(exported_data["dependencies"], list)
# Verify workflow service was called with specific workflow ID
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with(
app, "specific-workflow-id"
)
def test_export_dsl_with_invalid_workflow_id_raises_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that export_dsl raises error when invalid workflow ID is provided.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Update app to workflow mode
app.mode = "workflow"
db_session_with_containers.commit()
# Mock workflow service to return None when invalid workflow ID is provided
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.return_value = None
# Export DSL with invalid workflow ID should raise ValueError
with pytest.raises(ValueError, match="Missing draft workflow configuration, please check."):
AppDslService.export_dsl(app, include_secret=False, workflow_id="invalid-workflow-id")
# Verify workflow service was called with the invalid workflow ID
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with(
app, "invalid-workflow-id"
)
def test_check_dependencies_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful dependency checking.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Mock Redis to return dependencies
mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}'
mock_external_service_dependencies["redis_client"].get.return_value = mock_dependencies_json
# Check dependencies
dsl_service = AppDslService(db_session_with_containers)
result = dsl_service.check_dependencies(app_model=app)
# Verify result
assert result.leaked_dependencies == []
# Verify Redis was queried
mock_external_service_dependencies["redis_client"].get.assert_called_once_with(
f"app_check_dependencies:{app.id}"
)
# Verify dependencies service was called
mock_external_service_dependencies["dependencies_service"].get_leaked_dependencies.assert_called_once()

View File

@@ -0,0 +1,982 @@
import uuid
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.app.entities.app_invoke_entities import InvokeFrom
from models.model import EndUser
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
class TestAppGenerateService:
"""Integration tests for AppGenerateService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.billing_service.BillingService") as mock_billing_service,
patch("services.app_generate_service.WorkflowService") as mock_workflow_service,
patch("services.app_generate_service.RateLimit") as mock_rate_limit,
patch("services.app_generate_service.CompletionAppGenerator") as mock_completion_generator,
patch("services.app_generate_service.ChatAppGenerator") as mock_chat_generator,
patch("services.app_generate_service.AgentChatAppGenerator") as mock_agent_chat_generator,
patch("services.app_generate_service.AdvancedChatAppGenerator") as mock_advanced_chat_generator,
patch("services.app_generate_service.WorkflowAppGenerator") as mock_workflow_generator,
patch("services.account_service.FeatureService") as mock_account_feature_service,
patch("services.app_generate_service.dify_config") as mock_dify_config,
patch("configs.dify_config") as mock_global_dify_config,
):
# Setup default mock returns for billing service
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
"result": "success",
"history_id": "test_history_id",
}
# Setup default mock returns for workflow service
mock_workflow_service_instance = mock_workflow_service.return_value
mock_workflow_service_instance.get_published_workflow.return_value = MagicMock(spec=Workflow)
mock_workflow_service_instance.get_draft_workflow.return_value = MagicMock(spec=Workflow)
mock_workflow_service_instance.get_published_workflow_by_id.return_value = MagicMock(spec=Workflow)
# Setup default mock returns for rate limiting
mock_rate_limit_instance = mock_rate_limit.return_value
mock_rate_limit_instance.enter.return_value = "test_request_id"
mock_rate_limit_instance.generate.return_value = ["test_response"]
mock_rate_limit_instance.exit.return_value = None
# Setup default mock returns for app generators
mock_completion_generator_instance = mock_completion_generator.return_value
mock_completion_generator_instance.generate.return_value = ["completion_response"]
mock_completion_generator_instance.generate_more_like_this.return_value = ["more_like_this_response"]
mock_completion_generator.convert_to_event_stream.return_value = ["completion_stream"]
mock_chat_generator_instance = mock_chat_generator.return_value
mock_chat_generator_instance.generate.return_value = ["chat_response"]
mock_chat_generator.convert_to_event_stream.return_value = ["chat_stream"]
mock_agent_chat_generator_instance = mock_agent_chat_generator.return_value
mock_agent_chat_generator_instance.generate.return_value = ["agent_chat_response"]
mock_agent_chat_generator.convert_to_event_stream.return_value = ["agent_chat_stream"]
mock_advanced_chat_generator_instance = mock_advanced_chat_generator.return_value
mock_advanced_chat_generator_instance.generate.return_value = ["advanced_chat_response"]
mock_advanced_chat_generator_instance.single_iteration_generate.return_value = ["single_iteration_response"]
mock_advanced_chat_generator_instance.single_loop_generate.return_value = ["single_loop_response"]
mock_advanced_chat_generator.convert_to_event_stream.return_value = ["advanced_chat_stream"]
mock_workflow_generator_instance = mock_workflow_generator.return_value
mock_workflow_generator_instance.generate.return_value = ["workflow_response"]
mock_workflow_generator_instance.single_iteration_generate.return_value = [
"workflow_single_iteration_response"
]
mock_workflow_generator_instance.single_loop_generate.return_value = ["workflow_single_loop_response"]
mock_workflow_generator.convert_to_event_stream.return_value = ["workflow_stream"]
# Setup default mock returns for account service
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
# Setup dify_config mock returns
mock_dify_config.BILLING_ENABLED = False
mock_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
mock_global_dify_config.BILLING_ENABLED = False
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
yield {
"billing_service": mock_billing_service,
"workflow_service": mock_workflow_service,
"rate_limit": mock_rate_limit,
"completion_generator": mock_completion_generator,
"chat_generator": mock_chat_generator,
"agent_chat_generator": mock_agent_chat_generator,
"advanced_chat_generator": mock_advanced_chat_generator,
"workflow_generator": mock_workflow_generator,
"account_feature_service": mock_account_feature_service,
"dify_config": mock_dify_config,
"global_dify_config": mock_global_dify_config,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies, mode="chat"):
"""
Helper method to create a test app and account for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
mode: App mode to create
Returns:
tuple: (app, account) - Created app and account instances
"""
fake = Faker()
# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True
# Create account and tenant
from services.account_service import AccountService, TenantService
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": mode,
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
"max_active_requests": 5,
}
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
return app, account
def _create_test_workflow(self, db_session_with_containers, app):
"""
Helper method to create a test workflow for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
app: App instance
Returns:
Workflow: Created workflow instance
"""
fake = Faker()
workflow = Workflow(
id=str(uuid.uuid4()),
app_id=app.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
type="workflow",
status="published",
)
from extensions.ext_database import db
db.session.add(workflow)
db.session.commit()
return workflow
def test_generate_completion_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful generation for completion mode app.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="completion"
)
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test
result = AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify the result
assert result == ["test_response"]
# Verify rate limiting was called
mock_external_service_dependencies["rate_limit"].return_value.enter.assert_called_once()
mock_external_service_dependencies["rate_limit"].return_value.generate.assert_called_once()
# Verify completion generator was called
mock_external_service_dependencies["completion_generator"].return_value.generate.assert_called_once()
mock_external_service_dependencies["completion_generator"].convert_to_event_stream.assert_called_once()
def test_generate_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful generation for chat mode app.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="chat"
)
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test
result = AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify the result
assert result == ["test_response"]
# Verify chat generator was called
mock_external_service_dependencies["chat_generator"].return_value.generate.assert_called_once()
mock_external_service_dependencies["chat_generator"].convert_to_event_stream.assert_called_once()
def test_generate_agent_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful generation for agent chat mode app.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="agent-chat"
)
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test
result = AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify the result
assert result == ["test_response"]
# Verify agent chat generator was called
mock_external_service_dependencies["agent_chat_generator"].return_value.generate.assert_called_once()
mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once()
def test_generate_advanced_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful generation for advanced chat mode app.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
)
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test
result = AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify the result
assert result == ["test_response"]
# Verify advanced chat generator was called
mock_external_service_dependencies["advanced_chat_generator"].return_value.generate.assert_called_once()
mock_external_service_dependencies["advanced_chat_generator"].convert_to_event_stream.assert_called_once()
def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful generation for workflow mode app.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="workflow"
)
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test
result = AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify the result
assert result == ["test_response"]
# Verify workflow generator was called
mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once()
mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once()
def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test generation with a specific workflow ID.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
)
workflow_id = str(uuid.uuid4())
# Setup test arguments
args = {
"inputs": {"query": fake.text(max_nb_chars=50)},
"workflow_id": workflow_id,
"response_mode": "streaming",
}
# Execute the method under test
result = AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify the result
assert result == ["test_response"]
# Verify workflow service was called with specific workflow ID
mock_external_service_dependencies[
"workflow_service"
].return_value.get_published_workflow_by_id.assert_called_once()
def test_generate_with_debugger_invoke_from(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test generation with debugger invoke from.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
)
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test
result = AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
)
# Verify the result
assert result == ["test_response"]
# Verify draft workflow was fetched for debugger
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once()
def test_generate_with_non_streaming_mode(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test generation with non-streaming mode.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="completion"
)
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "blocking"}
# Execute the method under test
result = AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=False
)
# Verify the result
assert result == ["test_response"]
# Verify rate limit exit was called for non-streaming mode
mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once()
def test_generate_with_end_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test generation with EndUser instead of Account.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="completion"
)
# Create end user
end_user = EndUser(
tenant_id=account.current_tenant.id,
app_id=app.id,
type="normal",
external_user_id=fake.uuid4(),
name=fake.name(),
is_anonymous=False,
session_id=fake.uuid4(),
)
from extensions.ext_database import db
db.session.add(end_user)
db.session.commit()
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test
result = AppGenerateService.generate(
app_model=app, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify the result
assert result == ["test_response"]
def test_generate_with_billing_enabled_sandbox_plan(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test generation with billing enabled and sandbox plan.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="completion"
)
# Set BILLING_ENABLED to True for this test
mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True
mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test
result = AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify the result
assert result == ["test_response"]
# Verify billing service was called to consume quota
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
def test_generate_with_invalid_app_mode(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test generation with invalid app mode.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="chat"
)
# Manually set invalid mode after creation
app.mode = "invalid_mode"
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test and expect ValueError
with pytest.raises(ValueError) as exc_info:
AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify error message
assert "Invalid app mode" in str(exc_info.value)
def test_generate_with_workflow_id_format_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test generation with invalid workflow ID format.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
)
# Setup test arguments with invalid workflow ID
args = {
"inputs": {"query": fake.text(max_nb_chars=50)},
"workflow_id": "invalid_uuid",
"response_mode": "streaming",
}
# Execute the method under test and expect WorkflowIdFormatError
with pytest.raises(WorkflowIdFormatError) as exc_info:
AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify error message
assert "Invalid workflow_id format" in str(exc_info.value)
def test_generate_with_workflow_not_found_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test generation when workflow is not found.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
)
workflow_id = str(uuid.uuid4())
# Setup workflow service to return None (workflow not found)
mock_external_service_dependencies[
"workflow_service"
].return_value.get_published_workflow_by_id.return_value = None
# Setup test arguments
args = {
"inputs": {"query": fake.text(max_nb_chars=50)},
"workflow_id": workflow_id,
"response_mode": "streaming",
}
# Execute the method under test and expect WorkflowNotFoundError
with pytest.raises(WorkflowNotFoundError) as exc_info:
AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify error message
assert f"Workflow not found with id: {workflow_id}" in str(exc_info.value)
def test_generate_with_workflow_not_initialized_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test generation when workflow is not initialized for debugger.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
)
# Setup workflow service to return None (workflow not initialized)
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.return_value = None
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test and expect ValueError
with pytest.raises(ValueError) as exc_info:
AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
)
# Verify error message
assert "Workflow not initialized" in str(exc_info.value)
def test_generate_with_workflow_not_published_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test generation when workflow is not published for non-debugger.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
)
# Setup workflow service to return None (workflow not published)
mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test and expect ValueError
with pytest.raises(ValueError) as exc_info:
AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify error message
assert "Workflow not published" in str(exc_info.value)
def test_generate_single_iteration_advanced_chat_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful single iteration generation for advanced chat mode.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
)
node_id = fake.uuid4()
args = {"inputs": {"query": fake.text(max_nb_chars=50)}}
# Execute the method under test
result = AppGenerateService.generate_single_iteration(
app_model=app, user=account, node_id=node_id, args=args, streaming=True
)
# Verify the result
assert result == ["advanced_chat_stream"]
# Verify advanced chat generator was called
mock_external_service_dependencies[
"advanced_chat_generator"
].return_value.single_iteration_generate.assert_called_once()
def test_generate_single_iteration_workflow_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful single iteration generation for workflow mode.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="workflow"
)
node_id = fake.uuid4()
args = {"inputs": {"query": fake.text(max_nb_chars=50)}}
# Execute the method under test
result = AppGenerateService.generate_single_iteration(
app_model=app, user=account, node_id=node_id, args=args, streaming=True
)
# Verify the result
assert result == ["advanced_chat_stream"]
# Verify workflow generator was called
mock_external_service_dependencies[
"workflow_generator"
].return_value.single_iteration_generate.assert_called_once()
def test_generate_single_iteration_invalid_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test single iteration generation with invalid app mode.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="completion"
)
node_id = fake.uuid4()
args = {"inputs": {"query": fake.text(max_nb_chars=50)}}
# Execute the method under test and expect ValueError
with pytest.raises(ValueError) as exc_info:
AppGenerateService.generate_single_iteration(
app_model=app, user=account, node_id=node_id, args=args, streaming=True
)
# Verify error message
assert "Invalid app mode" in str(exc_info.value)
def test_generate_single_loop_advanced_chat_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful single loop generation for advanced chat mode.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
)
node_id = fake.uuid4()
args = {"inputs": {"query": fake.text(max_nb_chars=50)}}
# Execute the method under test
result = AppGenerateService.generate_single_loop(
app_model=app, user=account, node_id=node_id, args=args, streaming=True
)
# Verify the result
assert result == ["advanced_chat_stream"]
# Verify advanced chat generator was called
mock_external_service_dependencies[
"advanced_chat_generator"
].return_value.single_loop_generate.assert_called_once()
def test_generate_single_loop_workflow_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful single loop generation for workflow mode.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="workflow"
)
node_id = fake.uuid4()
args = {"inputs": {"query": fake.text(max_nb_chars=50)}}
# Execute the method under test
result = AppGenerateService.generate_single_loop(
app_model=app, user=account, node_id=node_id, args=args, streaming=True
)
# Verify the result
assert result == ["advanced_chat_stream"]
# Verify workflow generator was called
mock_external_service_dependencies["workflow_generator"].return_value.single_loop_generate.assert_called_once()
def test_generate_single_loop_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test single loop generation with invalid app mode.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="completion"
)
node_id = fake.uuid4()
args = {"inputs": {"query": fake.text(max_nb_chars=50)}}
# Execute the method under test and expect ValueError
with pytest.raises(ValueError) as exc_info:
AppGenerateService.generate_single_loop(
app_model=app, user=account, node_id=node_id, args=args, streaming=True
)
# Verify error message
assert "Invalid app mode" in str(exc_info.value)
def test_generate_more_like_this_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful more like this generation.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="completion"
)
message_id = fake.uuid4()
# Execute the method under test
result = AppGenerateService.generate_more_like_this(
app_model=app, user=account, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify the result
assert result == ["more_like_this_response"]
# Verify completion generator was called
mock_external_service_dependencies[
"completion_generator"
].return_value.generate_more_like_this.assert_called_once()
def test_generate_more_like_this_with_end_user(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test more like this generation with EndUser.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="completion"
)
# Create end user
end_user = EndUser(
tenant_id=account.current_tenant.id,
app_id=app.id,
type="normal",
external_user_id=fake.uuid4(),
name=fake.name(),
is_anonymous=False,
session_id=fake.uuid4(),
)
from extensions.ext_database import db
db.session.add(end_user)
db.session.commit()
message_id = fake.uuid4()
# Execute the method under test
result = AppGenerateService.generate_more_like_this(
app_model=app, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify the result
assert result == ["more_like_this_response"]
def test_get_max_active_requests_with_app_limit(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting max active requests with app-specific limit.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="completion"
)
# Set app-specific limit
app.max_active_requests = 10
# Execute the method under test
result = AppGenerateService._get_max_active_requests(app)
# Verify the result (should return the smaller value between app limit and config limit)
assert result == 10
def test_get_max_active_requests_with_config_limit(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting max active requests with config limit being smaller.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="completion"
)
# Set app-specific limit higher than config
app.max_active_requests = 100
# Execute the method under test
result = AppGenerateService._get_max_active_requests(app)
# Verify the result (should return the smaller value)
# Assuming config limit is smaller than 100
assert result <= 100
def test_get_max_active_requests_with_zero_limits(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting max active requests with zero limits (infinite).
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="completion"
)
# Set app-specific limit to 0 (infinite)
app.max_active_requests = 0
# Execute the method under test
result = AppGenerateService._get_max_active_requests(app)
# Verify the result (should return config limit when app limit is 0)
assert result == 100 # dify_config.APP_MAX_ACTIVE_REQUESTS
def test_generate_with_exception_cleanup(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test that rate limit exit is called when an exception occurs.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="completion"
)
# Setup completion generator to raise an exception
mock_external_service_dependencies["completion_generator"].return_value.generate.side_effect = Exception(
"Test exception"
)
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test and expect exception
with pytest.raises(Exception) as exc_info:
AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify exception message
assert "Test exception" in str(exc_info.value)
# Verify rate limit exit was called for cleanup
mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once()
def test_generate_with_agent_mode_detection(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test generation with agent mode detection based on app configuration.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="chat"
)
# Mock app to have agent mode enabled by setting the mode directly
app.mode = "agent-chat"
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test
result = AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify the result
assert result == ["test_response"]
# Verify agent chat generator was called instead of regular chat generator
mock_external_service_dependencies["agent_chat_generator"].return_value.generate.assert_called_once()
mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once()
def test_generate_with_different_invoke_from_values(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test generation with different invoke from values.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
)
# Test different invoke from values
invoke_from_values = [
InvokeFrom.SERVICE_API,
InvokeFrom.WEB_APP,
InvokeFrom.EXPLORE,
InvokeFrom.DEBUGGER,
]
for invoke_from in invoke_from_values:
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test
result = AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=invoke_from, streaming=True
)
# Verify the result
assert result == ["test_response"]
def test_generate_with_complex_args(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test generation with complex arguments including files and external trace ID.
"""
fake = Faker()
app, account = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies, mode="workflow"
)
# Setup complex test arguments
args = {
"inputs": {
"query": fake.text(max_nb_chars=50),
"context": fake.text(max_nb_chars=100),
"parameters": {"temperature": 0.7, "max_tokens": 1000},
},
"files": [
{"id": fake.uuid4(), "name": "test_file.txt", "size": 1024},
{"id": fake.uuid4(), "name": "test_image.jpg", "size": 2048},
],
"external_trace_id": fake.uuid4(),
"response_mode": "streaming",
}
# Execute the method under test
result = AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify the result
assert result == ["test_response"]
# Verify workflow generator was called with complex args
mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once()
call_args = mock_external_service_dependencies["workflow_generator"].return_value.generate.call_args
assert call_args[1]["args"] == args

View File

@@ -0,0 +1,954 @@
from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
from constants.model_template import default_app_templates
from models import Account
from models.model import App, Site
from services.account_service import AccountService, TenantService
from services.app_service import AppService
class TestAppService:
"""Integration tests for AppService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.account_service.FeatureService") as mock_account_feature_service,
):
# Setup default mock returns for app service
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
# Setup default mock returns for account service
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
# Mock ModelManager for model configuration
mock_model_instance = mock_model_manager.return_value
mock_model_instance.get_default_model_instance.return_value = None
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
yield {
"feature_service": mock_feature_service,
"enterprise_service": mock_enterprise_service,
"model_manager": mock_model_manager,
"account_feature_service": mock_account_feature_service,
}
def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app creation with basic parameters.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Setup app creation arguments
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
# Create app
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Verify app was created correctly
assert app.name == app_args["name"]
assert app.description == app_args["description"]
assert app.mode == app_args["mode"]
assert app.icon_type == app_args["icon_type"]
assert app.icon == app_args["icon"]
assert app.icon_background == app_args["icon_background"]
assert app.tenant_id == tenant.id
assert app.api_rph == app_args["api_rph"]
assert app.api_rpm == app_args["api_rpm"]
assert app.created_by == account.id
assert app.updated_by == account.id
assert app.status == "normal"
assert app.enable_site is True
assert app.enable_api is True
assert app.is_demo is False
assert app.is_public is False
assert app.is_universal is False
def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app creation with different app modes.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
app_service = AppService()
# Test different app modes
# from AppMode enum in default_app_model_template
app_modes = [v.value for v in default_app_templates]
for mode in app_modes:
app_args = {
"name": f"{fake.company()} {mode}",
"description": f"Test app for {mode} mode",
"mode": mode,
"icon_type": "emoji",
"icon": "🚀",
"icon_background": "#4ECDC4",
}
app = app_service.create_app(tenant.id, app_args, account)
# Verify app mode was set correctly
assert app.mode == mode
assert app.name == app_args["name"]
assert app.tenant_id == tenant.id
assert app.created_by == account.id
def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app retrieval.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
}
app_service = AppService()
created_app = app_service.create_app(tenant.id, app_args, account)
# Get app using the service - needs current_user mock
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
retrieved_app = app_service.get_app(created_app)
# Verify retrieved app matches created app
assert retrieved_app.id == created_app.id
assert retrieved_app.name == created_app.name
assert retrieved_app.description == created_app.description
assert retrieved_app.mode == created_app.mode
assert retrieved_app.tenant_id == created_app.tenant_id
assert retrieved_app.created_by == created_app.created_by
def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful paginated app list retrieval.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
app_service = AppService()
# Create multiple apps
app_names = [fake.company() for _ in range(5)]
for name in app_names:
app_args = {
"name": name,
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "📱",
"icon_background": "#96CEB4",
}
app_service.create_app(tenant.id, app_args, account)
# Get paginated apps
args = {
"page": 1,
"limit": 10,
"mode": "chat",
}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
# Verify pagination results
assert paginated_apps is not None
assert len(paginated_apps.items) >= 5 # Should have at least 5 apps
assert paginated_apps.page == 1
assert paginated_apps.per_page == 10
# Verify all apps belong to the correct tenant
for app in paginated_apps.items:
assert app.tenant_id == tenant.id
assert app.mode == "chat"
def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test paginated app list with various filters.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
app_service = AppService()
# Create apps with different modes
chat_app_args = {
"name": "Chat App",
"description": "A chat application",
"mode": "chat",
"icon_type": "emoji",
"icon": "💬",
"icon_background": "#FF6B6B",
}
completion_app_args = {
"name": "Completion App",
"description": "A completion application",
"mode": "completion",
"icon_type": "emoji",
"icon": "✍️",
"icon_background": "#4ECDC4",
}
chat_app = app_service.create_app(tenant.id, chat_app_args, account)
completion_app = app_service.create_app(tenant.id, completion_app_args, account)
# Test filter by mode
chat_args = {
"page": 1,
"limit": 10,
"mode": "chat",
}
chat_apps = app_service.get_paginate_apps(account.id, tenant.id, chat_args)
assert len(chat_apps.items) == 1
assert chat_apps.items[0].mode == "chat"
# Test filter by name
name_args = {
"page": 1,
"limit": 10,
"mode": "chat",
"name": "Chat",
}
filtered_apps = app_service.get_paginate_apps(account.id, tenant.id, name_args)
assert len(filtered_apps.items) == 1
assert "Chat" in filtered_apps.items[0].name
# Test filter by created_by_me
created_by_me_args = {
"page": 1,
"limit": 10,
"mode": "completion",
"is_created_by_me": True,
}
my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args)
assert len(my_apps.items) == 1
def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test paginated app list with tag filters.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
app_service = AppService()
# Create an app
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🏷️",
"icon_background": "#FFEAA7",
}
app = app_service.create_app(tenant.id, app_args, account)
# Mock TagService to return the app ID for tag filtering
with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service:
mock_tag_service.return_value = [app.id]
# Test with tag filter
args = {
"page": 1,
"limit": 10,
"mode": "chat",
"tag_ids": ["tag1", "tag2"],
}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
# Verify tag service was called
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"])
# Verify results
assert paginated_apps is not None
assert len(paginated_apps.items) == 1
assert paginated_apps.items[0].id == app.id
# Test with tag filter that returns no results
with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service:
mock_tag_service.return_value = []
args = {
"page": 1,
"limit": 10,
"mode": "chat",
"tag_ids": ["nonexistent_tag"],
}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
# Should return None when no apps match tag filter
assert paginated_apps is None
def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app update with all fields.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store original values
original_name = app.name
original_description = app.description
original_icon = app.icon
original_icon_background = app.icon_background
original_use_icon_as_answer_icon = app.use_icon_as_answer_icon
# Update app
update_args = {
"name": "Updated App Name",
"description": "Updated app description",
"icon_type": "emoji",
"icon": "🔄",
"icon_background": "#FF8C42",
"use_icon_as_answer_icon": True,
}
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app(app, update_args)
# Verify updated fields
assert updated_app.name == update_args["name"]
assert updated_app.description == update_args["description"]
assert updated_app.icon == update_args["icon"]
assert updated_app.icon_background == update_args["icon_background"]
assert updated_app.use_icon_as_answer_icon is True
assert updated_app.updated_by == account.id
# Verify other fields remain unchanged
assert updated_app.mode == app.mode
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app name update.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store original name
original_name = app.name
# Update app name
new_name = "New App Name"
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_name(app, new_name)
assert updated_app.name == new_name
assert updated_app.updated_by == account.id
# Verify other fields remain unchanged
assert updated_app.description == app.description
assert updated_app.mode == app.mode
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app icon update.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store original values
original_icon = app.icon
original_icon_background = app.icon_background
# Update app icon
new_icon = "🌟"
new_icon_background = "#FFD93D"
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_icon(app, new_icon, new_icon_background)
assert updated_app.icon == new_icon
assert updated_app.icon_background == new_icon_background
assert updated_app.updated_by == account.id
# Verify other fields remain unchanged
assert updated_app.name == app.name
assert updated_app.description == app.description
assert updated_app.mode == app.mode
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app site status update.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🌐",
"icon_background": "#74B9FF",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store original site status
original_site_status = app.enable_site
# Update site status to disabled
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_site_status(app, False)
assert updated_app.enable_site is False
assert updated_app.updated_by == account.id
# Update site status back to enabled
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_site_status(updated_app, True)
assert updated_app.enable_site is True
assert updated_app.updated_by == account.id
# Verify other fields remain unchanged
assert updated_app.name == app.name
assert updated_app.description == app.description
assert updated_app.mode == app.mode
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app API status update.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🔌",
"icon_background": "#A29BFE",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store original API status
original_api_status = app.enable_api
# Update API status to disabled
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_api_status(app, False)
assert updated_app.enable_api is False
assert updated_app.updated_by == account.id
# Update API status back to enabled
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_api_status(updated_app, True)
assert updated_app.enable_api is True
assert updated_app.updated_by == account.id
# Verify other fields remain unchanged
assert updated_app.name == app.name
assert updated_app.description == app.description
assert updated_app.mode == app.mode
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app site status update when status doesn't change.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🔄",
"icon_background": "#FD79A8",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store original values
original_site_status = app.enable_site
original_updated_at = app.updated_at
# Update site status to the same value (no change)
updated_app = app_service.update_app_site_status(app, original_site_status)
# Verify app is returned unchanged
assert updated_app.id == app.id
assert updated_app.enable_site == original_site_status
assert updated_app.updated_at == original_updated_at
# Verify other fields remain unchanged
assert updated_app.name == app.name
assert updated_app.description == app.description
assert updated_app.mode == app.mode
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app deletion.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🗑️",
"icon_background": "#E17055",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store app ID for verification
app_id = app.id
# Mock the async deletion task
with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task:
mock_delete_task.delay.return_value = None
# Delete app
app_service.delete_app(app)
# Verify async deletion task was called
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
# Verify app was deleted from database
from extensions.ext_database import db
deleted_app = db.session.query(App).filter_by(id=app_id).first()
assert deleted_app is None
def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app deletion with related data cleanup.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🧹",
"icon_background": "#00B894",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Store app ID for verification
app_id = app.id
# Mock webapp auth cleanup
mock_external_service_dependencies[
"feature_service"
].get_system_features.return_value.webapp_auth.enabled = True
# Mock the async deletion task
with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task:
mock_delete_task.delay.return_value = None
# Delete app
app_service.delete_app(app)
# Verify webapp auth cleanup was called
mock_external_service_dependencies["enterprise_service"].WebAppAuth.cleanup_webapp.assert_called_once_with(
app_id
)
# Verify async deletion task was called
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
# Verify app was deleted from database
from extensions.ext_database import db
deleted_app = db.session.query(App).filter_by(id=app_id).first()
assert deleted_app is None
def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app metadata retrieval.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "📊",
"icon_background": "#6C5CE7",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Get app metadata
app_meta = app_service.get_app_meta(app)
# Verify metadata contains expected fields
assert "tool_icons" in app_meta
# Note: get_app_meta currently only returns tool_icons
def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app code retrieval by app ID.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🔗",
"icon_background": "#FDCB6E",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Get app code by ID
app_code = AppService.get_app_code_by_id(app.id)
# Verify app code was retrieved correctly
# Note: Site would be created when App is created, site.code is auto-generated
assert app_code is not None
assert len(app_code) > 0
def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful app ID retrieval by app code.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🆔",
"icon_background": "#E84393",
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Create a site for the app
site = Site()
site.app_id = app.id
site.code = fake.postalcode()
site.title = fake.company()
site.status = "normal"
site.default_language = "en-US"
site.customize_token_strategy = "uuid"
from extensions.ext_database import db
db.session.add(site)
db.session.commit()
# Get app ID by code
app_id = AppService.get_app_id_by_code(site.code)
# Verify app ID was retrieved correctly
assert app_id == app.id
def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app creation with invalid mode.
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Setup app creation arguments with invalid mode
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "invalid_mode", # Invalid mode
"icon_type": "emoji",
"icon": "",
"icon_background": "#D63031",
}
app_service = AppService()
# Attempt to create app with invalid mode
with pytest.raises(ValueError, match="invalid mode value"):
app_service.create_app(tenant.id, app_args, account)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,386 @@
"""Unit tests for FeedbackService."""
import json
from datetime import datetime
from types import SimpleNamespace
from unittest import mock
import pytest
from extensions.ext_database import db
from models.model import App, Conversation, Message
from services.feedback_service import FeedbackService
class TestFeedbackService:
"""Test FeedbackService methods."""
@pytest.fixture
def mock_db_session(self, monkeypatch):
"""Mock database session."""
mock_session = mock.Mock()
monkeypatch.setattr(db, "session", mock_session)
return mock_session
@pytest.fixture
def sample_data(self):
"""Create sample data for testing."""
app_id = "test-app-id"
# Create mock models
app = App(id=app_id, name="Test App")
conversation = Conversation(id="test-conversation-id", app_id=app_id, name="Test Conversation")
message = Message(
id="test-message-id",
conversation_id="test-conversation-id",
query="What is AI?",
answer="AI is artificial intelligence.",
inputs={"query": "What is AI?"},
created_at=datetime(2024, 1, 1, 10, 0, 0),
)
# Use SimpleNamespace to avoid ORM model constructor issues
user_feedback = SimpleNamespace(
id="user-feedback-id",
app_id=app_id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="like",
from_source="user",
content="Great answer!",
from_end_user_id="user-123",
from_account_id=None,
from_account=None, # Mock account object
created_at=datetime(2024, 1, 1, 10, 5, 0),
)
admin_feedback = SimpleNamespace(
id="admin-feedback-id",
app_id=app_id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="dislike",
from_source="admin",
content="Could be more detailed",
from_end_user_id=None,
from_account_id="admin-456",
from_account=SimpleNamespace(name="Admin User"), # Mock account object
created_at=datetime(2024, 1, 1, 10, 10, 0),
)
return {
"app": app,
"conversation": conversation,
"message": message,
"user_feedback": user_feedback,
"admin_feedback": admin_feedback,
}
def test_export_feedbacks_csv_format(self, mock_db_session, sample_data):
"""Test exporting feedback data in CSV format."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
mock_db_session.query.return_value = mock_query
# Test CSV export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
# Verify response structure
assert hasattr(result, "headers")
assert "text/csv" in result.headers["Content-Type"]
assert "attachment" in result.headers["Content-Disposition"]
# Check CSV content
csv_content = result.get_data(as_text=True)
# Verify essential headers exist (order may include additional columns)
assert "feedback_id" in csv_content
assert "app_name" in csv_content
assert "conversation_id" in csv_content
assert sample_data["app"].name in csv_content
assert sample_data["message"].query in csv_content
def test_export_feedbacks_json_format(self, mock_db_session, sample_data):
"""Test exporting feedback data in JSON format."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
mock_db_session.query.return_value = mock_query
# Test JSON export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
# Verify response structure
assert hasattr(result, "headers")
assert "application/json" in result.headers["Content-Type"]
assert "attachment" in result.headers["Content-Disposition"]
# Check JSON content
json_content = json.loads(result.get_data(as_text=True))
assert "export_info" in json_content
assert "feedback_data" in json_content
assert json_content["export_info"]["app_id"] == sample_data["app"].id
assert json_content["export_info"]["total_records"] == 1
def test_export_feedbacks_with_filters(self, mock_db_session, sample_data):
"""Test exporting feedback with various filters."""
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
)
]
mock_db_session.query.return_value = mock_query
# Test with filters
result = FeedbackService.export_feedbacks(
app_id=sample_data["app"].id,
from_source="admin",
rating="dislike",
has_comment=True,
start_date="2024-01-01",
end_date="2024-12-31",
format_type="csv",
)
# Verify filters were applied
assert mock_query.filter.called
filter_calls = mock_query.filter.call_args_list
# At least three filter invocations are expected (source, rating, comment)
assert len(filter_calls) >= 3
def test_export_feedbacks_no_data(self, mock_db_session, sample_data):
"""Test exporting feedback when no data exists."""
# Setup mock query result with no data
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
mock_db_session.query.return_value = mock_query
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
# Should return an empty CSV with headers only
assert hasattr(result, "headers")
assert "text/csv" in result.headers["Content-Type"]
csv_content = result.get_data(as_text=True)
# Headers should exist (order can include additional columns)
assert "feedback_id" in csv_content
assert "app_name" in csv_content
assert "conversation_id" in csv_content
# No data rows expected
assert len([line for line in csv_content.strip().splitlines() if line.strip()]) == 1
def test_export_feedbacks_invalid_date_format(self, mock_db_session, sample_data):
"""Test exporting feedback with invalid date format."""
# Test with invalid start_date
with pytest.raises(ValueError, match="Invalid start_date format"):
FeedbackService.export_feedbacks(app_id=sample_data["app"].id, start_date="invalid-date-format")
# Test with invalid end_date
with pytest.raises(ValueError, match="Invalid end_date format"):
FeedbackService.export_feedbacks(app_id=sample_data["app"].id, end_date="invalid-date-format")
def test_export_feedbacks_invalid_format(self, mock_db_session, sample_data):
"""Test exporting feedback with unsupported format."""
with pytest.raises(ValueError, match="Unsupported format"):
FeedbackService.export_feedbacks(
app_id=sample_data["app"].id,
format_type="xml", # Unsupported format
)
def test_export_feedbacks_long_response_truncation(self, mock_db_session, sample_data):
"""Test that long AI responses are truncated in export."""
# Create message with long response
long_message = Message(
id="long-message-id",
conversation_id="test-conversation-id",
query="What is AI?",
answer="A" * 600, # 600 character response
inputs={"query": "What is AI?"},
created_at=datetime(2024, 1, 1, 10, 0, 0),
)
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
long_message,
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
)
]
mock_db_session.query.return_value = mock_query
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
# Check JSON content
json_content = json.loads(result.get_data(as_text=True))
exported_answer = json_content["feedback_data"][0]["ai_response"]
# Should be truncated with ellipsis
assert len(exported_answer) <= 503 # 500 + "..."
assert exported_answer.endswith("...")
assert len(exported_answer) > 500 # Should be close to limit
def test_export_feedbacks_unicode_content(self, mock_db_session, sample_data):
"""Test exporting feedback with unicode content (Chinese characters)."""
# Create feedback with Chinese content (use SimpleNamespace to avoid ORM constructor constraints)
chinese_feedback = SimpleNamespace(
id="chinese-feedback-id",
app_id=sample_data["app"].id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="dislike",
from_source="user",
content="回答不够详细,需要更多信息",
from_end_user_id="user-123",
from_account_id=None,
created_at=datetime(2024, 1, 1, 10, 5, 0),
)
# Create Chinese message
chinese_message = Message(
id="chinese-message-id",
conversation_id="test-conversation-id",
query="什么是人工智能?",
answer="人工智能是模拟人类智能的技术。",
inputs={"query": "什么是人工智能?"},
created_at=datetime(2024, 1, 1, 10, 0, 0),
)
# Setup mock query result
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
chinese_feedback,
chinese_message,
sample_data["conversation"],
sample_data["app"],
None, # No account for user feedback
)
]
mock_db_session.query.return_value = mock_query
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
# Check that unicode content is preserved
csv_content = result.get_data(as_text=True)
assert "什么是人工智能?" in csv_content
assert "回答不够详细,需要更多信息" in csv_content
assert "人工智能是模拟人类智能的技术" in csv_content
def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data):
"""Test that rating emojis are properly formatted in export."""
# Setup mock query result with both like and dislike feedback
mock_query = mock.Mock()
mock_query.join.return_value = mock_query
mock_query.outerjoin.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = [
(
sample_data["user_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["user_feedback"].from_account,
),
(
sample_data["admin_feedback"],
sample_data["message"],
sample_data["conversation"],
sample_data["app"],
sample_data["admin_feedback"].from_account,
),
]
mock_db_session.query.return_value = mock_query
# Test export
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
# Check JSON content for emoji ratings
json_content = json.loads(result.get_data(as_text=True))
feedback_data = json_content["feedback_data"]
# Should have both feedback records
assert len(feedback_data) == 2
# Check that emojis are properly set
like_feedback = next(f for f in feedback_data if f["feedback_rating_raw"] == "like")
dislike_feedback = next(f for f in feedback_data if f["feedback_rating_raw"] == "dislike")
assert like_feedback["feedback_rating"] == "👍"
assert dislike_feedback["feedback_rating"] == "👎"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,775 @@
from unittest.mock import patch
import pytest
from faker import Faker
from models.model import MessageFeedback
from services.app_service import AppService
from services.errors.message import (
FirstMessageNotExistsError,
LastMessageNotExistsError,
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.message_service import MessageService
class TestMessageService:
"""Integration tests for MessageService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.account_service.FeatureService") as mock_account_feature_service,
patch("services.message_service.ModelManager") as mock_model_manager,
patch("services.message_service.WorkflowService") as mock_workflow_service,
patch("services.message_service.AdvancedChatAppConfigManager") as mock_app_config_manager,
patch("services.message_service.LLMGenerator") as mock_llm_generator,
patch("services.message_service.TraceQueueManager") as mock_trace_manager_class,
patch("services.message_service.TokenBufferMemory") as mock_token_buffer_memory,
):
# Setup default mock returns
mock_account_feature_service.get_features.return_value.billing.enabled = False
# Mock ModelManager
mock_model_instance = mock_model_manager.return_value.get_default_model_instance.return_value
mock_model_instance.get_tts_voices.return_value = [{"value": "test-voice"}]
# Mock get_model_instance method as well
mock_model_manager.return_value.get_model_instance.return_value = mock_model_instance
# Mock WorkflowService
mock_workflow = mock_workflow_service.return_value.get_published_workflow.return_value
mock_workflow_service.return_value.get_draft_workflow.return_value = mock_workflow
# Mock AdvancedChatAppConfigManager
mock_app_config = mock_app_config_manager.get_app_config.return_value
mock_app_config.additional_features.suggested_questions_after_answer = True
# Mock LLMGenerator
mock_llm_generator.generate_suggested_questions_after_answer.return_value = ["Question 1", "Question 2"]
# Mock TraceQueueManager
mock_trace_manager_instance = mock_trace_manager_class.return_value
# Mock TokenBufferMemory
mock_memory_instance = mock_token_buffer_memory.return_value
mock_memory_instance.get_history_prompt_text.return_value = "Mocked history prompt"
yield {
"account_feature_service": mock_account_feature_service,
"model_manager": mock_model_manager,
"workflow_service": mock_workflow_service,
"app_config_manager": mock_app_config_manager,
"llm_generator": mock_llm_generator,
"trace_manager_class": mock_trace_manager_class,
"trace_manager_instance": mock_trace_manager_instance,
"token_buffer_memory": mock_token_buffer_memory,
# "current_user": mock_current_user,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (app, account) - Created app and account instances
"""
fake = Faker()
# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True
# Create account and tenant first
from services.account_service import AccountService, TenantService
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Setup app creation arguments
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "advanced-chat", # Use advanced-chat mode to use mocked workflow
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
# Create app
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Setup current_user mock
self._mock_current_user(mock_external_service_dependencies, account.id, tenant.id)
return app, account
def _mock_current_user(self, mock_external_service_dependencies, account_id, tenant_id):
"""
Helper method to mock the current user for testing.
"""
# mock_external_service_dependencies["current_user"].id = account_id
# mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id
def _create_test_conversation(self, app, account, fake):
"""
Helper method to create a test conversation with all required fields.
"""
from extensions.ext_database import db
from models.model import Conversation
conversation = Conversation(
app_id=app.id,
app_model_config_id=None,
model_provider=None,
model_id="",
override_model_configs=None,
mode=app.mode,
name=fake.sentence(),
inputs={},
introduction="",
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from="console",
from_source="console",
from_end_user_id=None,
from_account_id=account.id,
)
db.session.add(conversation)
db.session.flush()
return conversation
def _create_test_message(self, app, conversation, account, fake):
"""
Helper method to create a test message with all required fields.
"""
import json
from extensions.ext_database import db
from models.model import Message
message = Message(
app_id=app.id,
model_provider=None,
model_id="",
override_model_configs=None,
conversation_id=conversation.id,
inputs={},
query=fake.sentence(),
message=json.dumps([{"role": "user", "text": fake.sentence()}]),
message_tokens=0,
message_unit_price=0,
message_price_unit=0.001,
answer=fake.text(max_nb_chars=200),
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0.001,
parent_message_id=None,
provider_response_latency=0,
total_price=0,
currency="USD",
invoke_from="console",
from_source="console",
from_end_user_id=None,
from_account_id=account.id,
)
db.session.add(message)
db.session.commit()
return message
def test_pagination_by_first_id_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful pagination by first ID.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and multiple messages
conversation = self._create_test_conversation(app, account, fake)
messages = []
for i in range(5):
message = self._create_test_message(app, conversation, account, fake)
messages.append(message)
# Test pagination by first ID
result = MessageService.pagination_by_first_id(
app_model=app,
user=account,
conversation_id=conversation.id,
first_id=messages[2].id, # Use middle message as first_id
limit=2,
order="asc",
)
# Verify results
assert result.limit == 2
assert len(result.data) == 2
# total 5, from the middle, no more
assert result.has_more is False
# Verify messages are in ascending order
assert result.data[0].created_at <= result.data[1].created_at
def test_pagination_by_first_id_no_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test pagination by first ID when no user is provided.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Test pagination with no user
result = MessageService.pagination_by_first_id(
app_model=app, user=None, conversation_id=fake.uuid4(), first_id=None, limit=10
)
# Verify empty result
assert result.limit == 10
assert len(result.data) == 0
assert result.has_more is False
def test_pagination_by_first_id_no_conversation_id(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test pagination by first ID when no conversation ID is provided.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Test pagination with no conversation ID
result = MessageService.pagination_by_first_id(
app_model=app, user=account, conversation_id="", first_id=None, limit=10
)
# Verify empty result
assert result.limit == 10
assert len(result.data) == 0
assert result.has_more is False
def test_pagination_by_first_id_invalid_first_id(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test pagination by first ID with invalid first_id.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
self._create_test_message(app, conversation, account, fake)
# Test pagination with invalid first_id
with pytest.raises(FirstMessageNotExistsError):
MessageService.pagination_by_first_id(
app_model=app,
user=account,
conversation_id=conversation.id,
first_id=fake.uuid4(), # Non-existent message ID
limit=10,
)
def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful pagination by last ID.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and multiple messages
conversation = self._create_test_conversation(app, account, fake)
messages = []
for i in range(5):
message = self._create_test_message(app, conversation, account, fake)
messages.append(message)
# Test pagination by last ID
result = MessageService.pagination_by_last_id(
app_model=app,
user=account,
last_id=messages[2].id, # Use middle message as last_id
limit=2,
conversation_id=conversation.id,
)
# Verify results
assert result.limit == 2
assert len(result.data) == 2
# total 5, from the middle, no more
assert result.has_more is False
# Verify messages are in descending order
assert result.data[0].created_at >= result.data[1].created_at
def test_pagination_by_last_id_with_include_ids(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test pagination by last ID with include_ids filter.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and multiple messages
conversation = self._create_test_conversation(app, account, fake)
messages = []
for i in range(5):
message = self._create_test_message(app, conversation, account, fake)
messages.append(message)
# Test pagination with include_ids
include_ids = [messages[0].id, messages[1].id, messages[2].id]
result = MessageService.pagination_by_last_id(
app_model=app, user=account, last_id=messages[1].id, limit=2, include_ids=include_ids
)
# Verify results
assert result.limit == 2
assert len(result.data) <= 2
# Verify all returned messages are in include_ids
for message in result.data:
assert message.id in include_ids
def test_pagination_by_last_id_no_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test pagination by last ID when no user is provided.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Test pagination with no user
result = MessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
# Verify empty result
assert result.limit == 10
assert len(result.data) == 0
assert result.has_more is False
def test_pagination_by_last_id_invalid_last_id(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test pagination by last ID with invalid last_id.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
self._create_test_message(app, conversation, account, fake)
# Test pagination with invalid last_id
with pytest.raises(LastMessageNotExistsError):
MessageService.pagination_by_last_id(
app_model=app,
user=account,
last_id=fake.uuid4(), # Non-existent message ID
limit=10,
conversation_id=conversation.id,
)
def test_create_feedback_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful creation of feedback.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Create feedback
rating = "like"
content = fake.text(max_nb_chars=100)
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=rating, content=content
)
# Verify feedback was created correctly
assert feedback.app_id == app.id
assert feedback.conversation_id == conversation.id
assert feedback.message_id == message.id
assert feedback.rating == rating
assert feedback.content == content
assert feedback.from_source == "admin"
assert feedback.from_account_id == account.id
assert feedback.from_end_user_id is None
def test_create_feedback_no_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test creating feedback when no user is provided.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Test creating feedback with no user
with pytest.raises(ValueError, match="user cannot be None"):
MessageService.create_feedback(
app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100)
)
def test_create_feedback_update_existing(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test updating existing feedback.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Create initial feedback
initial_rating = "like"
initial_content = fake.text(max_nb_chars=100)
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=initial_rating, content=initial_content
)
# Update feedback
updated_rating = "dislike"
updated_content = fake.text(max_nb_chars=100)
updated_feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=updated_rating, content=updated_content
)
# Verify feedback was updated correctly
assert updated_feedback.id == feedback.id
assert updated_feedback.rating == updated_rating
assert updated_feedback.content == updated_content
assert updated_feedback.rating != initial_rating
assert updated_feedback.content != initial_content
def test_create_feedback_delete_existing(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test deleting existing feedback by setting rating to None.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Create initial feedback
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating="like", content=fake.text(max_nb_chars=100)
)
# Delete feedback by setting rating to None
MessageService.create_feedback(app_model=app, message_id=message.id, user=account, rating=None, content=None)
# Verify feedback was deleted
from extensions.ext_database import db
deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first()
assert deleted_feedback is None
def test_create_feedback_no_rating_when_not_exists(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test creating feedback with no rating when feedback doesn't exist.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Test creating feedback with no rating when no feedback exists
with pytest.raises(ValueError, match="rating cannot be None when feedback not exists"):
MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=None, content=None
)
def test_get_all_messages_feedbacks_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of all message feedbacks.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create multiple conversations and messages with feedbacks
feedbacks = []
for i in range(3):
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
feedback = MessageService.create_feedback(
app_model=app,
message_id=message.id,
user=account,
rating="like" if i % 2 == 0 else "dislike",
content=f"Feedback {i}: {fake.text(max_nb_chars=50)}",
)
feedbacks.append(feedback)
# Get all feedbacks
result = MessageService.get_all_messages_feedbacks(app, page=1, limit=10)
# Verify results
assert len(result) == 3
# Verify feedbacks are ordered by created_at desc
for i in range(len(result) - 1):
assert result[i]["created_at"] >= result[i + 1]["created_at"]
def test_get_all_messages_feedbacks_pagination(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test pagination of message feedbacks.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create multiple conversations and messages with feedbacks
for i in range(5):
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}"
)
# Get feedbacks with pagination
result_page_1 = MessageService.get_all_messages_feedbacks(app, page=1, limit=3)
result_page_2 = MessageService.get_all_messages_feedbacks(app, page=2, limit=3)
# Verify pagination results
assert len(result_page_1) == 3
assert len(result_page_2) == 2
# Verify no overlap between pages
page_1_ids = {feedback["id"] for feedback in result_page_1}
page_2_ids = {feedback["id"] for feedback in result_page_2}
assert len(page_1_ids.intersection(page_2_ids)) == 0
def test_get_message_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of message.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Get message
retrieved_message = MessageService.get_message(app_model=app, user=account, message_id=message.id)
# Verify message was retrieved correctly
assert retrieved_message.id == message.id
assert retrieved_message.app_id == app.id
assert retrieved_message.conversation_id == conversation.id
assert retrieved_message.from_source == "console"
assert retrieved_message.from_account_id == account.id
def test_get_message_not_exists(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test getting message that doesn't exist.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Test getting non-existent message
with pytest.raises(MessageNotExistsError):
MessageService.get_message(app_model=app, user=account, message_id=fake.uuid4())
def test_get_message_wrong_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test getting message with wrong user (different account).
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Create another account
from services.account_service import AccountService, TenantService
other_account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(other_account, name=fake.company())
# Test getting message with different user
with pytest.raises(MessageNotExistsError):
MessageService.get_message(app_model=app, user=other_account, message_id=message.id)
def test_get_suggested_questions_after_answer_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful generation of suggested questions after answer.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Mock the LLMGenerator to return specific questions
mock_questions = ["What is AI?", "How does machine learning work?", "Tell me about neural networks"]
mock_external_service_dependencies[
"llm_generator"
].generate_suggested_questions_after_answer.return_value = mock_questions
# Get suggested questions
from core.app.entities.app_invoke_entities import InvokeFrom
result = MessageService.get_suggested_questions_after_answer(
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
)
# Verify results
assert result == mock_questions
# Verify LLMGenerator was called
mock_external_service_dependencies[
"llm_generator"
].generate_suggested_questions_after_answer.assert_called_once()
# Verify TraceQueueManager was called
mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once()
def test_get_suggested_questions_after_answer_no_user(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting suggested questions when no user is provided.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Test getting suggested questions with no user
from core.app.entities.app_invoke_entities import InvokeFrom
with pytest.raises(ValueError, match="user cannot be None"):
MessageService.get_suggested_questions_after_answer(
app_model=app, user=None, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
)
def test_get_suggested_questions_after_answer_disabled(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting suggested questions when feature is disabled.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Mock the feature to be disabled
mock_external_service_dependencies[
"app_config_manager"
].get_app_config.return_value.additional_features.suggested_questions_after_answer = False
# Test getting suggested questions when feature is disabled
from core.app.entities.app_invoke_entities import InvokeFrom
with pytest.raises(SuggestedQuestionsAfterAnswerDisabledError):
MessageService.get_suggested_questions_after_answer(
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
)
def test_get_suggested_questions_after_answer_no_workflow(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting suggested questions when no workflow exists.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Mock no workflow
mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None
# Get suggested questions (should return empty list)
from core.app.entities.app_invoke_entities import InvokeFrom
result = MessageService.get_suggested_questions_after_answer(
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
)
# Verify empty result
assert result == []
def test_get_suggested_questions_after_answer_debugger_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting suggested questions in debugger mode.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
# Mock questions
mock_questions = ["Debug question 1", "Debug question 2"]
mock_external_service_dependencies[
"llm_generator"
].generate_suggested_questions_after_answer.return_value = mock_questions
# Get suggested questions in debugger mode
from core.app.entities.app_invoke_entities import InvokeFrom
result = MessageService.get_suggested_questions_after_answer(
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.DEBUGGER
)
# Verify results
assert result == mock_questions
# Verify draft workflow was used instead of published workflow
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with(
app_model=app
)
# Verify TraceQueueManager was called
mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once()

View File

@@ -0,0 +1,475 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import select
from models.account import TenantAccountJoin, TenantAccountRole
from models.model import Account, Tenant
from models.provider import LoadBalancingModelConfig, Provider, ProviderModelSetting
from services.model_load_balancing_service import ModelLoadBalancingService
class TestModelLoadBalancingService:
"""Integration tests for ModelLoadBalancingService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.model_load_balancing_service.ProviderManager") as mock_provider_manager,
patch("services.model_load_balancing_service.LBModelManager") as mock_lb_model_manager,
patch("services.model_load_balancing_service.ModelProviderFactory") as mock_model_provider_factory,
patch("services.model_load_balancing_service.encrypter") as mock_encrypter,
):
# Setup default mock returns
mock_provider_manager_instance = mock_provider_manager.return_value
# Mock provider configuration
mock_provider_config = MagicMock()
mock_provider_config.provider.provider = "openai"
mock_provider_config.custom_configuration.provider = None
# Mock provider model setting
mock_provider_model_setting = MagicMock()
mock_provider_model_setting.load_balancing_enabled = False
mock_provider_config.get_provider_model_setting.return_value = mock_provider_model_setting
# Mock provider configurations dict
mock_provider_configs = {"openai": mock_provider_config}
mock_provider_manager_instance.get_configurations.return_value = mock_provider_configs
# Mock LBModelManager
mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0)
# Mock ModelProviderFactory
mock_model_provider_factory_instance = mock_model_provider_factory.return_value
# Mock credential schemas
mock_credential_schema = MagicMock()
mock_credential_schema.credential_form_schemas = []
# Mock provider configuration methods
mock_provider_config.extract_secret_variables.return_value = []
mock_provider_config.obfuscated_credentials.return_value = {}
mock_provider_config._get_credential_schema.return_value = mock_credential_schema
yield {
"provider_manager": mock_provider_manager,
"lb_model_manager": mock_lb_model_manager,
"model_provider_factory": mock_model_provider_factory,
"encrypter": mock_encrypter,
"provider_config": mock_provider_config,
"provider_model_setting": mock_provider_model_setting,
"credential_schema": mock_credential_schema,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def _create_test_provider_and_setting(
self, db_session_with_containers, tenant_id, mock_external_service_dependencies
):
"""
Helper method to create a test provider and provider model setting.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
tenant_id: Tenant ID for the provider
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (provider, provider_model_setting) - Created provider and setting instances
"""
fake = Faker()
from extensions.ext_database import db
# Create provider
provider = Provider(
tenant_id=tenant_id,
provider_name="openai",
provider_type="custom",
is_valid=True,
)
db.session.add(provider)
db.session.commit()
# Create provider model setting
provider_model_setting = ProviderModelSetting(
tenant_id=tenant_id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
enabled=True,
load_balancing_enabled=False,
)
db.session.add(provider_model_setting)
db.session.commit()
return provider, provider_model_setting
def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful model load balancing enablement.
This test verifies:
- Proper provider configuration retrieval
- Successful enablement of model load balancing
- Correct method calls to provider configuration
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider, provider_model_setting = self._create_test_provider_and_setting(
db_session_with_containers, tenant.id, mock_external_service_dependencies
)
# Setup mocks for enable method
mock_provider_config = mock_external_service_dependencies["provider_config"]
mock_provider_config.enable_model_load_balancing = MagicMock()
# Act: Execute the method under test
service = ModelLoadBalancingService()
service.enable_model_load_balancing(
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
)
# Assert: Verify the expected outcomes
mock_provider_config.enable_model_load_balancing.assert_called_once()
call_args = mock_provider_config.enable_model_load_balancing.call_args
assert call_args.kwargs["model"] == "gpt-3.5-turbo"
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
# Verify database state
from extensions.ext_database import db
db.session.refresh(provider)
db.session.refresh(provider_model_setting)
assert provider.id is not None
assert provider_model_setting.id is not None
def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful model load balancing disablement.
This test verifies:
- Proper provider configuration retrieval
- Successful disablement of model load balancing
- Correct method calls to provider configuration
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider, provider_model_setting = self._create_test_provider_and_setting(
db_session_with_containers, tenant.id, mock_external_service_dependencies
)
# Setup mocks for disable method
mock_provider_config = mock_external_service_dependencies["provider_config"]
mock_provider_config.disable_model_load_balancing = MagicMock()
# Act: Execute the method under test
service = ModelLoadBalancingService()
service.disable_model_load_balancing(
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
)
# Assert: Verify the expected outcomes
mock_provider_config.disable_model_load_balancing.assert_called_once()
call_args = mock_provider_config.disable_model_load_balancing.call_args
assert call_args.kwargs["model"] == "gpt-3.5-turbo"
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
# Verify database state
from extensions.ext_database import db
db.session.refresh(provider)
db.session.refresh(provider_model_setting)
assert provider.id is not None
assert provider_model_setting.id is not None
def test_enable_model_load_balancing_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error handling when provider does not exist.
This test verifies:
- Proper error handling for non-existent provider
- Correct exception type and message
- No database state changes
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Setup mocks to return empty provider configurations
mock_provider_manager = mock_external_service_dependencies["provider_manager"]
mock_provider_manager_instance = mock_provider_manager.return_value
mock_provider_manager_instance.get_configurations.return_value = {}
# Act & Assert: Verify proper error handling
service = ModelLoadBalancingService()
with pytest.raises(ValueError) as exc_info:
service.enable_model_load_balancing(
tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm"
)
# Verify correct error message
assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
# Verify no database state changes occurred
from extensions.ext_database import db
db.session.rollback()
def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of load balancing configurations.
This test verifies:
- Proper provider configuration retrieval
- Successful database query for load balancing configs
- Correct return format and data structure
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider, provider_model_setting = self._create_test_provider_and_setting(
db_session_with_containers, tenant.id, mock_external_service_dependencies
)
# Create load balancing config
from extensions.ext_database import db
load_balancing_config = LoadBalancingModelConfig(
tenant_id=tenant.id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
name="config1",
encrypted_config='{"api_key": "test_key"}',
enabled=True,
)
db.session.add(load_balancing_config)
db.session.commit()
# Verify the config was created
db.session.refresh(load_balancing_config)
assert load_balancing_config.id is not None
# Setup mocks for get_load_balancing_configs method
mock_provider_config = mock_external_service_dependencies["provider_config"]
mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"]
mock_provider_model_setting.load_balancing_enabled = True
# Mock credential schema methods
mock_credential_schema = mock_external_service_dependencies["credential_schema"]
mock_credential_schema.credential_form_schemas = []
# Mock encrypter
mock_encrypter = mock_external_service_dependencies["encrypter"]
mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher")
# Mock _get_credential_schema method
mock_provider_config._get_credential_schema.return_value = mock_credential_schema
# Mock extract_secret_variables method
mock_provider_config.extract_secret_variables.return_value = []
# Mock obfuscated_credentials method
mock_provider_config.obfuscated_credentials.return_value = {}
# Mock LBModelManager.get_config_in_cooldown_and_ttl
mock_lb_model_manager = mock_external_service_dependencies["lb_model_manager"]
mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0)
# Act: Execute the method under test
service = ModelLoadBalancingService()
is_enabled, configs = service.get_load_balancing_configs(
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
)
# Assert: Verify the expected outcomes
assert is_enabled is True
assert len(configs) == 1
assert configs[0]["id"] == load_balancing_config.id
assert configs[0]["name"] == "config1"
assert configs[0]["enabled"] is True
assert configs[0]["in_cooldown"] is False
assert configs[0]["ttl"] == 0
# Verify database state
db.session.refresh(load_balancing_config)
assert load_balancing_config.id is not None
def test_get_load_balancing_configs_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error handling when provider does not exist in get_load_balancing_configs.
This test verifies:
- Proper error handling for non-existent provider
- Correct exception type and message
- No database state changes
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Setup mocks to return empty provider configurations
mock_provider_manager = mock_external_service_dependencies["provider_manager"]
mock_provider_manager_instance = mock_provider_manager.return_value
mock_provider_manager_instance.get_configurations.return_value = {}
# Act & Assert: Verify proper error handling
service = ModelLoadBalancingService()
with pytest.raises(ValueError) as exc_info:
service.get_load_balancing_configs(
tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm"
)
# Verify correct error message
assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
# Verify no database state changes occurred
from extensions.ext_database import db
db.session.rollback()
def test_get_load_balancing_configs_with_inherit_config(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test load balancing configs retrieval with inherit configuration.
This test verifies:
- Proper handling of inherit configuration
- Correct ordering of configurations
- Inherit config initialization when needed
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider, provider_model_setting = self._create_test_provider_and_setting(
db_session_with_containers, tenant.id, mock_external_service_dependencies
)
# Create load balancing config
from extensions.ext_database import db
load_balancing_config = LoadBalancingModelConfig(
tenant_id=tenant.id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
name="config1",
encrypted_config='{"api_key": "test_key"}',
enabled=True,
)
db.session.add(load_balancing_config)
db.session.commit()
# Setup mocks for inherit config scenario
mock_provider_config = mock_external_service_dependencies["provider_config"]
mock_provider_config.custom_configuration.provider = MagicMock() # Enable custom config
mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"]
mock_provider_model_setting.load_balancing_enabled = True
# Mock credential schema methods
mock_credential_schema = mock_external_service_dependencies["credential_schema"]
mock_credential_schema.credential_form_schemas = []
# Mock encrypter
mock_encrypter = mock_external_service_dependencies["encrypter"]
mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher")
# Act: Execute the method under test
service = ModelLoadBalancingService()
is_enabled, configs = service.get_load_balancing_configs(
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
)
# Assert: Verify the expected outcomes
assert is_enabled is True
assert len(configs) == 2 # inherit config + existing config
# First config should be inherit config
assert configs[0]["name"] == "__inherit__"
assert configs[0]["enabled"] is True
# Second config should be the existing config
assert configs[1]["id"] == load_balancing_config.id
assert configs[1]["name"] == "config1"
# Verify database state
db.session.refresh(load_balancing_config)
assert load_balancing_config.id is not None
# Verify inherit config was created in database
inherit_configs = db.session.scalars(
select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__")
).all()
assert len(inherit_configs) == 1

View File

@@ -0,0 +1,620 @@
from unittest.mock import patch
import pytest
from faker import Faker
from models.model import EndUser, Message
from models.web import SavedMessage
from services.app_service import AppService
from services.saved_message_service import SavedMessageService
class TestSavedMessageService:
"""Integration tests for SavedMessageService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.account_service.FeatureService") as mock_account_feature_service,
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.saved_message_service.MessageService") as mock_message_service,
):
# Setup default mock returns
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
# Mock ModelManager for app creation
mock_model_instance = mock_model_manager.return_value
mock_model_instance.get_default_model_instance.return_value = None
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
# Mock MessageService
mock_message_service.get_message.return_value = None
mock_message_service.pagination_by_last_id.return_value = None
yield {
"account_feature_service": mock_account_feature_service,
"model_manager": mock_model_manager,
"message_service": mock_message_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (app, account) - Created app and account instances
"""
fake = Faker()
# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True
# Create account and tenant first
from services.account_service import AccountService, TenantService
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
return app, account
def _create_test_end_user(self, db_session_with_containers, app):
"""
Helper method to create a test end user for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
app: App instance to associate the end user with
Returns:
EndUser: Created end user instance
"""
fake = Faker()
end_user = EndUser(
tenant_id=app.tenant_id,
app_id=app.id,
external_user_id=fake.uuid4(),
name=fake.name(),
type="normal",
session_id=fake.uuid4(),
is_anonymous=False,
)
from extensions.ext_database import db
db.session.add(end_user)
db.session.commit()
return end_user
def _create_test_message(self, db_session_with_containers, app, user):
"""
Helper method to create a test message for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
app: App instance to associate the message with
user: User instance (Account or EndUser) to associate the message with
Returns:
Message: Created message instance
"""
fake = Faker()
# Create a simple conversation first
from models.model import Conversation
conversation = Conversation(
app_id=app.id,
from_source="account" if hasattr(user, "current_tenant") else "end_user",
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
from_account_id=user.id if hasattr(user, "current_tenant") else None,
name=fake.sentence(nb_words=3),
inputs={},
status="normal",
mode="chat",
)
from extensions.ext_database import db
db.session.add(conversation)
db.session.commit()
# Create message
message = Message(
app_id=app.id,
conversation_id=conversation.id,
from_source="account" if hasattr(user, "current_tenant") else "end_user",
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
from_account_id=user.id if hasattr(user, "current_tenant") else None,
inputs={},
query=fake.sentence(nb_words=5),
message=fake.text(max_nb_chars=100),
answer=fake.text(max_nb_chars=200),
message_tokens=50,
answer_tokens=100,
message_unit_price=0.001,
answer_unit_price=0.002,
total_price=0.003,
currency="USD",
status="success",
)
db.session.add(message)
db.session.commit()
return message
def test_pagination_by_last_id_success_with_account_user(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful pagination by last ID with account user.
This test verifies:
- Proper pagination with account user
- Correct filtering by app_id and user
- Proper role identification for account users
- MessageService integration
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create test messages
message1 = self._create_test_message(db_session_with_containers, app, account)
message2 = self._create_test_message(db_session_with_containers, app, account)
# Create saved messages
saved_message1 = SavedMessage(
app_id=app.id,
message_id=message1.id,
created_by_role="account",
created_by=account.id,
)
saved_message2 = SavedMessage(
app_id=app.id,
message_id=message2.id,
created_by_role="account",
created_by=account.id,
)
from extensions.ext_database import db
db.session.add_all([saved_message1, saved_message2])
db.session.commit()
# Mock MessageService.pagination_by_last_id return value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
mock_pagination = InfiniteScrollPagination(data=[message1, message2], limit=10, has_more=False)
mock_external_service_dependencies["message_service"].pagination_by_last_id.return_value = mock_pagination
# Act: Execute the method under test
result = SavedMessageService.pagination_by_last_id(app_model=app, user=account, last_id=None, limit=10)
# Assert: Verify the expected outcomes
assert result is not None
assert result.data == [message1, message2]
assert result.limit == 10
assert result.has_more is False
# Verify MessageService was called with correct parameters
# Sort the IDs to handle database query order variations
expected_include_ids = sorted([message1.id, message2.id])
actual_call = mock_external_service_dependencies["message_service"].pagination_by_last_id.call_args
actual_include_ids = sorted(actual_call.kwargs.get("include_ids", []))
assert actual_call.kwargs["app_model"] == app
assert actual_call.kwargs["user"] == account
assert actual_call.kwargs["last_id"] is None
assert actual_call.kwargs["limit"] == 10
assert actual_include_ids == expected_include_ids
# Verify database state
db.session.refresh(saved_message1)
db.session.refresh(saved_message2)
assert saved_message1.id is not None
assert saved_message2.id is not None
assert saved_message1.created_by_role == "account"
assert saved_message2.created_by_role == "account"
def test_pagination_by_last_id_success_with_end_user(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful pagination by last ID with end user.
This test verifies:
- Proper pagination with end user
- Correct filtering by app_id and user
- Proper role identification for end users
- MessageService integration
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
end_user = self._create_test_end_user(db_session_with_containers, app)
# Create test messages
message1 = self._create_test_message(db_session_with_containers, app, end_user)
message2 = self._create_test_message(db_session_with_containers, app, end_user)
# Create saved messages
saved_message1 = SavedMessage(
app_id=app.id,
message_id=message1.id,
created_by_role="end_user",
created_by=end_user.id,
)
saved_message2 = SavedMessage(
app_id=app.id,
message_id=message2.id,
created_by_role="end_user",
created_by=end_user.id,
)
from extensions.ext_database import db
db.session.add_all([saved_message1, saved_message2])
db.session.commit()
# Mock MessageService.pagination_by_last_id return value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
mock_pagination = InfiniteScrollPagination(data=[message1, message2], limit=5, has_more=True)
mock_external_service_dependencies["message_service"].pagination_by_last_id.return_value = mock_pagination
# Act: Execute the method under test
result = SavedMessageService.pagination_by_last_id(
app_model=app, user=end_user, last_id="test_last_id", limit=5
)
# Assert: Verify the expected outcomes
assert result is not None
assert result.data == [message1, message2]
assert result.limit == 5
assert result.has_more is True
# Verify MessageService was called with correct parameters
# Sort the IDs to handle database query order variations
expected_include_ids = sorted([message1.id, message2.id])
actual_call = mock_external_service_dependencies["message_service"].pagination_by_last_id.call_args
actual_include_ids = sorted(actual_call.kwargs.get("include_ids", []))
assert actual_call.kwargs["app_model"] == app
assert actual_call.kwargs["user"] == end_user
assert actual_call.kwargs["last_id"] == "test_last_id"
assert actual_call.kwargs["limit"] == 5
assert actual_include_ids == expected_include_ids
# Verify database state
db.session.refresh(saved_message1)
db.session.refresh(saved_message2)
assert saved_message1.id is not None
assert saved_message2.id is not None
assert saved_message1.created_by_role == "end_user"
assert saved_message2.created_by_role == "end_user"
def test_save_success_with_new_message(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful save of a new message.
This test verifies:
- Proper creation of new saved message
- Correct database state after save
- Proper relationship establishment
- MessageService integration for message retrieval
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Mock MessageService.get_message return value
mock_external_service_dependencies["message_service"].get_message.return_value = message
# Act: Execute the method under test
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
# Assert: Verify the expected outcomes
# Check if saved message was created in database
from extensions.ext_database import db
saved_message = (
db.session.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
SavedMessage.created_by_role == "account",
SavedMessage.created_by == account.id,
)
.first()
)
assert saved_message is not None
assert saved_message.app_id == app.id
assert saved_message.message_id == message.id
assert saved_message.created_by_role == "account"
assert saved_message.created_by == account.id
assert saved_message.created_at is not None
# Verify MessageService.get_message was called
mock_external_service_dependencies["message_service"].get_message.assert_called_once_with(
app_model=app, user=account, message_id=message.id
)
# Verify database state
db.session.refresh(saved_message)
assert saved_message.id is not None
def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test error handling when no user is provided.
This test verifies:
- Proper error handling for missing user
- ValueError is raised when user is None
- No database operations are performed
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
assert "User is required" in str(exc_info.value)
# Verify no database operations were performed
from extensions.ext_database import db
saved_messages = db.session.query(SavedMessage).all()
assert len(saved_messages) == 0
def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test error handling when saving message with no user.
This test verifies:
- Method returns early when user is None
- No database operations are performed
- No exceptions are raised
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Act: Execute the method under test with None user
result = SavedMessageService.save(app_model=app, user=None, message_id=message.id)
# Assert: Verify the expected outcomes
assert result is None
# Verify no saved message was created
from extensions.ext_database import db
saved_message = (
db.session.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
)
.first()
)
assert saved_message is None
def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful deletion of an existing saved message.
This test verifies:
- Proper deletion of existing saved message
- Correct database state after deletion
- No errors during deletion process
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Create a saved message first
saved_message = SavedMessage(
app_id=app.id,
message_id=message.id,
created_by_role="account",
created_by=account.id,
)
from extensions.ext_database import db
db.session.add(saved_message)
db.session.commit()
# Verify saved message exists
assert (
db.session.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
SavedMessage.created_by_role == "account",
SavedMessage.created_by == account.id,
)
.first()
is not None
)
# Act: Execute the method under test
SavedMessageService.delete(app_model=app, user=account, message_id=message.id)
# Assert: Verify the expected outcomes
# Check if saved message was deleted from database
deleted_saved_message = (
db.session.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
SavedMessage.created_by_role == "account",
SavedMessage.created_by == account.id,
)
.first()
)
assert deleted_saved_message is None
# Verify database state
db.session.commit()
# The message should still exist, only the saved_message should be deleted
assert db.session.query(Message).where(Message.id == message.id).first() is not None
def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test error handling when no user is provided.
This test verifies:
- Proper error handling for missing user
- ValueError is raised when user is None
- No database operations are performed
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
assert "User is required" in str(exc_info.value)
# Verify no database operations were performed for this specific test
# Note: We don't check total count as other tests may have created data
# Instead, we verify that the error was properly raised
pass
def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test error handling when saving message with no user.
This test verifies:
- Method returns early when user is None
- No database operations are performed
- No exceptions are raised
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Act: Execute the method under test with None user
result = SavedMessageService.save(app_model=app, user=None, message_id=message.id)
# Assert: Verify the expected outcomes
assert result is None
# Verify no saved message was created
from extensions.ext_database import db
saved_message = (
db.session.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
)
.first()
)
assert saved_message is None
def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful deletion of an existing saved message.
This test verifies:
- Proper deletion of existing saved message
- Correct database state after deletion
- No errors during deletion process
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
message = self._create_test_message(db_session_with_containers, app, account)
# Create a saved message first
saved_message = SavedMessage(
app_id=app.id,
message_id=message.id,
created_by_role="account",
created_by=account.id,
)
from extensions.ext_database import db
db.session.add(saved_message)
db.session.commit()
# Verify saved message exists
assert (
db.session.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
SavedMessage.created_by_role == "account",
SavedMessage.created_by == account.id,
)
.first()
is not None
)
# Act: Execute the method under test
SavedMessageService.delete(app_model=app, user=account, message_id=message.id)
# Assert: Verify the expected outcomes
# Check if saved message was deleted from database
deleted_saved_message = (
db.session.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
SavedMessage.created_by_role == "account",
SavedMessage.created_by == account.id,
)
.first()
)
assert deleted_saved_message is None
# Verify database state
db.session.commit()
# The message should still exist, only the saved_message should be deleted
assert db.session.query(Message).where(Message.id == message.id).first() is not None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,573 @@
from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from models import Account
from models.model import Conversation, EndUser
from models.web import PinnedConversation
from services.account_service import AccountService, TenantService
from services.app_service import AppService
from services.web_conversation_service import WebConversationService
class TestWebConversationService:
"""Integration tests for WebConversationService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.account_service.FeatureService") as mock_account_feature_service,
):
# Setup default mock returns for app service
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
# Setup default mock returns for account service
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
# Mock ModelManager for model configuration
mock_model_instance = mock_model_manager.return_value
mock_model_instance.get_default_model_instance.return_value = None
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
yield {
"feature_service": mock_feature_service,
"enterprise_service": mock_enterprise_service,
"model_manager": mock_model_manager,
"account_feature_service": mock_account_feature_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (app, account) - Created app and account instances
"""
fake = Faker()
# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True
# Create account and tenant
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
return app, account
def _create_test_end_user(self, db_session_with_containers, app):
"""
Helper method to create a test end user for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
app: App instance
Returns:
EndUser: Created end user instance
"""
fake = Faker()
end_user = EndUser(
session_id=fake.uuid4(),
app_id=app.id,
type="normal",
is_anonymous=False,
tenant_id=app.tenant_id,
)
from extensions.ext_database import db
db.session.add(end_user)
db.session.commit()
return end_user
def _create_test_conversation(self, db_session_with_containers, app, user, fake):
"""
Helper method to create a test conversation for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
app: App instance
user: User instance (Account or EndUser)
fake: Faker instance
Returns:
Conversation: Created conversation instance
"""
conversation = Conversation(
app_id=app.id,
app_model_config_id=app.app_model_config_id,
model_provider="openai",
model_id="gpt-3.5-turbo",
mode="chat",
name=fake.sentence(nb_words=3),
summary=fake.text(max_nb_chars=100),
inputs={},
introduction=fake.text(max_nb_chars=200),
system_instruction=fake.text(max_nb_chars=300),
system_instruction_tokens=50,
status="normal",
invoke_from=InvokeFrom.WEB_APP,
from_source="console" if isinstance(user, Account) else "api",
from_end_user_id=user.id if isinstance(user, EndUser) else None,
from_account_id=user.id if isinstance(user, Account) else None,
dialogue_count=0,
is_deleted=False,
)
from extensions.ext_database import db
db.session.add(conversation)
db.session.commit()
return conversation
def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful pagination by last ID with basic parameters.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create multiple conversations
conversations = []
for i in range(5):
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
conversations.append(conversation)
# Test pagination without pinned filter
result = WebConversationService.pagination_by_last_id(
session=db_session_with_containers,
app_model=app,
user=account,
last_id=None,
limit=3,
invoke_from=InvokeFrom.WEB_APP,
pinned=None,
sort_by="-updated_at",
)
# Verify results
assert result.limit == 3
assert len(result.data) == 3
assert result.has_more is True
# Verify conversations are in descending order by updated_at
assert result.data[0].updated_at >= result.data[1].updated_at
assert result.data[1].updated_at >= result.data[2].updated_at
def test_pagination_by_last_id_with_pinned_filter(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test pagination by last ID with pinned conversation filter.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create conversations
conversations = []
for i in range(5):
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
conversations.append(conversation)
# Pin some conversations
pinned_conversation1 = PinnedConversation(
app_id=app.id,
conversation_id=conversations[0].id,
created_by_role="account",
created_by=account.id,
)
pinned_conversation2 = PinnedConversation(
app_id=app.id,
conversation_id=conversations[2].id,
created_by_role="account",
created_by=account.id,
)
from extensions.ext_database import db
db.session.add(pinned_conversation1)
db.session.add(pinned_conversation2)
db.session.commit()
# Test pagination with pinned filter
result = WebConversationService.pagination_by_last_id(
session=db_session_with_containers,
app_model=app,
user=account,
last_id=None,
limit=10,
invoke_from=InvokeFrom.WEB_APP,
pinned=True,
sort_by="-updated_at",
)
# Verify only pinned conversations are returned
assert result.limit == 10
assert len(result.data) == 2
assert result.has_more is False
# Verify the returned conversations are the pinned ones
returned_ids = [conv.id for conv in result.data]
expected_ids = [conversations[0].id, conversations[2].id]
assert set(returned_ids) == set(expected_ids)
def test_pagination_by_last_id_with_unpinned_filter(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test pagination by last ID with unpinned conversation filter.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create conversations
conversations = []
for i in range(5):
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
conversations.append(conversation)
# Pin one conversation
pinned_conversation = PinnedConversation(
app_id=app.id,
conversation_id=conversations[0].id,
created_by_role="account",
created_by=account.id,
)
from extensions.ext_database import db
db.session.add(pinned_conversation)
db.session.commit()
# Test pagination with unpinned filter
result = WebConversationService.pagination_by_last_id(
session=db_session_with_containers,
app_model=app,
user=account,
last_id=None,
limit=10,
invoke_from=InvokeFrom.WEB_APP,
pinned=False,
sort_by="-updated_at",
)
# Verify unpinned conversations are returned (should be 4 out of 5)
assert result.limit == 10
assert len(result.data) == 4
assert result.has_more is False
# Verify the pinned conversation is not in the results
returned_ids = [conv.id for conv in result.data]
assert conversations[0].id not in returned_ids
# Verify all other conversations are in the results
expected_unpinned_ids = [conv.id for conv in conversations[1:]]
assert set(returned_ids) == set(expected_unpinned_ids)
def test_pin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful pinning of a conversation.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
# Pin the conversation
WebConversationService.pin(app, conversation.id, account)
# Verify the conversation was pinned
from extensions.ext_database import db
pinned_conversation = (
db.session.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
PinnedConversation.created_by_role == "account",
PinnedConversation.created_by == account.id,
)
.first()
)
assert pinned_conversation is not None
assert pinned_conversation.app_id == app.id
assert pinned_conversation.conversation_id == conversation.id
assert pinned_conversation.created_by_role == "account"
assert pinned_conversation.created_by == account.id
def test_pin_conversation_already_pinned(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test pinning a conversation that is already pinned (should not create duplicate).
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
# Pin the conversation first time
WebConversationService.pin(app, conversation.id, account)
# Pin the conversation again
WebConversationService.pin(app, conversation.id, account)
# Verify only one pinned conversation record exists
from extensions.ext_database import db
pinned_conversations = db.session.scalars(
select(PinnedConversation).where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
PinnedConversation.created_by_role == "account",
PinnedConversation.created_by == account.id,
)
).all()
assert len(pinned_conversations) == 1
def test_pin_conversation_with_end_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test pinning a conversation with an end user.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create an end user
end_user = self._create_test_end_user(db_session_with_containers, app)
# Create a conversation for the end user
conversation = self._create_test_conversation(db_session_with_containers, app, end_user, fake)
# Pin the conversation
WebConversationService.pin(app, conversation.id, end_user)
# Verify the conversation was pinned
from extensions.ext_database import db
pinned_conversation = (
db.session.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
PinnedConversation.created_by_role == "end_user",
PinnedConversation.created_by == end_user.id,
)
.first()
)
assert pinned_conversation is not None
assert pinned_conversation.app_id == app.id
assert pinned_conversation.conversation_id == conversation.id
assert pinned_conversation.created_by_role == "end_user"
assert pinned_conversation.created_by == end_user.id
def test_unpin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful unpinning of a conversation.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
# Pin the conversation first
WebConversationService.pin(app, conversation.id, account)
# Verify it was pinned
from extensions.ext_database import db
pinned_conversation = (
db.session.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
PinnedConversation.created_by_role == "account",
PinnedConversation.created_by == account.id,
)
.first()
)
assert pinned_conversation is not None
# Unpin the conversation
WebConversationService.unpin(app, conversation.id, account)
# Verify it was unpinned
pinned_conversation = (
db.session.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
PinnedConversation.created_by_role == "account",
PinnedConversation.created_by == account.id,
)
.first()
)
assert pinned_conversation is None
def test_unpin_conversation_not_pinned(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test unpinning a conversation that is not pinned (should not cause error).
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
# Try to unpin a conversation that was never pinned
WebConversationService.unpin(app, conversation.id, account)
# Verify no pinned conversation record exists
from extensions.ext_database import db
pinned_conversation = (
db.session.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
PinnedConversation.created_by_role == "account",
PinnedConversation.created_by == account.id,
)
.first()
)
assert pinned_conversation is None
def test_pagination_by_last_id_user_required_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that pagination_by_last_id raises ValueError when user is None.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Test with None user
with pytest.raises(ValueError, match="User is required"):
WebConversationService.pagination_by_last_id(
session=db_session_with_containers,
app_model=app,
user=None,
last_id=None,
limit=10,
invoke_from=InvokeFrom.WEB_APP,
pinned=None,
sort_by="-updated_at",
)
def test_pin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test that pin method returns early when user is None.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
# Try to pin with None user
WebConversationService.pin(app, conversation.id, None)
# Verify no pinned conversation was created
from extensions.ext_database import db
pinned_conversation = (
db.session.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
)
.first()
)
assert pinned_conversation is None
def test_unpin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test that unpin method returns early when user is None.
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
# Pin the conversation first
WebConversationService.pin(app, conversation.id, account)
# Verify it was pinned
from extensions.ext_database import db
pinned_conversation = (
db.session.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
PinnedConversation.created_by_role == "account",
PinnedConversation.created_by == account.id,
)
.first()
)
assert pinned_conversation is not None
# Try to unpin with None user
WebConversationService.unpin(app, conversation.id, None)
# Verify the conversation is still pinned
pinned_conversation = (
db.session.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
PinnedConversation.created_by_role == "account",
PinnedConversation.created_by == account.id,
)
.first()
)
assert pinned_conversation is not None

View File

@@ -0,0 +1,895 @@
import time
import uuid
from unittest.mock import patch
import pytest
from faker import Faker
from werkzeug.exceptions import NotFound, Unauthorized
from libs.password import hash_password
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole
from models.model import App, Site
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
class TestWebAppAuthService:
"""Integration tests for WebAppAuthService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.webapp_auth_service.PassportService") as mock_passport_service,
patch("services.webapp_auth_service.TokenManager") as mock_token_manager,
patch("services.webapp_auth_service.send_email_code_login_mail_task") as mock_mail_task,
patch("services.webapp_auth_service.AppService") as mock_app_service,
patch("services.webapp_auth_service.EnterpriseService") as mock_enterprise_service,
):
# Setup default mock returns
mock_passport_service.return_value.issue.return_value = "mock_jwt_token"
mock_token_manager.generate_token.return_value = "mock_token"
mock_token_manager.get_token_data.return_value = {"code": "123456"}
mock_mail_task.delay.return_value = None
mock_app_service.get_app_id_by_code.return_value = "mock_app_id"
mock_enterprise_service.WebAppAuth.get_app_access_mode_by_id.return_value = type(
"MockWebAppAuth", (), {"access_mode": "private"}
)()
# Note: get_app_access_mode_by_code method was removed in refactoring
yield {
"passport_service": mock_passport_service,
"token_manager": mock_token_manager,
"mail_task": mock_mail_task,
"app_service": mock_app_service,
"enterprise_service": mock_enterprise_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
import uuid
# Create account with unique email to avoid collisions
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com"
account = Account(
email=unique_email,
name=fake.name(),
interface_language="en-US",
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def _create_test_account_with_password(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account with password for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant, password) - Created account, tenant and password
"""
fake = Faker()
password = fake.password(length=12)
# Create account with password
import uuid
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com"
account = Account(
email=unique_email,
name=fake.name(),
interface_language="en-US",
status="active",
)
# Hash password
salt = b"test_salt_16_bytes"
password_hash = hash_password(password, salt)
# Convert to base64 for storage
import base64
account.password = base64.b64encode(password_hash).decode()
account.password_salt = base64.b64encode(salt).decode()
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant, password
def _create_test_app_and_site(self, db_session_with_containers, mock_external_service_dependencies, tenant):
"""
Helper method to create a test app and site for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
tenant: Tenant instance to associate with
Returns:
tuple: (app, site) - Created app and site instances
"""
fake = Faker()
# Create app
app = App(
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
enable_site=True,
enable_api=True,
)
from extensions.ext_database import db
db.session.add(app)
db.session.commit()
# Create site
site = Site(
app_id=app.id,
title=fake.company(),
code=fake.unique.lexify(text="??????"),
description=fake.text(max_nb_chars=100),
default_language="en-US",
status="normal",
customize_token_strategy="not_allow",
)
db.session.add(site)
db.session.commit()
return app, site
def test_authenticate_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful authentication with valid email and password.
This test verifies:
- Proper authentication with valid credentials
- Correct account return
- Database state consistency
"""
# Arrange: Create test data
account, tenant, password = self._create_test_account_with_password(
db_session_with_containers, mock_external_service_dependencies
)
# Act: Execute authentication
result = WebAppAuthService.authenticate(account.email, password)
# Assert: Verify successful authentication
assert result is not None
assert result.id == account.id
assert result.email == account.email
assert result.name == account.name
assert result.status == AccountStatus.ACTIVE
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
assert result.id is not None
assert result.password is not None
assert result.password_salt is not None
def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test authentication with non-existent email.
This test verifies:
- Proper error handling for non-existent accounts
- Correct exception type and message
"""
# Arrange: Generate a guaranteed non-existent email
# Use UUID and timestamp to ensure uniqueness
unique_id = str(uuid.uuid4()).replace("-", "")
timestamp = str(int(time.time() * 1000000)) # microseconds
non_existent_email = f"nonexistent_{unique_id}_{timestamp}@test-domain-that-never-exists.invalid"
# Double-check this email doesn't exist in the database
existing_account = db_session_with_containers.query(Account).filter_by(email=non_existent_email).first()
assert existing_account is None, f"Test email {non_existent_email} already exists in database"
# Act & Assert: Verify proper error handling
with pytest.raises(AccountNotFoundError):
WebAppAuthService.authenticate(non_existent_email, "any_password")
def test_authenticate_account_banned(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test authentication with banned account.
This test verifies:
- Proper error handling for banned accounts
- Correct exception type and message
"""
# Arrange: Create banned account
fake = Faker()
password = fake.password(length=12)
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status=AccountStatus.BANNED,
)
# Hash password
salt = b"test_salt_16_bytes"
password_hash = hash_password(password, salt)
# Convert to base64 for storage
import base64
account.password = base64.b64encode(password_hash).decode()
account.password_salt = base64.b64encode(salt).decode()
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Act & Assert: Verify proper error handling
with pytest.raises(AccountLoginError) as exc_info:
WebAppAuthService.authenticate(account.email, password)
assert "Account is banned." in str(exc_info.value)
def test_authenticate_invalid_password(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test authentication with invalid password.
This test verifies:
- Proper error handling for invalid passwords
- Correct exception type and message
"""
# Arrange: Create account with password
account, tenant, correct_password = self._create_test_account_with_password(
db_session_with_containers, mock_external_service_dependencies
)
# Act & Assert: Verify proper error handling with wrong password
with pytest.raises(AccountPasswordError) as exc_info:
WebAppAuthService.authenticate(account.email, "wrong_password")
assert "Invalid email or password." in str(exc_info.value)
def test_authenticate_account_without_password(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test authentication for account without password.
This test verifies:
- Proper error handling for accounts without password
- Correct exception type and message
"""
# Arrange: Create account without password
fake = Faker()
import uuid
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com"
account = Account(
email=unique_email,
name=fake.name(),
interface_language="en-US",
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Act & Assert: Verify proper error handling
with pytest.raises(AccountPasswordError) as exc_info:
WebAppAuthService.authenticate(account.email, "any_password")
assert "Invalid email or password." in str(exc_info.value)
def test_login_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful login and JWT token generation.
This test verifies:
- Proper JWT token generation
- Correct token format and content
- Mock service integration
"""
# Arrange: Create test account
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Act: Execute login
result = WebAppAuthService.login(account)
# Assert: Verify successful login
assert result is not None
assert result == "mock_jwt_token"
# Verify mock service was called correctly
mock_external_service_dependencies["passport_service"].return_value.issue.assert_called_once()
call_args = mock_external_service_dependencies["passport_service"].return_value.issue.call_args[0][0]
assert call_args["sub"] == "Web API Passport"
assert call_args["user_id"] == account.id
assert call_args["session_id"] == account.email
assert call_args["token_source"] == "webapp_login_token"
assert call_args["auth_type"] == "internal"
assert "exp" in call_args
def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful user retrieval through email.
This test verifies:
- Proper user retrieval by email
- Correct account return
- Database state consistency
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Act: Execute user retrieval
result = WebAppAuthService.get_user_through_email(account.email)
# Assert: Verify successful retrieval
assert result is not None
assert result.id == account.id
assert result.email == account.email
assert result.name == account.name
assert result.status == AccountStatus.ACTIVE
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
assert result.id is not None
def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test user retrieval with non-existent email.
This test verifies:
- Proper handling for non-existent users
- Correct return value (None)
"""
# Arrange: Use non-existent email
fake = Faker()
non_existent_email = fake.email()
# Act: Execute user retrieval
result = WebAppAuthService.get_user_through_email(non_existent_email)
# Assert: Verify proper handling
assert result is None
def test_get_user_through_email_banned(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test user retrieval with banned account.
This test verifies:
- Proper error handling for banned accounts
- Correct exception type and message
"""
# Arrange: Create banned account
fake = Faker()
import uuid
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com"
account = Account(
email=unique_email,
name=fake.name(),
interface_language="en-US",
status=AccountStatus.BANNED,
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Act & Assert: Verify proper error handling
with pytest.raises(Unauthorized) as exc_info:
WebAppAuthService.get_user_through_email(account.email)
assert "Account is banned." in str(exc_info.value)
def test_send_email_code_login_email_with_account(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test sending email code login email with account.
This test verifies:
- Proper email code generation
- Token generation with correct data
- Mail task scheduling
- Mock service integration
"""
# Arrange: Create test account
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Act: Execute email code login email sending
result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US")
# Assert: Verify successful email sending
assert result is not None
assert result == "mock_token"
# Verify mock services were called correctly
mock_external_service_dependencies["token_manager"].generate_token.assert_called_once()
mock_external_service_dependencies["mail_task"].delay.assert_called_once()
# Verify token generation parameters
token_call_args = mock_external_service_dependencies["token_manager"].generate_token.call_args
assert token_call_args[1]["account"] == account
assert token_call_args[1]["email"] == account.email
assert token_call_args[1]["token_type"] == "email_code_login"
assert "code" in token_call_args[1]["additional_data"]
# Verify mail task parameters
mail_call_args = mock_external_service_dependencies["mail_task"].delay.call_args
assert mail_call_args[1]["language"] == "en-US"
assert mail_call_args[1]["to"] == account.email
assert "code" in mail_call_args[1]
def test_send_email_code_login_email_with_email_only(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test sending email code login email with email only.
This test verifies:
- Proper email code generation without account
- Token generation with email only
- Mail task scheduling
- Mock service integration
"""
# Arrange: Use test email
fake = Faker()
test_email = fake.email()
# Act: Execute email code login email sending
result = WebAppAuthService.send_email_code_login_email(email=test_email, language="zh-Hans")
# Assert: Verify successful email sending
assert result is not None
assert result == "mock_token"
# Verify mock services were called correctly
mock_external_service_dependencies["token_manager"].generate_token.assert_called_once()
mock_external_service_dependencies["mail_task"].delay.assert_called_once()
# Verify token generation parameters
token_call_args = mock_external_service_dependencies["token_manager"].generate_token.call_args
assert token_call_args[1]["account"] is None
assert token_call_args[1]["email"] == test_email
assert token_call_args[1]["token_type"] == "email_code_login"
assert "code" in token_call_args[1]["additional_data"]
# Verify mail task parameters
mail_call_args = mock_external_service_dependencies["mail_task"].delay.call_args
assert mail_call_args[1]["language"] == "zh-Hans"
assert mail_call_args[1]["to"] == test_email
assert "code" in mail_call_args[1]
def test_send_email_code_login_email_no_email_provided(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test sending email code login email without providing email.
This test verifies:
- Proper error handling when no email is provided
- Correct exception type and message
"""
# Arrange: No email provided
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
WebAppAuthService.send_email_code_login_email()
assert "Email must be provided." in str(exc_info.value)
def test_get_email_code_login_data_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of email code login data.
This test verifies:
- Proper token data retrieval
- Correct data format
- Mock service integration
"""
# Arrange: Setup mock return
expected_data = {"code": "123456", "email": "test@example.com"}
mock_external_service_dependencies["token_manager"].get_token_data.return_value = expected_data
# Act: Execute data retrieval
result = WebAppAuthService.get_email_code_login_data("mock_token")
# Assert: Verify successful retrieval
assert result is not None
assert result == expected_data
assert result["code"] == "123456"
assert result["email"] == "test@example.com"
# Verify mock service was called correctly
mock_external_service_dependencies["token_manager"].get_token_data.assert_called_once_with(
"mock_token", "email_code_login"
)
def test_get_email_code_login_data_no_data(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test email code login data retrieval when no data exists.
This test verifies:
- Proper handling when no token data exists
- Correct return value (None)
- Mock service integration
"""
# Arrange: Setup mock return for no data
mock_external_service_dependencies["token_manager"].get_token_data.return_value = None
# Act: Execute data retrieval
result = WebAppAuthService.get_email_code_login_data("invalid_token")
# Assert: Verify proper handling
assert result is None
# Verify mock service was called correctly
mock_external_service_dependencies["token_manager"].get_token_data.assert_called_once_with(
"invalid_token", "email_code_login"
)
def test_revoke_email_code_login_token_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful revocation of email code login token.
This test verifies:
- Proper token revocation
- Mock service integration
"""
# Arrange: Setup mock
# Act: Execute token revocation
WebAppAuthService.revoke_email_code_login_token("mock_token")
# Assert: Verify mock service was called correctly
mock_external_service_dependencies["token_manager"].revoke_token.assert_called_once_with(
"mock_token", "email_code_login"
)
def test_create_end_user_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful end user creation.
This test verifies:
- Proper end user creation with valid app code
- Correct database state after creation
- Proper relationship establishment
- Mock service integration
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
app, site = self._create_test_app_and_site(
db_session_with_containers, mock_external_service_dependencies, tenant
)
# Act: Execute end user creation
result = WebAppAuthService.create_end_user(site.code, "test@example.com")
# Assert: Verify successful creation
assert result is not None
assert result.tenant_id == app.tenant_id
assert result.app_id == app.id
assert result.type == "browser"
assert result.is_anonymous is False
assert result.session_id == "test@example.com"
assert result.name == "enterpriseuser"
assert result.external_user_id == "enterpriseuser"
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
assert result.id is not None
assert result.created_at is not None
assert result.updated_at is not None
def test_create_end_user_site_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test end user creation with non-existent site code.
This test verifies:
- Proper error handling for non-existent sites
- Correct exception type and message
"""
# Arrange: Use non-existent site code
fake = Faker()
non_existent_code = fake.unique.lexify(text="??????")
# Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info:
WebAppAuthService.create_end_user(non_existent_code, "test@example.com")
assert "Site not found." in str(exc_info.value)
def test_create_end_user_app_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test end user creation when app is not found.
This test verifies:
- Proper error handling when app is missing
- Correct exception type and message
"""
# Arrange: Create site without app
fake = Faker()
tenant = Tenant(
name=fake.company(),
status="normal",
)
from extensions.ext_database import db
db.session.add(tenant)
db.session.commit()
site = Site(
app_id="00000000-0000-0000-0000-000000000000",
title=fake.company(),
code=fake.unique.lexify(text="??????"),
description=fake.text(max_nb_chars=100),
default_language="en-US",
status="normal",
customize_token_strategy="not_allow",
)
db.session.add(site)
db.session.commit()
# Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info:
WebAppAuthService.create_end_user(site.code, "test@example.com")
assert "App not found." in str(exc_info.value)
def test_is_app_require_permission_check_with_access_mode_private(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test permission check requirement for private access mode.
This test verifies:
- Proper permission check requirement for private mode
- Correct return value
- Mock service integration
"""
# Arrange: Setup test with private access mode
# Act: Execute permission check requirement test
result = WebAppAuthService.is_app_require_permission_check(access_mode="private")
# Assert: Verify correct result
assert result is True
def test_is_app_require_permission_check_with_access_mode_public(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test permission check requirement for public access mode.
This test verifies:
- Proper permission check requirement for public mode
- Correct return value
- Mock service integration
"""
# Arrange: Setup test with public access mode
# Act: Execute permission check requirement test
result = WebAppAuthService.is_app_require_permission_check(access_mode="public")
# Assert: Verify correct result
assert result is False
def test_is_app_require_permission_check_with_app_code(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test permission check requirement using app code.
This test verifies:
- Proper permission check requirement using app code
- Correct return value
- Mock service integration
"""
# Arrange: Setup mock for app service
mock_external_service_dependencies["app_service"].get_app_id_by_code.return_value = "mock_app_id"
# Act: Execute permission check requirement test
result = WebAppAuthService.is_app_require_permission_check(app_code="mock_app_code")
# Assert: Verify correct result
assert result is True
# Verify mock service was called correctly
mock_external_service_dependencies["app_service"].get_app_id_by_code.assert_called_once_with("mock_app_code")
mock_external_service_dependencies[
"enterprise_service"
].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with("mock_app_id")
def test_is_app_require_permission_check_no_parameters(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test permission check requirement with no parameters.
This test verifies:
- Proper error handling when no parameters provided
- Correct exception type and message
"""
# Arrange: No parameters provided
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
WebAppAuthService.is_app_require_permission_check()
assert "Either app_code or app_id must be provided." in str(exc_info.value)
def test_get_app_auth_type_with_access_mode_public(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test app authentication type for public access mode.
This test verifies:
- Proper authentication type determination for public mode
- Correct return value
- Mock service integration
"""
# Arrange: Setup test with public access mode
# Act: Execute authentication type determination
result = WebAppAuthService.get_app_auth_type(access_mode="public")
# Assert: Verify correct result
assert result == WebAppAuthType.PUBLIC
def test_get_app_auth_type_with_access_mode_private(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test app authentication type for private access mode.
This test verifies:
- Proper authentication type determination for private mode
- Correct return value
- Mock service integration
"""
# Arrange: Setup test with private access mode
# Act: Execute authentication type determination
result = WebAppAuthService.get_app_auth_type(access_mode="private")
# Assert: Verify correct result
assert result == WebAppAuthType.INTERNAL
def test_get_app_auth_type_with_app_code(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app authentication type using app code.
This test verifies:
- Proper authentication type determination using app code
- Correct return value
- Mock service integration
"""
# Arrange: Setup mock for enterprise service
mock_external_service_dependencies["app_service"].get_app_id_by_code.return_value = "mock_app_id"
setting = type("MockWebAppAuth", (), {"access_mode": "sso_verified"})()
mock_external_service_dependencies[
"enterprise_service"
].WebAppAuth.get_app_access_mode_by_id.return_value = setting
# Act: Execute authentication type determination
result: WebAppAuthType = WebAppAuthService.get_app_auth_type(app_code="mock_app_code")
# Assert: Verify correct result
assert result == WebAppAuthType.EXTERNAL
# Verify mock service was called correctly
mock_external_service_dependencies[
"enterprise_service"
].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with(app_id="mock_app_id")
def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test app authentication type with no parameters.
This test verifies:
- Proper error handling when no parameters provided
- Correct exception type and message
"""
# Arrange: No parameters provided
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
WebAppAuthService.get_app_auth_type()
assert "Either app_code or access_mode must be provided." in str(exc_info.value)

View File

@@ -0,0 +1,571 @@
import json
from io import BytesIO
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from flask import Flask
from werkzeug.datastructures import FileStorage
from models.enums import AppTriggerStatus, AppTriggerType
from models.model import App
from models.trigger import AppTrigger, WorkflowWebhookTrigger
from models.workflow import Workflow
from services.account_service import AccountService, TenantService
from services.trigger.webhook_service import WebhookService
class TestWebhookService:
"""Integration tests for WebhookService using testcontainers."""
@pytest.fixture
def mock_external_dependencies(self):
"""Mock external service dependencies."""
with (
patch("services.trigger.webhook_service.AsyncWorkflowService") as mock_async_service,
patch("services.trigger.webhook_service.ToolFileManager") as mock_tool_file_manager,
patch("services.trigger.webhook_service.file_factory") as mock_file_factory,
patch("services.account_service.FeatureService") as mock_feature_service,
):
# Mock ToolFileManager
mock_tool_file_instance = MagicMock()
mock_tool_file_manager.return_value = mock_tool_file_instance
# Mock file creation
mock_tool_file = MagicMock()
mock_tool_file.id = "test_file_id"
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
# Mock file factory
mock_file_obj = MagicMock()
mock_file_factory.build_from_mapping.return_value = mock_file_obj
# Mock feature service
mock_feature_service.get_system_features.return_value.is_allow_register = True
mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True
yield {
"async_service": mock_async_service,
"tool_file_manager": mock_tool_file_manager,
"file_factory": mock_file_factory,
"tool_file": mock_tool_file,
"file_obj": mock_file_obj,
"feature_service": mock_feature_service,
}
@pytest.fixture
def test_data(self, db_session_with_containers, mock_external_dependencies):
"""Create test data for webhook service tests."""
fake = Faker()
# Create account and tenant
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
assert tenant is not None
# Create app
app = App(
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(),
mode="workflow",
icon="",
icon_background="",
enable_site=True,
enable_api=True,
)
db_session_with_containers.add(app)
db_session_with_containers.flush()
# Create workflow
workflow_data = {
"nodes": [
{
"id": "webhook_node",
"type": "webhook",
"data": {
"title": "Test Webhook",
"method": "post",
"content_type": "application/json",
"headers": [
{"name": "Authorization", "required": True},
{"name": "Content-Type", "required": False},
],
"params": [{"name": "version", "required": True}, {"name": "format", "required": False}],
"body": [
{"name": "message", "type": "string", "required": True},
{"name": "count", "type": "number", "required": False},
{"name": "upload", "type": "file", "required": False},
],
"status_code": 200,
"response_body": '{"status": "success"}',
"timeout": 30,
},
}
],
"edges": [],
}
workflow = Workflow(
tenant_id=tenant.id,
app_id=app.id,
type="workflow",
graph=json.dumps(workflow_data),
features=json.dumps({}),
created_by=account.id,
environment_variables=[],
conversation_variables=[],
version="1.0",
)
db_session_with_containers.add(workflow)
db_session_with_containers.flush()
# Create webhook trigger
webhook_id = fake.uuid4()[:16]
webhook_trigger = WorkflowWebhookTrigger(
app_id=app.id,
node_id="webhook_node",
tenant_id=tenant.id,
webhook_id=str(webhook_id),
created_by=account.id,
)
db_session_with_containers.add(webhook_trigger)
db_session_with_containers.flush()
# Create app trigger (required for non-debug mode)
app_trigger = AppTrigger(
tenant_id=tenant.id,
app_id=app.id,
node_id="webhook_node",
trigger_type=AppTriggerType.TRIGGER_WEBHOOK,
provider_name="webhook",
title="Test Webhook",
status=AppTriggerStatus.ENABLED,
)
db_session_with_containers.add(app_trigger)
db_session_with_containers.commit()
return {
"tenant": tenant,
"account": account,
"app": app,
"workflow": workflow,
"webhook_trigger": webhook_trigger,
"webhook_id": webhook_id,
"app_trigger": app_trigger,
}
def test_get_webhook_trigger_and_workflow_success(self, test_data, flask_app_with_containers):
"""Test successful retrieval of webhook trigger and workflow."""
webhook_id = test_data["webhook_id"]
with flask_app_with_containers.app_context():
webhook_trigger, workflow, node_config = WebhookService.get_webhook_trigger_and_workflow(webhook_id)
assert webhook_trigger is not None
assert webhook_trigger.webhook_id == webhook_id
assert workflow is not None
assert workflow.app_id == test_data["app"].id
assert node_config is not None
assert node_config["id"] == "webhook_node"
assert node_config["data"]["title"] == "Test Webhook"
def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers):
"""Test webhook trigger not found scenario."""
with flask_app_with_containers.app_context():
with pytest.raises(ValueError, match="Webhook not found"):
WebhookService.get_webhook_trigger_and_workflow("nonexistent_webhook")
def test_extract_webhook_data_json(self):
"""Test webhook data extraction from JSON request."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/json", "Authorization": "Bearer token"},
query_string="version=1&format=json",
json={"message": "hello", "count": 42},
):
webhook_trigger = MagicMock()
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
assert webhook_data["method"] == "POST"
assert webhook_data["headers"]["Authorization"] == "Bearer token"
assert webhook_data["query_params"]["version"] == "1"
assert webhook_data["query_params"]["format"] == "json"
assert webhook_data["body"]["message"] == "hello"
assert webhook_data["body"]["count"] == 42
assert webhook_data["files"] == {}
def test_extract_webhook_data_form_urlencoded(self):
"""Test webhook data extraction from form URL encoded request."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={"username": "test", "password": "secret"},
):
webhook_trigger = MagicMock()
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
assert webhook_data["method"] == "POST"
assert webhook_data["body"]["username"] == "test"
assert webhook_data["body"]["password"] == "secret"
def test_extract_webhook_data_multipart_with_files(self, mock_external_dependencies):
"""Test webhook data extraction from multipart form with files."""
app = Flask(__name__)
# Create a mock file
file_content = b"test file content"
file_storage = FileStorage(stream=BytesIO(file_content), filename="test.txt", content_type="text/plain")
with app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "multipart/form-data"},
data={"message": "test", "upload": file_storage},
):
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
assert webhook_data["method"] == "POST"
assert webhook_data["body"]["message"] == "test"
assert "upload" in webhook_data["files"]
# Verify file processing was called
mock_external_dependencies["tool_file_manager"].assert_called_once()
mock_external_dependencies["file_factory"].build_from_mapping.assert_called_once()
def test_extract_webhook_data_raw_text(self):
"""Test webhook data extraction from raw text request."""
app = Flask(__name__)
with app.test_request_context(
"/webhook", method="POST", headers={"Content-Type": "text/plain"}, data="raw text content"
):
webhook_trigger = MagicMock()
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
assert webhook_data["method"] == "POST"
assert webhook_data["body"]["raw"] == "raw text content"
def test_extract_and_validate_webhook_request_success(self):
"""Test successful webhook request validation and type conversion."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/json", "Authorization": "Bearer token"},
query_string="version=1",
json={"message": "hello"},
):
webhook_trigger = MagicMock()
node_config = {
"data": {
"method": "post",
"content_type": "application/json",
"headers": [
{"name": "Authorization", "required": True},
{"name": "Content-Type", "required": False},
],
"params": [{"name": "version", "required": True}],
"body": [{"name": "message", "type": "string", "required": True}],
}
}
result = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
assert result["headers"]["Authorization"] == "Bearer token"
assert result["query_params"]["version"] == "1"
assert result["body"]["message"] == "hello"
def test_extract_and_validate_webhook_request_method_mismatch(self):
"""Test webhook validation with HTTP method mismatch."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="GET",
headers={"Content-Type": "application/json"},
):
webhook_trigger = MagicMock()
node_config = {"data": {"method": "post", "content_type": "application/json"}}
with pytest.raises(ValueError, match="HTTP method mismatch"):
WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
def test_extract_and_validate_webhook_request_missing_required_header(self):
"""Test webhook validation with missing required header."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/json"},
):
webhook_trigger = MagicMock()
node_config = {
"data": {
"method": "post",
"content_type": "application/json",
"headers": [{"name": "Authorization", "required": True}],
}
}
with pytest.raises(ValueError, match="Required header missing: Authorization"):
WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
def test_extract_and_validate_webhook_request_case_insensitive_headers(self):
"""Test webhook validation with case-insensitive header matching."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/json", "authorization": "Bearer token"},
json={"message": "hello"},
):
webhook_trigger = MagicMock()
node_config = {
"data": {
"method": "post",
"content_type": "application/json",
"headers": [{"name": "Authorization", "required": True}],
"body": [{"name": "message", "type": "string", "required": True}],
}
}
result = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
assert result["headers"].get("Authorization") == "Bearer token"
def test_extract_and_validate_webhook_request_missing_required_param(self):
"""Test webhook validation with missing required query parameter."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/json"},
json={"message": "hello"},
):
webhook_trigger = MagicMock()
node_config = {
"data": {
"method": "post",
"content_type": "application/json",
"params": [{"name": "version", "required": True}],
"body": [{"name": "message", "type": "string", "required": True}],
}
}
with pytest.raises(ValueError, match="Required parameter missing: version"):
WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
def test_extract_and_validate_webhook_request_missing_required_body_param(self):
"""Test webhook validation with missing required body parameter."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/json"},
json={},
):
webhook_trigger = MagicMock()
node_config = {
"data": {
"method": "post",
"content_type": "application/json",
"body": [{"name": "message", "type": "string", "required": True}],
}
}
with pytest.raises(ValueError, match="Required body parameter missing: message"):
WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
def test_extract_and_validate_webhook_request_missing_required_file(self):
"""Test webhook validation when required file is missing from multipart request."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="POST",
data={"note": "test"},
content_type="multipart/form-data",
):
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "tenant"
webhook_trigger.created_by = "user"
node_config = {
"data": {
"method": "post",
"content_type": "multipart/form-data",
"body": [{"name": "upload", "type": "file", "required": True}],
}
}
result = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
assert result["files"] == {}
def test_trigger_workflow_execution_success(self, test_data, mock_external_dependencies, flask_app_with_containers):
"""Test successful workflow execution trigger."""
webhook_data = {
"method": "POST",
"headers": {"Authorization": "Bearer token"},
"query_params": {"version": "1"},
"body": {"message": "hello"},
"files": {},
}
with flask_app_with_containers.app_context():
# Mock tenant owner lookup to return the test account
with patch("services.trigger.webhook_service.select") as mock_select:
mock_query = MagicMock()
mock_select.return_value.join.return_value.where.return_value = mock_query
# Mock the session to return our test account
with patch("services.trigger.webhook_service.Session") as mock_session:
mock_session_instance = MagicMock()
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_session_instance.scalar.return_value = test_data["account"]
# Should not raise any exceptions
WebhookService.trigger_workflow_execution(
test_data["webhook_trigger"], webhook_data, test_data["workflow"]
)
# Verify AsyncWorkflowService was called
mock_external_dependencies["async_service"].trigger_workflow_async.assert_called_once()
def test_trigger_workflow_execution_end_user_service_failure(
self, test_data, mock_external_dependencies, flask_app_with_containers
):
"""Test workflow execution trigger when EndUserService fails."""
webhook_data = {"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}}
with flask_app_with_containers.app_context():
# Mock EndUserService to raise an exception
with patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type"
) as mock_end_user:
mock_end_user.side_effect = ValueError("Failed to create end user")
with pytest.raises(ValueError, match="Failed to create end user"):
WebhookService.trigger_workflow_execution(
test_data["webhook_trigger"], webhook_data, test_data["workflow"]
)
def test_generate_webhook_response_default(self):
"""Test webhook response generation with default values."""
node_config = {"data": {}}
response_data, status_code = WebhookService.generate_webhook_response(node_config)
assert status_code == 200
assert response_data["status"] == "success"
assert "Webhook processed successfully" in response_data["message"]
def test_generate_webhook_response_custom_json(self):
"""Test webhook response generation with custom JSON response."""
node_config = {"data": {"status_code": 201, "response_body": '{"result": "created", "id": 123}'}}
response_data, status_code = WebhookService.generate_webhook_response(node_config)
assert status_code == 201
assert response_data["result"] == "created"
assert response_data["id"] == 123
def test_generate_webhook_response_custom_text(self):
"""Test webhook response generation with custom text response."""
node_config = {"data": {"status_code": 202, "response_body": "Request accepted for processing"}}
response_data, status_code = WebhookService.generate_webhook_response(node_config)
assert status_code == 202
assert response_data["message"] == "Request accepted for processing"
def test_generate_webhook_response_invalid_json(self):
"""Test webhook response generation with invalid JSON response."""
node_config = {"data": {"status_code": 400, "response_body": '{"invalid": json}'}}
response_data, status_code = WebhookService.generate_webhook_response(node_config)
assert status_code == 400
assert response_data["message"] == '{"invalid": json}'
def test_process_file_uploads_success(self, mock_external_dependencies):
"""Test successful file upload processing."""
# Create mock files
files = {
"file1": MagicMock(filename="test1.txt", content_type="text/plain"),
"file2": MagicMock(filename="test2.jpg", content_type="image/jpeg"),
}
# Mock file reads
files["file1"].read.return_value = b"content1"
files["file2"].read.return_value = b"content2"
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
result = WebhookService._process_file_uploads(files, webhook_trigger)
assert len(result) == 2
assert "file1" in result
assert "file2" in result
# Verify file processing was called for each file
assert mock_external_dependencies["tool_file_manager"].call_count == 2
assert mock_external_dependencies["file_factory"].build_from_mapping.call_count == 2
def test_process_file_uploads_with_errors(self, mock_external_dependencies):
"""Test file upload processing with errors."""
# Create mock files, one will fail
files = {
"good_file": MagicMock(filename="test.txt", content_type="text/plain"),
"bad_file": MagicMock(filename="test.bad", content_type="text/plain"),
}
files["good_file"].read.return_value = b"content"
files["bad_file"].read.side_effect = Exception("Read error")
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
result = WebhookService._process_file_uploads(files, webhook_trigger)
# Should process the good file and skip the bad one
assert len(result) == 1
assert "good_file" in result
assert "bad_file" not in result
def test_process_file_uploads_empty_filename(self, mock_external_dependencies):
"""Test file upload processing with empty filename."""
files = {
"no_filename": MagicMock(filename="", content_type="text/plain"),
"none_filename": MagicMock(filename=None, content_type="text/plain"),
}
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
result = WebhookService._process_file_uploads(files, webhook_trigger)
# Should skip files without filenames
assert len(result) == 0
mock_external_dependencies["tool_file_manager"].assert_not_called()

View File

@@ -0,0 +1,841 @@
import pytest
from faker import Faker
from core.variables.segments import StringSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from models import App, Workflow
from models.enums import DraftVariableType
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import (
UpdateNotSupportedError,
WorkflowDraftVariableService,
)
def _get_random_variable_name(fake: Faker):
return "".join(fake.random_letters(length=10))
class TestWorkflowDraftVariableService:
"""
Comprehensive integration tests for WorkflowDraftVariableService using testcontainers.
This test class covers all major functionality of the WorkflowDraftVariableService:
- CRUD operations for workflow draft variables (Create, Read, Update, Delete)
- Variable listing and filtering by type (conversation, system, node)
- Variable updates and resets with proper validation
- Variable deletion operations at different scopes
- Special functionality like prefill and conversation ID retrieval
- Error handling for various edge cases and invalid operations
All tests use the testcontainers infrastructure to ensure proper database isolation
and realistic testing environment with actual database interactions.
"""
@pytest.fixture
def mock_external_service_dependencies(self):
"""
Mock setup for external service dependencies.
WorkflowDraftVariableService doesn't have external dependencies that need mocking,
so this fixture returns an empty dictionary to maintain consistency with other test classes.
This ensures the test structure remains consistent across different service test files.
"""
# WorkflowDraftVariableService doesn't have external dependencies that need mocking
return {}
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None):
"""
Helper method to create a test app with realistic data for testing.
This method creates a complete App instance with all required fields populated
using Faker for generating realistic test data. The app is configured for
workflow mode to support workflow draft variable testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies (unused in this service)
fake: Faker instance for generating test data, creates new instance if not provided
Returns:
App: Created test app instance with all required fields populated
"""
fake = fake or Faker()
app = App()
app.id = fake.uuid4()
app.tenant_id = fake.uuid4()
app.name = fake.company()
app.description = fake.text()
app.mode = "workflow"
app.icon_type = "emoji"
app.icon = "🤖"
app.icon_background = "#FFEAD5"
app.enable_site = True
app.enable_api = True
app.created_by = fake.uuid4()
app.updated_by = app.created_by
from extensions.ext_database import db
db.session.add(app)
db.session.commit()
return app
def _create_test_workflow(self, db_session_with_containers, app, fake=None):
"""
Helper method to create a test workflow associated with an app.
This method creates a Workflow instance using the proper factory method
to ensure all required fields are set correctly. The workflow is configured
as a draft version with basic graph structure for testing workflow variables.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
app: The app to associate the workflow with
fake: Faker instance for generating test data, creates new instance if not provided
Returns:
Workflow: Created test workflow instance with proper configuration
"""
fake = fake or Faker()
workflow = Workflow.new(
tenant_id=app.tenant_id,
app_id=app.id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features="{}",
created_by=app.created_by,
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
from extensions.ext_database import db
db.session.add(workflow)
db.session.commit()
return workflow
def _create_test_variable(
self,
db_session_with_containers,
app_id,
node_id,
name,
value,
variable_type: DraftVariableType = DraftVariableType.CONVERSATION,
fake=None,
):
"""
Helper method to create a test workflow draft variable with proper configuration.
This method creates different types of variables (conversation, system, node) using
the appropriate factory methods to ensure proper initialization. Each variable type
has specific requirements and this method handles the creation logic for all types.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
app_id: ID of the app to associate the variable with
node_id: ID of the node (or special constants like CONVERSATION_VARIABLE_NODE_ID)
name: Name of the variable for identification
value: StringSegment value for the variable content
variable_type: Type of variable ("conversation", "system", "node") determining creation method
fake: Faker instance for generating test data, creates new instance if not provided
Returns:
WorkflowDraftVariable: Created test variable instance with proper type configuration
"""
fake = fake or Faker()
if variable_type == "conversation":
# Create conversation variable using the appropriate factory method
variable = WorkflowDraftVariable.new_conversation_variable(
app_id=app_id,
name=name,
value=value,
description=fake.text(max_nb_chars=20),
)
elif variable_type == "system":
# Create system variable with editable flag and execution context
variable = WorkflowDraftVariable.new_sys_variable(
app_id=app_id,
name=name,
value=value,
node_execution_id=fake.uuid4(),
editable=True,
)
else: # node variable
# Create node variable with visibility and editability settings
variable = WorkflowDraftVariable.new_node_variable(
app_id=app_id,
node_id=node_id,
name=name,
value=value,
node_execution_id=fake.uuid4(),
visible=True,
editable=True,
)
from extensions.ext_database import db
db.session.add(variable)
db.session.commit()
return variable
def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test getting a single variable by ID successfully.
This test verifies that the service can retrieve a specific variable
by its ID and that the returned variable contains the correct data.
It ensures the basic CRUD read operation works correctly for workflow draft variables.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
test_value = StringSegment(value=fake.word())
variable = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_variable = service.get_variable(variable.id)
assert retrieved_variable is not None
assert retrieved_variable.id == variable.id
assert retrieved_variable.name == "test_var"
assert retrieved_variable.app_id == app.id
assert retrieved_variable.get_value().value == test_value.value
def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test getting a variable that doesn't exist.
This test verifies that the service returns None when trying to
retrieve a variable with a non-existent ID. This ensures proper
handling of missing data scenarios.
"""
fake = Faker()
non_existent_id = fake.uuid4()
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_variable = service.get_variable(non_existent_id)
assert retrieved_variable is None
def test_get_draft_variables_by_selectors_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting variables by selectors successfully.
This test verifies that the service can retrieve multiple variables
using selector pairs (node_id, variable_name) and returns the correct
variables for each selector. This is useful for bulk variable retrieval
operations in workflow execution contexts.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
var1_value = StringSegment(value=fake.word())
var2_value = StringSegment(value=fake.word())
var3_value = StringSegment(value=fake.word())
var1 = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var1", var1_value, fake=fake
)
var2 = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var2", var2_value, fake=fake
)
var3 = self._create_test_variable(
db_session_with_containers,
app.id,
"test_node_1",
"var3",
var3_value,
variable_type=DraftVariableType.NODE,
fake=fake,
)
selectors = [
[CONVERSATION_VARIABLE_NODE_ID, "var1"],
[CONVERSATION_VARIABLE_NODE_ID, "var2"],
["test_node_1", "var3"],
]
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_variables = service.get_draft_variables_by_selectors(app.id, selectors)
assert len(retrieved_variables) == 3
var_names = [var.name for var in retrieved_variables]
assert "var1" in var_names
assert "var2" in var_names
assert "var3" in var_names
for var in retrieved_variables:
if var.name == "var1":
assert var.get_value().value == var1_value.value
elif var.name == "var2":
assert var.get_value().value == var2_value.value
elif var.name == "var3":
assert var.get_value().value == var3_value.value
def test_list_variables_without_values_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test listing variables without values successfully with pagination.
This test verifies that the service can list variables with pagination
and that the returned variables don't include their values (for performance).
This is important for scenarios where only variable metadata is needed
without loading the actual content.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
for i in range(5):
test_value = StringSegment(value=fake.numerify("value######"))
self._create_test_variable(
db_session_with_containers,
app.id,
CONVERSATION_VARIABLE_NODE_ID,
_get_random_variable_name(fake),
test_value,
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_variables_without_values(app.id, page=1, limit=3)
assert result.total == 5
assert len(result.variables) == 3
assert result.variables[0].created_at >= result.variables[1].created_at
assert result.variables[1].created_at >= result.variables[2].created_at
for var in result.variables:
assert var.name is not None
assert var.app_id == app.id
def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test listing variables for a specific node successfully.
This test verifies that the service can filter and return only
variables associated with a specific node ID. This is crucial for
workflow execution where variables need to be scoped to specific nodes.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
node_id = fake.word()
var1_value = StringSegment(value=fake.word())
var2_value = StringSegment(value=fake.word())
var3_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers,
app.id,
node_id,
"var1",
var1_value,
variable_type=DraftVariableType.NODE,
fake=fake,
)
self._create_test_variable(
db_session_with_containers,
app.id,
node_id,
"var2",
var3_value,
variable_type=DraftVariableType.NODE,
fake=fake,
)
self._create_test_variable(
db_session_with_containers,
app.id,
"other_node",
"var3",
var2_value,
variable_type=DraftVariableType.NODE,
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_node_variables(app.id, node_id)
assert len(result.variables) == 2
for var in result.variables:
assert var.node_id == node_id
assert var.app_id == app.id
var_names = [var.name for var in result.variables]
assert "var1" in var_names
assert "var2" in var_names
assert "var3" not in var_names
def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test listing conversation variables successfully.
This test verifies that the service can filter and return only
conversation variables, excluding system and node variables.
Conversation variables are user-facing variables that can be
modified during conversation flows.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
conv_var1_value = StringSegment(value=fake.word())
conv_var2_value = StringSegment(value=fake.word())
conv_var1 = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var1", conv_var1_value, fake=fake
)
conv_var2 = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var2", conv_var2_value, fake=fake
)
sys_var_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers,
app.id,
SYSTEM_VARIABLE_NODE_ID,
"sys_var",
sys_var_value,
variable_type=DraftVariableType.SYS,
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_conversation_variables(app.id)
assert len(result.variables) == 2
for var in result.variables:
assert var.node_id == CONVERSATION_VARIABLE_NODE_ID
assert var.app_id == app.id
assert var.get_variable_type() == DraftVariableType.CONVERSATION
var_names = [var.name for var in result.variables]
assert "conv_var1" in var_names
assert "conv_var2" in var_names
assert "sys_var" not in var_names
def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test updating a variable's name and value successfully.
This test verifies that the service can update both the name and value
of an editable variable and that the changes are persisted correctly.
It also checks that the last_edited_at timestamp is updated appropriately.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
original_value = StringSegment(value=fake.word())
new_value = StringSegment(value=fake.word())
variable = self._create_test_variable(
db_session_with_containers,
app.id,
CONVERSATION_VARIABLE_NODE_ID,
"original_name",
original_value,
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
updated_variable = service.update_variable(variable, name="new_name", value=new_value)
assert updated_variable.name == "new_name"
assert updated_variable.get_value().value == new_value.value
assert updated_variable.last_edited_at is not None
from extensions.ext_database import db
db.session.refresh(variable)
assert variable.name == "new_name"
assert variable.get_value().value == new_value.value
assert variable.last_edited_at is not None
def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test that updating a non-editable variable raises an exception.
This test verifies that the service properly prevents updates to
variables that are not marked as editable. This is important for
maintaining data integrity and preventing unauthorized modifications
to system-controlled variables.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
original_value = StringSegment(value=fake.word())
new_value = StringSegment(value=fake.word())
variable = WorkflowDraftVariable.new_sys_variable(
app_id=app.id,
name=fake.word(), # This is typically not editable
value=original_value,
node_execution_id=fake.uuid4(),
editable=False, # Set as non-editable
)
from extensions.ext_database import db
db.session.add(variable)
db.session.commit()
service = WorkflowDraftVariableService(db_session_with_containers)
with pytest.raises(UpdateNotSupportedError) as exc_info:
service.update_variable(variable, name="new_name", value=new_value)
assert "variable not support updating" in str(exc_info.value)
assert variable.id in str(exc_info.value)
def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test resetting conversation variable successfully.
This test verifies that the service can reset a conversation variable
to its default value and clear the last_edited_at timestamp.
This functionality is useful for reverting user modifications
back to the original workflow configuration.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake)
from core.variables.variables import StringVariable
conv_var = StringVariable(
id=fake.uuid4(),
name="test_conv_var",
value="default_value",
selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"],
)
workflow.conversation_variables = [conv_var]
from extensions.ext_database import db
db.session.commit()
modified_value = StringSegment(value=fake.word())
variable = self._create_test_variable(
db_session_with_containers,
app.id,
CONVERSATION_VARIABLE_NODE_ID,
"test_conv_var",
modified_value,
fake=fake,
)
variable.last_edited_at = fake.date_time()
db.session.commit()
service = WorkflowDraftVariableService(db_session_with_containers)
reset_variable = service.reset_variable(workflow, variable)
assert reset_variable is not None
assert reset_variable.get_value().value == "default_value"
assert reset_variable.last_edited_at is None
db.session.refresh(variable)
assert variable.get_value().value == "default_value"
assert variable.last_edited_at is None
def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test deleting a single variable successfully.
This test verifies that the service can delete a specific variable
and that it's properly removed from the database. It ensures that
the deletion operation is atomic and complete.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
test_value = StringSegment(value=fake.word())
variable = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
)
from extensions.ext_database import db
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None
service = WorkflowDraftVariableService(db_session_with_containers)
service.delete_variable(variable)
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None
def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test deleting all variables for a workflow successfully.
This test verifies that the service can delete all variables
associated with a specific app/workflow. This is useful for
cleanup operations when workflows are deleted or reset.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
for i in range(3):
test_value = StringSegment(value=fake.numerify("value######"))
self._create_test_variable(
db_session_with_containers,
app.id,
CONVERSATION_VARIABLE_NODE_ID,
_get_random_variable_name(fake),
test_value,
fake=fake,
)
other_app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
other_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers,
other_app.id,
CONVERSATION_VARIABLE_NODE_ID,
_get_random_variable_name(fake),
other_value,
fake=fake,
)
from extensions.ext_database import db
app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
assert len(app_variables) == 3
assert len(other_app_variables) == 1
service = WorkflowDraftVariableService(db_session_with_containers)
service.delete_workflow_variables(app.id)
app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
assert len(app_variables_after) == 0
assert len(other_app_variables_after) == 1
def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test deleting all variables for a specific node successfully.
This test verifies that the service can delete all variables
associated with a specific node while preserving variables
for other nodes and conversation variables. This is important
for node-specific cleanup operations in workflow management.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
node_id = fake.word()
for i in range(2):
test_value = StringSegment(value=fake.numerify("node_value######"))
self._create_test_variable(
db_session_with_containers,
app.id,
node_id,
_get_random_variable_name(fake),
test_value,
variable_type=DraftVariableType.NODE,
fake=fake,
)
other_node_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers,
app.id,
"other_node",
_get_random_variable_name(fake),
other_node_value,
variable_type=DraftVariableType.NODE,
fake=fake,
)
conv_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers,
app.id,
CONVERSATION_VARIABLE_NODE_ID,
_get_random_variable_name(fake),
conv_value,
fake=fake,
)
from extensions.ext_database import db
target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
other_node_variables = (
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
)
conv_variables = (
db.session.query(WorkflowDraftVariable)
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
.all()
)
assert len(target_node_variables) == 2
assert len(other_node_variables) == 1
assert len(conv_variables) == 1
service = WorkflowDraftVariableService(db_session_with_containers)
service.delete_node_variables(app.id, node_id)
target_node_variables_after = (
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
)
other_node_variables_after = (
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
)
conv_variables_after = (
db.session.query(WorkflowDraftVariable)
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
.all()
)
assert len(target_node_variables_after) == 0
assert len(other_node_variables_after) == 1
assert len(conv_variables_after) == 1
def test_prefill_conversation_variable_default_values_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test prefill conversation variable default values successfully.
This test verifies that the service can automatically create
conversation variables with default values based on the workflow
configuration when none exist. This is important for initializing
workflow variables with proper defaults from the workflow definition.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake)
from core.variables.variables import StringVariable
conv_var1 = StringVariable(
id=fake.uuid4(),
name="conv_var1",
value="default_value1",
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var1"],
)
conv_var2 = StringVariable(
id=fake.uuid4(),
name="conv_var2",
value="default_value2",
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"],
)
workflow.conversation_variables = [conv_var1, conv_var2]
from extensions.ext_database import db
db.session.commit()
service = WorkflowDraftVariableService(db_session_with_containers)
service.prefill_conversation_variable_default_values(workflow)
draft_variables = (
db.session.query(WorkflowDraftVariable)
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
.all()
)
assert len(draft_variables) == 2
var_names = [var.name for var in draft_variables]
assert "conv_var1" in var_names
assert "conv_var2" in var_names
for var in draft_variables:
assert var.app_id == app.id
assert var.node_id == CONVERSATION_VARIABLE_NODE_ID
assert var.editable is True
assert var.get_variable_type() == DraftVariableType.CONVERSATION
def test_get_conversation_id_from_draft_variable_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting conversation ID from draft variable successfully.
This test verifies that the service can extract the conversation ID
from a system variable named "conversation_id". This is important
for maintaining conversation context across workflow executions.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
conversation_id = fake.uuid4()
conv_id_value = StringSegment(value=conversation_id)
self._create_test_variable(
db_session_with_containers,
app.id,
SYSTEM_VARIABLE_NODE_ID,
"conversation_id",
conv_id_value,
variable_type=DraftVariableType.SYS,
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
assert retrieved_conv_id == conversation_id
def test_get_conversation_id_from_draft_variable_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting conversation ID when it doesn't exist.
This test verifies that the service returns None when no
conversation_id variable exists for the app. This ensures
proper handling of missing conversation context scenarios.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
assert retrieved_conv_id is None
def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test listing system variables successfully.
This test verifies that the service can filter and return only
system variables, excluding conversation and node variables.
System variables are internal variables used by the workflow
engine for maintaining state and context.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
sys_var1_value = StringSegment(value=fake.word())
sys_var2_value = StringSegment(value=fake.word())
sys_var1 = self._create_test_variable(
db_session_with_containers,
app.id,
SYSTEM_VARIABLE_NODE_ID,
"sys_var1",
sys_var1_value,
variable_type=DraftVariableType.SYS,
fake=fake,
)
sys_var2 = self._create_test_variable(
db_session_with_containers,
app.id,
SYSTEM_VARIABLE_NODE_ID,
"sys_var2",
sys_var2_value,
variable_type=DraftVariableType.SYS,
fake=fake,
)
conv_var_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var", conv_var_value, fake=fake
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_system_variables(app.id)
assert len(result.variables) == 2
for var in result.variables:
assert var.node_id == SYSTEM_VARIABLE_NODE_ID
assert var.app_id == app.id
assert var.get_variable_type() == DraftVariableType.SYS
var_names = [var.name for var in result.variables]
assert "sys_var1" in var_names
assert "sys_var2" in var_names
assert "conv_var" not in var_names
def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test getting variables by name successfully for different types.
This test verifies that the service can retrieve variables by name
for different variable types (conversation, system, node). This
functionality is important for variable lookup operations during
workflow execution and user interactions.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
test_value = StringSegment(value=fake.word())
conv_var = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_conv_var", test_value, fake=fake
)
sys_var = self._create_test_variable(
db_session_with_containers,
app.id,
SYSTEM_VARIABLE_NODE_ID,
"test_sys_var",
test_value,
variable_type=DraftVariableType.SYS,
fake=fake,
)
node_var = self._create_test_variable(
db_session_with_containers,
app.id,
"test_node",
"test_node_var",
test_value,
variable_type=DraftVariableType.NODE,
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var")
assert retrieved_conv_var is not None
assert retrieved_conv_var.name == "test_conv_var"
assert retrieved_conv_var.node_id == CONVERSATION_VARIABLE_NODE_ID
retrieved_sys_var = service.get_system_variable(app.id, "test_sys_var")
assert retrieved_sys_var is not None
assert retrieved_sys_var.name == "test_sys_var"
assert retrieved_sys_var.node_id == SYSTEM_VARIABLE_NODE_ID
retrieved_node_var = service.get_node_variable(app.id, "test_node", "test_node_var")
assert retrieved_node_var is not None
assert retrieved_node_var.name == "test_node_var"
assert retrieved_node_var.node_id == "test_node"
def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test getting variables by name when they don't exist.
This test verifies that the service returns None when trying to
retrieve variables by name that don't exist. This ensures proper
handling of missing variable scenarios for all variable types.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_conv_var = service.get_conversation_variable(app.id, "non_existent_conv_var")
assert retrieved_conv_var is None
retrieved_sys_var = service.get_system_variable(app.id, "non_existent_sys_var")
assert retrieved_sys_var is None
retrieved_node_var = service.get_node_variable(app.id, "test_node", "non_existent_node_var")
assert retrieved_node_var is None

View File

@@ -0,0 +1,713 @@
import json
import uuid
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
import pytest
from faker import Faker
from models.enums import CreatorUserRole
from models.model import (
Message,
)
from models.workflow import WorkflowRun
from services.account_service import AccountService, TenantService
from services.app_service import AppService
from services.workflow_run_service import WorkflowRunService
class TestWorkflowRunService:
"""Integration tests for WorkflowRunService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.account_service.FeatureService") as mock_account_feature_service,
):
# Setup default mock returns for app service
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
# Setup default mock returns for account service
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
# Mock ModelManager for model configuration
mock_model_instance = mock_model_manager.return_value
mock_model_instance.get_default_model_instance.return_value = None
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
yield {
"feature_service": mock_feature_service,
"enterprise_service": mock_enterprise_service,
"model_manager": mock_model_manager,
"account_feature_service": mock_account_feature_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (app, account) - Created app and account instances
"""
fake = Faker()
# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True
# Create account and tenant
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
return app, account
def _create_test_workflow_run(
self, db_session_with_containers, app, account, triggered_from="debugging", offset_minutes=0
):
"""
Helper method to create a test workflow run for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
app: App instance
account: Account instance
triggered_from: Trigger source for workflow run
Returns:
WorkflowRun: Created workflow run instance
"""
fake = Faker()
from extensions.ext_database import db
# Create workflow run with offset timestamp
base_time = datetime.now(UTC)
created_time = base_time - timedelta(minutes=offset_minutes)
workflow_run = WorkflowRun(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=str(uuid.uuid4()),
type="chat",
triggered_from=triggered_from,
version="1.0.0",
graph=json.dumps({"nodes": [], "edges": []}),
inputs=json.dumps({"input": "test"}),
status="succeeded",
outputs=json.dumps({"output": "test result"}),
elapsed_time=1.5,
total_tokens=100,
total_steps=3,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=created_time,
finished_at=created_time,
)
db.session.add(workflow_run)
db.session.commit()
return workflow_run
def _create_test_message(self, db_session_with_containers, app, account, workflow_run):
"""
Helper method to create a test message for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
app: App instance
account: Account instance
workflow_run: WorkflowRun instance
Returns:
Message: Created message instance
"""
fake = Faker()
from extensions.ext_database import db
# Create conversation first (required for message)
from models.model import Conversation
conversation = Conversation(
app_id=app.id,
name=fake.sentence(),
inputs={},
status="normal",
mode="chat",
from_source=CreatorUserRole.ACCOUNT,
from_account_id=account.id,
)
db.session.add(conversation)
db.session.commit()
# Create message
message = Message()
message.app_id = app.id
message.conversation_id = conversation.id
message.query = fake.text(max_nb_chars=100)
message.message = {"type": "text", "content": fake.text(max_nb_chars=100)}
message.answer = fake.text(max_nb_chars=200)
message.message_tokens = 50
message.answer_tokens = 100
message.message_unit_price = 0.001
message.answer_unit_price = 0.002
message.message_price_unit = 0.001
message.answer_price_unit = 0.001
message.currency = "USD"
message.status = "normal"
message.from_source = CreatorUserRole.ACCOUNT
message.from_account_id = account.id
message.workflow_run_id = workflow_run.id
message.inputs = {"input": "test input"}
db.session.add(message)
db.session.commit()
return message
def test_get_paginate_workflow_runs_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful pagination of workflow runs with debugging trigger.
This test verifies:
- Proper pagination of workflow runs
- Correct filtering by triggered_from
- Proper limit and last_id handling
- Repository method calls
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create multiple workflow runs
workflow_runs = []
for i in range(5):
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
workflow_runs.append(workflow_run)
# Act: Execute the method under test
workflow_run_service = WorkflowRunService()
args = {"limit": 3, "last_id": None}
result = workflow_run_service.get_paginate_workflow_runs(app, args)
# Assert: Verify the expected outcomes
assert result is not None
assert hasattr(result, "data")
assert len(result.data) == 3 # Should return 3 items due to limit
# Verify pagination properties
assert hasattr(result, "has_more")
assert hasattr(result, "limit")
# Verify all returned items are debugging runs
for workflow_run in result.data:
assert workflow_run.triggered_from == "debugging"
assert workflow_run.app_id == app.id
assert workflow_run.tenant_id == app.tenant_id
def test_get_paginate_workflow_runs_with_last_id(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test pagination of workflow runs with last_id parameter.
This test verifies:
- Proper pagination with last_id parameter
- Correct handling of pagination state
- Repository method calls with proper parameters
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create multiple workflow runs with different timestamps
workflow_runs = []
for i in range(5):
workflow_run = self._create_test_workflow_run(
db_session_with_containers, app, account, "debugging", offset_minutes=i
)
workflow_runs.append(workflow_run)
# Act: Execute the method under test with last_id
workflow_run_service = WorkflowRunService()
args = {"limit": 2, "last_id": workflow_runs[1].id}
result = workflow_run_service.get_paginate_workflow_runs(app, args)
# Assert: Verify the expected outcomes
assert result is not None
assert hasattr(result, "data")
assert len(result.data) == 2 # Should return 2 items due to limit
# Verify pagination properties
assert hasattr(result, "has_more")
assert hasattr(result, "limit")
# Verify all returned items are debugging runs
for workflow_run in result.data:
assert workflow_run.triggered_from == "debugging"
assert workflow_run.app_id == app.id
assert workflow_run.tenant_id == app.tenant_id
def test_get_paginate_workflow_runs_default_limit(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test pagination of workflow runs with default limit.
This test verifies:
- Default limit of 20 when not specified
- Proper handling of missing limit parameter
- Repository method calls with default values
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create workflow runs
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
# Act: Execute the method under test without limit
workflow_run_service = WorkflowRunService()
args = {} # No limit specified
result = workflow_run_service.get_paginate_workflow_runs(app, args)
# Assert: Verify the expected outcomes
assert result is not None
assert hasattr(result, "data")
# Verify pagination properties
assert hasattr(result, "has_more")
assert hasattr(result, "limit")
# Verify the returned workflow run
if result.data:
workflow_run_result = result.data[0]
assert workflow_run_result.triggered_from == "debugging"
assert workflow_run_result.app_id == app.id
assert workflow_run_result.tenant_id == app.tenant_id
def test_get_paginate_advanced_chat_workflow_runs_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful pagination of advanced chat workflow runs with message information.
This test verifies:
- Proper pagination of advanced chat workflow runs
- Correct filtering by triggered_from
- Message information enrichment
- WorkflowWithMessage wrapper functionality
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create workflow runs with messages
workflow_runs = []
for i in range(3):
workflow_run = self._create_test_workflow_run(
db_session_with_containers, app, account, "debugging", offset_minutes=i
)
message = self._create_test_message(db_session_with_containers, app, account, workflow_run)
workflow_runs.append(workflow_run)
# Act: Execute the method under test
workflow_run_service = WorkflowRunService()
args = {"limit": 2, "last_id": None}
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app, args)
# Assert: Verify the expected outcomes
assert result is not None
assert hasattr(result, "data")
assert len(result.data) == 2 # Should return 2 items due to limit
# Verify pagination properties
assert hasattr(result, "has_more")
assert hasattr(result, "limit")
# Verify all returned items have message information
for workflow_run in result.data:
assert hasattr(workflow_run, "message_id")
assert hasattr(workflow_run, "conversation_id")
assert workflow_run.app_id == app.id
assert workflow_run.tenant_id == app.tenant_id
def test_get_workflow_run_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of workflow run by ID.
This test verifies:
- Proper workflow run retrieval by ID
- Correct tenant and app isolation
- Repository method calls with proper parameters
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create workflow run
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
# Act: Execute the method under test
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_run(app, workflow_run.id)
# Assert: Verify the expected outcomes
assert result is not None
assert result.id == workflow_run.id
assert result.tenant_id == app.tenant_id
assert result.app_id == app.id
assert result.triggered_from == "debugging"
assert result.status == "succeeded"
assert result.type == "chat"
assert result.version == "1.0.0"
def test_get_workflow_run_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test workflow run retrieval when run ID does not exist.
This test verifies:
- Proper handling of non-existent workflow run IDs
- Repository method calls with proper parameters
- Return value for missing records
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Use a non-existent UUID
non_existent_id = str(uuid.uuid4())
# Act: Execute the method under test
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_run(app, non_existent_id)
# Assert: Verify the expected outcomes
assert result is None
def test_get_workflow_run_node_executions_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful retrieval of workflow run node executions.
This test verifies:
- Proper node execution retrieval for workflow run
- Correct tenant and app isolation
- Repository method calls with proper parameters
- Context setup for plugin tool providers
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create workflow run
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
# Create node executions
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionModel
node_executions = []
for i in range(3):
node_execution = WorkflowNodeExecutionModel(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow_run.workflow_id,
triggered_from="workflow-run",
workflow_run_id=workflow_run.id,
index=i,
node_id=f"node_{i}",
node_type="llm" if i == 0 else "tool",
title=f"Node {i}",
inputs=json.dumps({"input": f"test_input_{i}"}),
process_data=json.dumps({"process": f"test_process_{i}"}),
status="succeeded",
elapsed_time=0.5,
execution_metadata=json.dumps({"tokens": 50}),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(node_execution)
node_executions.append(node_execution)
db.session.commit()
# Act: Execute the method under test
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_run_node_executions(app, workflow_run.id, account)
# Assert: Verify the expected outcomes
assert result is not None
assert len(result) == 3
# Verify node execution properties
for node_execution in result:
assert node_execution.tenant_id == app.tenant_id
assert node_execution.app_id == app.id
assert node_execution.workflow_run_id == workflow_run.id
assert node_execution.index in [0, 1, 2] # Check that index is one of the expected values
assert node_execution.node_id.startswith("node_") # Check that node_id starts with "node_"
assert node_execution.status == "succeeded"
def test_get_workflow_run_node_executions_empty(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting node executions for a workflow run with no executions.
This test verifies:
- Empty result when no node executions exist
- Proper handling of empty data
- No errors when querying non-existent executions
"""
# Arrange: Setup test data
account_service = AccountService()
tenant_service = TenantService()
app_service = AppService()
workflow_run_service = WorkflowRunService()
# Create account and tenant
account = account_service.create_account(
email="test@example.com",
name="Test User",
password="password123",
interface_language="en-US",
)
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
tenant = account.current_tenant
# Create app
app_args = {
"name": "Test App",
"mode": "chat",
"icon_type": "emoji",
"icon": "🚀",
"icon_background": "#4ECDC4",
}
app = app_service.create_app(tenant.id, app_args, account)
# Create workflow run without node executions
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
# Act: Get node executions
result = workflow_run_service.get_workflow_run_node_executions(
app_model=app,
run_id=workflow_run.id,
user=account,
)
# Assert: Verify empty result
assert result is not None
assert len(result) == 0
def test_get_workflow_run_node_executions_invalid_workflow_run_id(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting node executions with invalid workflow run ID.
This test verifies:
- Proper handling of invalid workflow run ID
- Empty result when workflow run doesn't exist
- No errors when querying with invalid ID
"""
# Arrange: Setup test data
account_service = AccountService()
tenant_service = TenantService()
app_service = AppService()
workflow_run_service = WorkflowRunService()
# Create account and tenant
account = account_service.create_account(
email="test@example.com",
name="Test User",
password="password123",
interface_language="en-US",
)
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
tenant = account.current_tenant
# Create app
app_args = {
"name": "Test App",
"mode": "chat",
"icon_type": "emoji",
"icon": "🚀",
"icon_background": "#4ECDC4",
}
app = app_service.create_app(tenant.id, app_args, account)
# Use invalid workflow run ID
invalid_workflow_run_id = str(uuid.uuid4())
# Act: Get node executions with invalid ID
result = workflow_run_service.get_workflow_run_node_executions(
app_model=app,
run_id=invalid_workflow_run_id,
user=account,
)
# Assert: Verify empty result
assert result is not None
assert len(result) == 0
def test_get_workflow_run_node_executions_database_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting node executions when database encounters an error.
This test verifies:
- Proper error handling when database operations fail
- Graceful degradation in error scenarios
- Error propagation to calling code
"""
# Arrange: Setup test data
account_service = AccountService()
tenant_service = TenantService()
app_service = AppService()
workflow_run_service = WorkflowRunService()
# Create account and tenant
account = account_service.create_account(
email="test@example.com",
name="Test User",
password="password123",
interface_language="en-US",
)
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
tenant = account.current_tenant
# Create app
app_args = {
"name": "Test App",
"mode": "chat",
"icon_type": "emoji",
"icon": "🚀",
"icon_background": "#4ECDC4",
}
app = app_service.create_app(tenant.id, app_args, account)
# Create workflow run
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
# Mock database error by closing the session
db_session_with_containers.close()
# Act & Assert: Verify error handling
with pytest.raises((Exception, RuntimeError)):
workflow_run_service.get_workflow_run_node_executions(
app_model=app,
run_id=workflow_run.id,
user=account,
)
def test_get_workflow_run_node_executions_end_user(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test node execution retrieval for end user.
This test verifies:
- Proper handling of end user vs account user
- Correct tenant ID extraction for end users
- Repository method calls with proper parameters
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create workflow run
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
# Create end user
from extensions.ext_database import db
from models.model import EndUser
end_user = EndUser(
tenant_id=app.tenant_id,
app_id=app.id,
type="web_app",
is_anonymous=False,
session_id=str(uuid.uuid4()),
external_user_id=str(uuid.uuid4()),
name=fake.name(),
)
db.session.add(end_user)
db.session.commit()
# Create node execution
from models.workflow import WorkflowNodeExecutionModel
node_execution = WorkflowNodeExecutionModel(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow_run.workflow_id,
triggered_from="workflow-run",
workflow_run_id=workflow_run.id,
index=0,
node_id="node_0",
node_type="llm",
title="Node 0",
inputs=json.dumps({"input": "test_input"}),
process_data=json.dumps({"process": "test_process"}),
status="succeeded",
elapsed_time=0.5,
execution_metadata=json.dumps({"tokens": 50}),
created_by_role=CreatorUserRole.END_USER,
created_by=end_user.id,
created_at=datetime.now(UTC),
)
db.session.add(node_execution)
db.session.commit()
# Act: Execute the method under test
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_run_node_executions(app, workflow_run.id, end_user)
# Assert: Verify the expected outcomes
assert result is not None
assert len(result) == 1
# Verify node execution properties
node_exec = result[0]
assert node_exec.tenant_id == app.tenant_id
assert node_exec.app_id == app.id
assert node_exec.workflow_run_id == workflow_run.id
assert node_exec.created_by == end_user.id
assert node_exec.created_by_role == CreatorUserRole.END_USER

View File

@@ -0,0 +1,529 @@
from unittest.mock import patch
import pytest
from faker import Faker
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from services.workspace_service import WorkspaceService
class TestWorkspaceService:
"""Integration tests for WorkspaceService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.workspace_service.FeatureService") as mock_feature_service,
patch("services.workspace_service.TenantService") as mock_tenant_service,
patch("services.workspace_service.dify_config") as mock_dify_config,
):
# Setup default mock returns
mock_feature_service.get_features.return_value.can_replace_logo = True
mock_tenant_service.has_roles.return_value = True
mock_dify_config.FILES_URL = "https://example.com/files"
yield {
"feature_service": mock_feature_service,
"tenant_service": mock_tenant_service,
"dify_config": mock_dify_config,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
plan="basic",
custom_config='{"replace_webapp_logo": true, "remove_webapp_brand": false}',
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join with owner role
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def test_get_tenant_info_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful retrieval of tenant information with all features enabled.
This test verifies:
- Proper tenant info retrieval with all required fields
- Correct role assignment from TenantAccountJoin
- Custom config handling when features are enabled
- Logo replacement functionality for privileged users
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Setup mocks for feature service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert result["id"] == tenant.id
assert result["name"] == tenant.name
assert result["plan"] == tenant.plan
assert result["status"] == tenant.status
assert result["role"] == TenantAccountRole.OWNER
assert result["created_at"] == tenant.created_at
assert result["trial_end_reason"] is None
# Verify custom config is included for privileged users
assert "custom_config" in result
assert result["custom_config"]["remove_webapp_brand"] is False
assert "replace_webapp_logo" in result["custom_config"]
# Verify database state
from extensions.ext_database import db
db.session.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_without_custom_config(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant info retrieval when custom config features are disabled.
This test verifies:
- Tenant info retrieval without custom config when features are disabled
- Proper handling of disabled logo replacement functionality
- Role assignment still works correctly
- Basic tenant information is complete
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Setup mocks to disable custom config features
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = False
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert result["id"] == tenant.id
assert result["name"] == tenant.name
assert result["plan"] == tenant.plan
assert result["status"] == tenant.status
assert result["role"] == TenantAccountRole.OWNER
assert result["created_at"] == tenant.created_at
assert result["trial_end_reason"] is None
# Verify custom config is not included when features are disabled
assert "custom_config" not in result
# Verify database state
from extensions.ext_database import db
db.session.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_normal_user_role(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant info retrieval for normal user role without privileged features.
This test verifies:
- Tenant info retrieval for non-privileged users
- Role assignment for normal users
- Custom config is not accessible for normal users
- Proper handling of different user roles
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Update the join to have normal role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join.role = TenantAccountRole.NORMAL
db.session.commit()
# Setup mocks for feature service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert result["id"] == tenant.id
assert result["name"] == tenant.name
assert result["plan"] == tenant.plan
assert result["status"] == tenant.status
assert result["role"] == TenantAccountRole.NORMAL
assert result["created_at"] == tenant.created_at
assert result["trial_end_reason"] is None
# Verify custom config is not included for normal users
assert "custom_config" not in result
# Verify database state
db.session.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_admin_role_and_logo_replacement(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant info retrieval for admin role with logo replacement enabled.
This test verifies:
- Admin role can access custom config features
- Logo replacement functionality works for admin users
- Proper URL construction for logo replacement
- Custom config handling for admin role
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Update the join to have admin role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join.role = TenantAccountRole.ADMIN
db.session.commit()
# Setup mocks for feature service and tenant service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
mock_external_service_dependencies["dify_config"].FILES_URL = "https://cdn.example.com"
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert result["role"] == TenantAccountRole.ADMIN
# Verify custom config is included for admin users
assert "custom_config" in result
assert result["custom_config"]["remove_webapp_brand"] is False
assert "replace_webapp_logo" in result["custom_config"]
# Verify database state
db.session.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_tenant_none(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test tenant info retrieval when tenant parameter is None.
This test verifies:
- Proper handling of None tenant parameter
- Method returns None for invalid input
- No exceptions are raised for None input
- Graceful degradation for invalid data
"""
# Arrange: No test data needed for this test
# Act: Execute the method under test with None tenant
result = WorkspaceService.get_tenant_info(None)
# Assert: Verify the expected outcomes
assert result is None
def test_get_tenant_info_with_custom_config_variations(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant info retrieval with various custom config configurations.
This test verifies:
- Different custom config combinations work correctly
- Logo replacement URL construction with various configs
- Brand removal functionality
- Edge cases in custom config handling
"""
# Arrange: Create test data with different custom configs
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Test different custom config combinations
test_configs = [
# Case 1: Both logo and brand removal enabled
{"replace_webapp_logo": True, "remove_webapp_brand": True},
# Case 2: Only logo replacement enabled
{"replace_webapp_logo": True, "remove_webapp_brand": False},
# Case 3: Only brand removal enabled
{"replace_webapp_logo": False, "remove_webapp_brand": True},
# Case 4: Neither enabled
{"replace_webapp_logo": False, "remove_webapp_brand": False},
]
for config in test_configs:
# Update tenant custom config
import json
from extensions.ext_database import db
tenant.custom_config = json.dumps(config)
db.session.commit()
# Setup mocks
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
mock_external_service_dependencies["dify_config"].FILES_URL = "https://files.example.com"
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert "custom_config" in result
if config["replace_webapp_logo"]:
assert "replace_webapp_logo" in result["custom_config"]
if config["replace_webapp_logo"]:
expected_url = f"https://files.example.com/files/workspaces/{tenant.id}/webapp-logo"
assert result["custom_config"]["replace_webapp_logo"] == expected_url
else:
assert result["custom_config"]["replace_webapp_logo"] is None
assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"]
# Verify database state
db.session.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_editor_role_and_limited_permissions(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant info retrieval for editor role with limited permissions.
This test verifies:
- Editor role has limited access to custom config features
- Proper role-based permission checking
- Custom config handling for different role levels
- Role hierarchy and permission boundaries
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Update the join to have editor role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join.role = TenantAccountRole.EDITOR
db.session.commit()
# Setup mocks for feature service and tenant service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
# Editor role should not have admin/owner permissions
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
mock_external_service_dependencies["dify_config"].FILES_URL = "https://cdn.example.com"
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert result["role"] == TenantAccountRole.EDITOR
# Verify custom config is not included for editor users without admin privileges
assert "custom_config" not in result
# Verify database state
db.session.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_dataset_operator_role(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant info retrieval for dataset operator role.
This test verifies:
- Dataset operator role handling
- Role assignment for specialized roles
- Permission boundaries for dataset operators
- Custom config access for dataset operators
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Update the join to have dataset operator role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join.role = TenantAccountRole.DATASET_OPERATOR
db.session.commit()
# Setup mocks for feature service and tenant service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
# Dataset operator should not have admin/owner permissions
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
mock_external_service_dependencies["dify_config"].FILES_URL = "https://cdn.example.com"
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert result["role"] == TenantAccountRole.DATASET_OPERATOR
# Verify custom config is not included for dataset operators without admin privileges
assert "custom_config" not in result
# Verify database state
db.session.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_complex_custom_config_scenarios(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant info retrieval with complex custom config scenarios.
This test verifies:
- Complex custom config combinations
- Edge cases in custom config handling
- URL construction with various configs
- Error handling for malformed configs
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Test complex custom config scenarios
test_configs = [
# Case 1: Empty custom config
{},
# Case 2: Custom config with only logo replacement
{"replace_webapp_logo": True},
# Case 3: Custom config with only brand removal
{"remove_webapp_brand": True},
# Case 4: Custom config with additional fields
{
"replace_webapp_logo": True,
"remove_webapp_brand": False,
"custom_field": "custom_value",
"nested_config": {"key": "value"},
},
# Case 5: Custom config with null values
{"replace_webapp_logo": None, "remove_webapp_brand": None},
]
for config in test_configs:
# Update tenant custom config
import json
from extensions.ext_database import db
tenant.custom_config = json.dumps(config)
db.session.commit()
# Setup mocks
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
mock_external_service_dependencies["dify_config"].FILES_URL = "https://files.example.com"
# Mock current_user for flask_login
with patch("services.workspace_service.current_user", account):
# Act: Execute the method under test
result = WorkspaceService.get_tenant_info(tenant)
# Assert: Verify the expected outcomes
assert result is not None
assert "custom_config" in result
# Verify logo replacement handling
if config.get("replace_webapp_logo"):
assert "replace_webapp_logo" in result["custom_config"]
expected_url = f"https://files.example.com/files/workspaces/{tenant.id}/webapp-logo"
assert result["custom_config"]["replace_webapp_logo"] == expected_url
else:
assert result["custom_config"]["replace_webapp_logo"] is None
# Verify brand removal handling
if "remove_webapp_brand" in config:
assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"]
else:
assert result["custom_config"]["remove_webapp_brand"] is False
# Verify database state
db.session.refresh(tenant)
assert tenant.id is not None

View File

@@ -0,0 +1,550 @@
from unittest.mock import patch
import pytest
from faker import Faker
from models import Account, Tenant
from models.tools import ApiToolProvider
from services.tools.api_tools_manage_service import ApiToolManageService
class TestApiToolManageService:
"""Integration tests for ApiToolManageService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.tools.api_tools_manage_service.ToolLabelManager") as mock_tool_label_manager,
patch("services.tools.api_tools_manage_service.create_tool_provider_encrypter") as mock_encrypter,
patch("services.tools.api_tools_manage_service.ApiToolProviderController") as mock_provider_controller,
):
# Setup default mock returns
mock_tool_label_manager.update_tool_labels.return_value = None
mock_encrypter.return_value = (mock_encrypter, None)
mock_encrypter.encrypt.return_value = {"encrypted": "credentials"}
mock_provider_controller.from_db.return_value = mock_provider_controller
mock_provider_controller.load_bundled_tools.return_value = None
yield {
"tool_label_manager": mock_tool_label_manager,
"encrypter": mock_encrypter,
"provider_controller": mock_provider_controller,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def _create_test_openapi_schema(self):
"""Helper method to create a test OpenAPI schema."""
return """
{
"openapi": "3.0.0",
"info": {
"title": "Test API",
"version": "1.0.0",
"description": "Test API for testing purposes"
},
"servers": [
{
"url": "https://api.example.com",
"description": "Production server"
}
],
"paths": {
"/test": {
"get": {
"operationId": "testOperation",
"summary": "Test operation",
"responses": {
"200": {
"description": "Success"
}
}
}
}
}
}
"""
def test_parser_api_schema_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful parsing of API schema.
This test verifies:
- Proper schema parsing with valid OpenAPI schema
- Correct credentials schema generation
- Proper warning handling
- Return value structure
"""
# Arrange: Create test schema
schema = self._create_test_openapi_schema()
# Act: Parse the schema
result = ApiToolManageService.parser_api_schema(schema)
# Assert: Verify the result structure
assert result is not None
assert "schema_type" in result
assert "parameters_schema" in result
assert "credentials_schema" in result
assert "warning" in result
# Verify credentials schema structure
credentials_schema = result["credentials_schema"]
assert len(credentials_schema) == 3
# Check auth_type field
auth_type_field = next(field for field in credentials_schema if field["name"] == "auth_type")
assert auth_type_field["required"] is True
assert auth_type_field["default"] == "none"
assert len(auth_type_field["options"]) == 2
# Check api_key_header field
api_key_header_field = next(field for field in credentials_schema if field["name"] == "api_key_header")
assert api_key_header_field["required"] is False
assert api_key_header_field["default"] == "api_key"
# Check api_key_value field
api_key_value_field = next(field for field in credentials_schema if field["name"] == "api_key_value")
assert api_key_value_field["required"] is False
assert api_key_value_field["default"] == ""
def test_parser_api_schema_invalid_schema(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test parsing of invalid API schema.
This test verifies:
- Proper error handling for invalid schemas
- Correct exception type and message
- Error propagation from underlying parser
"""
# Arrange: Create invalid schema
invalid_schema = "invalid json schema"
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.parser_api_schema(invalid_schema)
assert "invalid schema" in str(exc_info.value)
def test_parser_api_schema_malformed_json(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test parsing of malformed JSON schema.
This test verifies:
- Proper error handling for malformed JSON
- Correct exception type and message
- Error propagation from JSON parsing
"""
# Arrange: Create malformed JSON schema
malformed_schema = '{"openapi": "3.0.0", "info": {"title": "Test", "version": "1.0.0"}, "paths": {}}'
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.parser_api_schema(malformed_schema)
assert "invalid schema" in str(exc_info.value)
def test_convert_schema_to_tool_bundles_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful conversion of schema to tool bundles.
This test verifies:
- Proper schema conversion with valid OpenAPI schema
- Correct tool bundles generation
- Proper schema type detection
- Return value structure
"""
# Arrange: Create test schema
schema = self._create_test_openapi_schema()
# Act: Convert schema to tool bundles
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema)
# Assert: Verify the result structure
assert tool_bundles is not None
assert isinstance(tool_bundles, list)
assert len(tool_bundles) > 0
assert schema_type is not None
assert isinstance(schema_type, str)
# Verify tool bundle structure
tool_bundle = tool_bundles[0]
assert hasattr(tool_bundle, "operation_id")
assert tool_bundle.operation_id == "testOperation"
def test_convert_schema_to_tool_bundles_with_extra_info(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful conversion of schema to tool bundles with extra info.
This test verifies:
- Proper schema conversion with extra info parameter
- Correct tool bundles generation
- Extra info handling
- Return value structure
"""
# Arrange: Create test schema and extra info
schema = self._create_test_openapi_schema()
extra_info = {"description": "Custom description", "version": "2.0.0"}
# Act: Convert schema to tool bundles with extra info
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
# Assert: Verify the result structure
assert tool_bundles is not None
assert isinstance(tool_bundles, list)
assert len(tool_bundles) > 0
assert schema_type is not None
assert isinstance(schema_type, str)
def test_convert_schema_to_tool_bundles_invalid_schema(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test conversion of invalid schema to tool bundles.
This test verifies:
- Proper error handling for invalid schemas
- Correct exception type and message
- Error propagation from underlying parser
"""
# Arrange: Create invalid schema
invalid_schema = "invalid schema content"
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.convert_schema_to_tool_bundles(invalid_schema)
assert "invalid schema" in str(exc_info.value)
def test_create_api_tool_provider_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful creation of API tool provider.
This test verifies:
- Proper provider creation with valid parameters
- Correct database state after creation
- Proper relationship establishment
- External service integration
- Return value correctness
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""}
schema_type = "openapi"
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
labels = ["test", "api"]
# Act: Create API tool provider
result = ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
# Assert: Verify the result
assert result == {"result": "success"}
# Verify database state
from extensions.ext_database import db
provider = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)
assert provider is not None
assert provider.name == provider_name
assert provider.tenant_id == tenant.id
assert provider.user_id == account.id
assert provider.schema_type_str == schema_type
assert provider.privacy_policy == privacy_policy
assert provider.custom_disclaimer == custom_disclaimer
# Verify mock interactions
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()
mock_external_service_dependencies["encrypter"].assert_called_once()
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once()
def test_create_api_tool_provider_duplicate_name(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test creation of API tool provider with duplicate name.
This test verifies:
- Proper error handling for duplicate provider names
- Correct exception type and message
- Database constraint enforcement
"""
# Arrange: Create test data and existing provider
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {"auth_type": "none"}
schema_type = "openapi"
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
labels = ["test"]
# Create first provider
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
# Act & Assert: Try to create duplicate provider
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
assert f"provider {provider_name} already exists" in str(exc_info.value)
def test_create_api_tool_provider_invalid_schema_type(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test creation of API tool provider with invalid schema type.
This test verifies:
- Proper error handling for invalid schema types
- Correct exception type and message
- Schema type validation
"""
# Arrange: Create test data with invalid schema type
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {"auth_type": "none"}
schema_type = "invalid_type"
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
labels = ["test"]
# Act & Assert: Try to create provider with invalid schema type
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
assert "invalid schema type" in str(exc_info.value)
def test_create_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test creation of API tool provider with missing auth type.
This test verifies:
- Proper error handling for missing auth type
- Correct exception type and message
- Credentials validation
"""
# Arrange: Create test data with missing auth type
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {} # Missing auth_type
schema_type = "openapi"
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
labels = ["test"]
# Act & Assert: Try to create provider with missing auth type
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
assert "auth_type is required" in str(exc_info.value)
def test_create_api_tool_provider_with_api_key_auth(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful creation of API tool provider with API key authentication.
This test verifies:
- Proper provider creation with API key auth
- Correct credentials handling
- Proper authentication type processing
"""
# Arrange: Create test data with API key auth
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔑"}
credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()}
schema_type = "openapi"
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
labels = ["api_key", "secure"]
# Act: Create API tool provider
result = ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
# Assert: Verify the result
assert result == {"result": "success"}
# Verify database state
from extensions.ext_database import db
provider = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)
assert provider is not None
assert provider.name == provider_name
assert provider.tenant_id == tenant.id
assert provider.user_id == account.id
assert provider.schema_type_str == schema_type
# Verify mock interactions
mock_external_service_dependencies["encrypter"].assert_called_once()
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()

View File

@@ -0,0 +1,788 @@
from unittest.mock import Mock, patch
import pytest
from faker import Faker
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService
class TestToolTransformService:
"""Integration tests for ToolTransformService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with patch("services.tools.tools_transform_service.dify_config") as mock_dify_config:
with patch("services.plugin.plugin_service.dify_config", new=mock_dify_config):
# Setup default mock returns
mock_dify_config.CONSOLE_API_URL = "https://console.example.com"
yield {
"dify_config": mock_dify_config,
}
def _create_test_tool_provider(
self, db_session_with_containers, mock_external_service_dependencies, provider_type="api"
):
"""
Helper method to create a test tool provider for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
provider_type: Type of provider to create
Returns:
Tool provider instance
"""
fake = Faker()
if provider_type == "api":
provider = ApiToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
icon_dark='{"background": "#252525", "content": "🔧"}',
tenant_id="test_tenant_id",
user_id="test_user_id",
credentials={"auth_type": "api_key_header", "api_key": "test_key"},
provider_type="api",
)
elif provider_type == "builtin":
provider = BuiltinToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon="🔧",
icon_dark="🔧",
tenant_id="test_tenant_id",
provider="test_provider",
credential_type="api_key",
credentials={"api_key": "test_key"},
)
elif provider_type == "workflow":
provider = WorkflowToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
icon_dark='{"background": "#252525", "content": "🔧"}',
tenant_id="test_tenant_id",
user_id="test_user_id",
workflow_id="test_workflow_id",
)
elif provider_type == "mcp":
provider = MCPToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
provider_icon='{"background": "#FF6B6B", "content": "🔧"}',
tenant_id="test_tenant_id",
user_id="test_user_id",
server_url="https://mcp.example.com",
server_identifier="test_server",
tools='[{"name": "test_tool", "description": "Test tool"}]',
authed=True,
)
else:
raise ValueError(f"Unknown provider type: {provider_type}")
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
return provider
def test_get_plugin_icon_url_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful plugin icon URL generation.
This test verifies:
- Proper URL construction for plugin icons
- Correct tenant_id and filename handling
- URL format compliance
"""
# Arrange: Setup test data
fake = Faker()
tenant_id = fake.uuid4()
filename = "test_icon.png"
# Act: Execute the method under test
result = PluginService.get_plugin_icon_url(str(tenant_id), filename)
# Assert: Verify the expected outcomes
assert result is not None
assert isinstance(result, str)
assert "console/api/workspaces/current/plugin/icon" in result
assert str(tenant_id) in result
assert filename in result
assert result.startswith("https://console.example.com")
# Verify URL structure
expected_url = f"https://console.example.com/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={filename}"
assert result == expected_url
def test_get_plugin_icon_url_with_empty_console_url(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test plugin icon URL generation when CONSOLE_API_URL is empty.
This test verifies:
- Fallback to relative URL when CONSOLE_API_URL is None
- Proper URL construction with relative path
"""
# Arrange: Setup mock with empty console URL
mock_external_service_dependencies["dify_config"].CONSOLE_API_URL = None
fake = Faker()
tenant_id = fake.uuid4()
filename = "test_icon.png"
# Act: Execute the method under test
result = PluginService.get_plugin_icon_url(str(tenant_id), filename)
# Assert: Verify the expected outcomes
assert result is not None
assert isinstance(result, str)
assert result.startswith("/console/api/workspaces/current/plugin/icon")
assert str(tenant_id) in result
assert filename in result
# Verify URL structure
expected_url = f"/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={filename}"
assert result == expected_url
def test_get_tool_provider_icon_url_builtin_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for builtin providers.
This test verifies:
- Proper URL construction for builtin tool providers
- Correct provider type handling
- URL format compliance
"""
# Arrange: Setup test data
fake = Faker()
provider_type = ToolProviderType.BUILT_IN
provider_name = fake.company()
icon = "🔧"
# Act: Execute the method under test
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
# Assert: Verify the expected outcomes
assert result is not None
assert isinstance(result, str)
assert "console/api/workspaces/current/tool-provider/builtin" in result
# Note: provider_name may contain spaces that get URL encoded
assert provider_name.replace(" ", "%20") in result or provider_name in result
assert result.endswith("/icon")
assert result.startswith("https://console.example.com")
# Verify URL structure (accounting for URL encoding)
# The actual result will have URL-encoded spaces (%20), so we need to compare accordingly
expected_url = (
f"https://console.example.com/console/api/workspaces/current/tool-provider/builtin/{provider_name}/icon"
)
# Convert expected URL to match the actual URL encoding
expected_encoded = expected_url.replace(" ", "%20")
assert result == expected_encoded
def test_get_tool_provider_icon_url_api_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for API providers.
This test verifies:
- Proper icon handling for API tool providers
- JSON string parsing for icon data
- Fallback icon when parsing fails
"""
# Arrange: Setup test data
fake = Faker()
provider_type = ToolProviderType.API
provider_name = fake.company()
icon = '{"background": "#FF6B6B", "content": "🔧"}'
# Act: Execute the method under test
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
# Assert: Verify the expected outcomes
assert result is not None
assert isinstance(result, dict)
assert result["background"] == "#FF6B6B"
assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_api_invalid_json(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tool provider icon URL generation for API providers with invalid JSON.
This test verifies:
- Proper fallback when JSON parsing fails
- Default icon structure when exception occurs
"""
# Arrange: Setup test data with invalid JSON
fake = Faker()
provider_type = ToolProviderType.API
provider_name = fake.company()
icon = '{"invalid": json}'
# Act: Execute the method under test
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
# Assert: Verify the expected outcomes
assert result is not None
assert isinstance(result, dict)
assert result["background"] == "#252525"
# Note: emoji characters may be represented as Unicode escape sequences
assert result["content"] == "😁" or result["content"] == "\ud83d\ude01"
def test_get_tool_provider_icon_url_workflow_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for workflow providers.
This test verifies:
- Proper icon handling for workflow tool providers
- Direct icon return for workflow type
"""
# Arrange: Setup test data
fake = Faker()
provider_type = ToolProviderType.WORKFLOW
provider_name = fake.company()
icon = {"background": "#FF6B6B", "content": "🔧"}
# Act: Execute the method under test
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
# Assert: Verify the expected outcomes
assert result is not None
assert isinstance(result, dict)
assert result["background"] == "#FF6B6B"
assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_mcp_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for MCP providers.
This test verifies:
- Direct icon return for MCP type
- No URL transformation for MCP providers
"""
# Arrange: Setup test data
fake = Faker()
provider_type = ToolProviderType.MCP
provider_name = fake.company()
icon = {"background": "#FF6B6B", "content": "🔧"}
# Act: Execute the method under test
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
# Assert: Verify the expected outcomes
assert result is not None
assert isinstance(result, dict)
assert result["background"] == "#FF6B6B"
assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_unknown_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tool provider icon URL generation for unknown provider types.
This test verifies:
- Empty string return for unknown provider types
- Proper handling of unsupported types
"""
# Arrange: Setup test data with unknown type
fake = Faker()
provider_type = "unknown_type"
provider_name = fake.company()
icon = "🔧"
# Act: Execute the method under test
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
# Assert: Verify the expected outcomes
assert result == ""
def test_repack_provider_dict_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful provider repacking with dictionary input.
This test verifies:
- Proper icon URL generation for dictionary providers
- Correct provider type handling
- Icon transformation for different provider types
"""
# Arrange: Setup test data
fake = Faker()
tenant_id = fake.uuid4()
provider = {"type": ToolProviderType.BUILT_IN, "name": fake.company(), "icon": "🔧"}
# Act: Execute the method under test
ToolTransformService.repack_provider(str(tenant_id), provider)
# Assert: Verify the expected outcomes
assert "icon" in provider
assert isinstance(provider["icon"], str)
assert "console/api/workspaces/current/tool-provider/builtin" in provider["icon"]
# Note: provider name may contain spaces that get URL encoded
assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"]
def test_repack_provider_entity_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful provider repacking with ToolProviderApiEntity input.
This test verifies:
- Proper icon URL generation for entity providers
- Plugin icon handling when plugin_id is present
- Regular icon handling when plugin_id is not present
"""
# Arrange: Setup test data
fake = Faker()
tenant_id = fake.uuid4()
# Create provider entity with plugin_id
provider = ToolProviderApiEntity(
id=str(fake.uuid4()),
author=fake.name(),
name=fake.company(),
description=I18nObject(en_US=fake.text(max_nb_chars=100)),
icon="test_icon.png",
icon_dark="test_icon_dark.png",
label=I18nObject(en_US=fake.company()),
type=ToolProviderType.API,
masked_credentials={},
is_team_authorization=True,
plugin_id="test_plugin_id",
tools=[],
labels=[],
)
# Act: Execute the method under test
ToolTransformService.repack_provider(tenant_id, provider)
# Assert: Verify the expected outcomes
assert provider.icon is not None
assert isinstance(provider.icon, str)
assert "console/api/workspaces/current/plugin/icon" in provider.icon
assert str(tenant_id) in provider.icon
assert "test_icon.png" in provider.icon
# Verify dark icon handling
assert provider.icon_dark is not None
assert isinstance(provider.icon_dark, str)
assert "console/api/workspaces/current/plugin/icon" in provider.icon_dark
assert str(tenant_id) in provider.icon_dark
assert "test_icon_dark.png" in provider.icon_dark
def test_repack_provider_entity_no_plugin_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful provider repacking with ToolProviderApiEntity input without plugin_id.
This test verifies:
- Proper icon URL generation for non-plugin providers
- Regular tool provider icon handling
- Dark icon handling when present
"""
# Arrange: Setup test data
fake = Faker()
tenant_id = fake.uuid4()
# Create provider entity without plugin_id
provider = ToolProviderApiEntity(
id=fake.uuid4(),
author=fake.name(),
name=fake.company(),
description=I18nObject(en_US=fake.text(max_nb_chars=100)),
icon='{"background": "#FF6B6B", "content": "🔧"}',
icon_dark='{"background": "#252525", "content": "🔧"}',
label=I18nObject(en_US=fake.company()),
type=ToolProviderType.API,
masked_credentials={},
is_team_authorization=True,
plugin_id=None,
tools=[],
labels=[],
)
# Act: Execute the method under test
ToolTransformService.repack_provider(str(tenant_id), provider)
# Assert: Verify the expected outcomes
assert provider.icon is not None
assert isinstance(provider.icon, dict)
assert provider.icon["background"] == "#FF6B6B"
assert provider.icon["content"] == "🔧"
# Verify dark icon handling
assert provider.icon_dark is not None
assert isinstance(provider.icon_dark, dict)
assert provider.icon_dark["background"] == "#252525"
assert provider.icon_dark["content"] == "🔧"
def test_repack_provider_entity_no_dark_icon(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test provider repacking with ToolProviderApiEntity input without dark icon.
This test verifies:
- Proper handling when icon_dark is None or empty
- No errors when dark icon is not present
"""
# Arrange: Setup test data
fake = Faker()
tenant_id = fake.uuid4()
# Create provider entity without dark icon
provider = ToolProviderApiEntity(
id=fake.uuid4(),
author=fake.name(),
name=fake.company(),
description=I18nObject(en_US=fake.text(max_nb_chars=100)),
icon='{"background": "#FF6B6B", "content": "🔧"}',
icon_dark="",
label=I18nObject(en_US=fake.company()),
type=ToolProviderType.API,
masked_credentials={},
is_team_authorization=True,
plugin_id=None,
tools=[],
labels=[],
)
# Act: Execute the method under test
ToolTransformService.repack_provider(tenant_id, provider)
# Assert: Verify the expected outcomes
assert provider.icon is not None
assert isinstance(provider.icon, dict)
assert provider.icon["background"] == "#FF6B6B"
assert provider.icon["content"] == "🔧"
# Verify dark icon remains empty string
assert provider.icon_dark == ""
def test_builtin_provider_to_user_provider_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful conversion of builtin provider to user provider.
This test verifies:
- Proper entity creation with all required fields
- Credentials schema handling
- Team authorization setup
- Plugin ID handling
"""
# Arrange: Setup test data
fake = Faker()
# Create mock provider controller
mock_controller = Mock()
mock_controller.entity.identity.name = fake.company()
mock_controller.entity.identity.author = fake.name()
mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
mock_controller.entity.identity.icon = "🔧"
mock_controller.entity.identity.icon_dark = "🔧"
mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
mock_controller.plugin_id = None
mock_controller.plugin_unique_identifier = None
mock_controller.tool_labels = ["label1", "label2"]
mock_controller.need_credentials = True
# Mock credentials schema
mock_credential = Mock()
mock_credential.to_basic_provider_config.return_value.name = "api_key"
mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
# Create mock database provider
mock_db_provider = Mock()
mock_db_provider.credential_type = "api-key"
mock_db_provider.tenant_id = fake.uuid4()
mock_db_provider.credentials = {"api_key": "encrypted_key"}
# Mock encryption
with patch("services.tools.tools_transform_service.create_provider_encrypter") as mock_encrypter:
mock_encrypter_instance = Mock()
mock_encrypter_instance.decrypt.return_value = {"api_key": "decrypted_key"}
mock_encrypter_instance.mask_plugin_credentials.return_value = {"api_key": ""}
mock_encrypter.return_value = (mock_encrypter_instance, None)
# Act: Execute the method under test
result = ToolTransformService.builtin_provider_to_user_provider(
mock_controller, mock_db_provider, decrypt_credentials=True
)
# Assert: Verify the expected outcomes
assert result is not None
assert result.id == mock_controller.entity.identity.name
assert result.author == mock_controller.entity.identity.author
assert result.name == mock_controller.entity.identity.name
assert result.description == mock_controller.entity.identity.description
assert result.icon == mock_controller.entity.identity.icon
assert result.icon_dark == mock_controller.entity.identity.icon_dark
assert result.label == mock_controller.entity.identity.label
assert result.type == ToolProviderType.BUILT_IN
assert result.is_team_authorization is True
assert result.plugin_id is None
assert result.tools == []
assert result.labels == ["label1", "label2"]
assert result.masked_credentials == {"api_key": ""}
assert result.original_credentials == {"api_key": "decrypted_key"}
def test_builtin_provider_to_user_provider_plugin_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful conversion of builtin provider to user provider with plugin.
This test verifies:
- Plugin ID and unique identifier handling
- Proper entity creation for plugin providers
"""
# Arrange: Setup test data
fake = Faker()
# Create mock provider controller with plugin
mock_controller = Mock()
mock_controller.entity.identity.name = fake.company()
mock_controller.entity.identity.author = fake.name()
mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
mock_controller.entity.identity.icon = "🔧"
mock_controller.entity.identity.icon_dark = "🔧"
mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
mock_controller.plugin_id = "test_plugin_id"
mock_controller.plugin_unique_identifier = "test_unique_id"
mock_controller.tool_labels = ["label1"]
mock_controller.need_credentials = False
# Mock credentials schema
mock_credential = Mock()
mock_credential.to_basic_provider_config.return_value.name = "api_key"
mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
# Act: Execute the method under test
result = ToolTransformService.builtin_provider_to_user_provider(
mock_controller, None, decrypt_credentials=False
)
# Assert: Verify the expected outcomes
assert result is not None
# Note: The method checks isinstance(provider_controller, PluginToolProviderController)
# Since we're using a Mock, this check will fail, so plugin_id will remain None
# In a real test with actual PluginToolProviderController, this would work
assert result.is_team_authorization is True
assert result.allow_delete is False
def test_builtin_provider_to_user_provider_no_credentials(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test conversion of builtin provider to user provider without credentials.
This test verifies:
- Proper handling when no credentials are needed
- Team authorization setup for no-credentials providers
"""
# Arrange: Setup test data
fake = Faker()
# Create mock provider controller
mock_controller = Mock()
mock_controller.entity.identity.name = fake.company()
mock_controller.entity.identity.author = fake.name()
mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
mock_controller.entity.identity.icon = "🔧"
mock_controller.entity.identity.icon_dark = "🔧"
mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
mock_controller.plugin_id = None
mock_controller.plugin_unique_identifier = None
mock_controller.tool_labels = []
mock_controller.need_credentials = False
# Mock credentials schema
mock_credential = Mock()
mock_credential.to_basic_provider_config.return_value.name = "api_key"
mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
# Act: Execute the method under test
result = ToolTransformService.builtin_provider_to_user_provider(
mock_controller, None, decrypt_credentials=False
)
# Assert: Verify the expected outcomes
assert result is not None
assert result.is_team_authorization is True
assert result.allow_delete is False
assert result.masked_credentials == {"api_key": ""}
def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful conversion of API provider to controller.
This test verifies:
- Proper controller creation from database provider
- Auth type handling for different credential types
- Backward compatibility for auth types
"""
# Arrange: Setup test data
fake = Faker()
# Create API tool provider with api_key_header auth
provider = ApiToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
tenant_id=fake.uuid4(),
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
tools_str="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
# Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider)
# Assert: Verify the expected outcomes
assert result is not None
assert hasattr(result, "from_db")
# Additional assertions would depend on the actual controller implementation
def test_api_provider_to_controller_api_key_query(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test conversion of API provider to controller with api_key_query auth type.
This test verifies:
- Proper auth type handling for query parameter authentication
"""
# Arrange: Setup test data
fake = Faker()
# Create API tool provider with api_key_query auth
provider = ApiToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
tenant_id=fake.uuid4(),
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
tools_str="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
# Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider)
# Assert: Verify the expected outcomes
assert result is not None
assert hasattr(result, "from_db")
def test_api_provider_to_controller_backward_compatibility(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test conversion of API provider to controller with backward compatibility auth types.
This test verifies:
- Proper handling of legacy auth type values
- Backward compatibility for api_key and api_key_header
"""
# Arrange: Setup test data
fake = Faker()
# Create API tool provider with legacy auth type
provider = ApiToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
tenant_id=fake.uuid4(),
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
tools_str="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
# Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider)
# Assert: Verify the expected outcomes
assert result is not None
assert hasattr(result, "from_db")
def test_workflow_provider_to_controller_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful conversion of workflow provider to controller.
This test verifies:
- Proper controller creation from workflow provider
- Workflow-specific controller handling
"""
# Arrange: Setup test data
fake = Faker()
# Create workflow tool provider
provider = WorkflowToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
tenant_id=fake.uuid4(),
user_id=fake.uuid4(),
app_id=fake.uuid4(),
label="Test Workflow",
version="1.0.0",
parameter_configuration="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
# Mock the WorkflowToolProviderController.from_db method to avoid app dependency
with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db:
mock_controller = Mock()
mock_from_db.return_value = mock_controller
# Act: Execute the method under test
result = ToolTransformService.workflow_provider_to_controller(provider)
# Assert: Verify the expected outcomes
assert result is not None
assert result == mock_controller
mock_from_db.assert_called_once_with(provider)

View File

@@ -0,0 +1,716 @@
import json
from unittest.mock import patch
import pytest
from faker import Faker
from models.tools import WorkflowToolProvider
from models.workflow import Workflow as WorkflowModel
from services.account_service import AccountService, TenantService
from services.app_service import AppService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
class TestWorkflowToolManageService:
"""Integration tests for WorkflowToolManageService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.account_service.FeatureService") as mock_account_feature_service,
patch(
"services.tools.workflow_tools_manage_service.WorkflowToolProviderController"
) as mock_workflow_tool_provider_controller,
patch("services.tools.workflow_tools_manage_service.ToolLabelManager") as mock_tool_label_manager,
patch("services.tools.workflow_tools_manage_service.ToolTransformService") as mock_tool_transform_service,
):
# Setup default mock returns for app service
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
# Setup default mock returns for account service
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
# Mock ModelManager for model configuration
mock_model_instance = mock_model_manager.return_value
mock_model_instance.get_default_model_instance.return_value = None
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
# Mock WorkflowToolProviderController
mock_workflow_tool_provider_controller.from_db.return_value = None
# Mock ToolLabelManager
mock_tool_label_manager.update_tool_labels.return_value = None
# Mock ToolTransformService
mock_tool_transform_service.workflow_provider_to_controller.return_value = None
yield {
"feature_service": mock_feature_service,
"enterprise_service": mock_enterprise_service,
"model_manager": mock_model_manager,
"account_feature_service": mock_account_feature_service,
"workflow_tool_provider_controller": mock_workflow_tool_provider_controller,
"tool_label_manager": mock_tool_label_manager,
"tool_transform_service": mock_tool_transform_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (app, account, workflow) - Created app, account and workflow instances
"""
fake = Faker()
# Setup mocks for account creation
mock_external_service_dependencies[
"account_feature_service"
].get_system_features.return_value.is_allow_register = True
# Create account and tenant
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "workflow",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
# Create workflow for the app
workflow = WorkflowModel(
tenant_id=tenant.id,
app_id=app.id,
type="workflow",
version="1.0.0",
graph=json.dumps({}),
features=json.dumps({}),
created_by=account.id,
environment_variables=[],
conversation_variables=[],
)
from extensions.ext_database import db
db.session.add(workflow)
db.session.commit()
# Update app to reference the workflow
app.workflow_id = workflow.id
db.session.commit()
return app, account, workflow
def _create_test_workflow_tool_parameters(self):
"""Helper method to create valid workflow tool parameters."""
return [
{
"name": "input_text",
"description": "Input text for processing",
"form": "form",
"type": "string",
"required": True,
},
{
"name": "output_format",
"description": "Output format specification",
"form": "form",
"type": "select",
"required": False,
},
]
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful workflow tool creation with valid parameters.
This test verifies:
- Proper workflow tool creation with all required fields
- Correct database state after creation
- Proper relationship establishment
- External service integration
- Return value correctness
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Setup workflow tool creation parameters
tool_name = fake.word()
tool_label = fake.word()
tool_icon = {"type": "emoji", "emoji": "🔧"}
tool_description = fake.text(max_nb_chars=200)
tool_parameters = self._create_test_workflow_tool_parameters()
tool_privacy_policy = fake.text(max_nb_chars=100)
tool_labels = ["automation", "workflow"]
# Execute the method under test
result = WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=tool_name,
label=tool_label,
icon=tool_icon,
description=tool_description,
parameters=tool_parameters,
privacy_policy=tool_privacy_policy,
labels=tool_labels,
)
# Verify the result
assert result == {"result": "success"}
# Verify database state
from extensions.ext_database import db
# Check if workflow tool provider was created
created_tool_provider = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
)
.first()
)
assert created_tool_provider is not None
assert created_tool_provider.name == tool_name
assert created_tool_provider.label == tool_label
assert created_tool_provider.icon == json.dumps(tool_icon)
assert created_tool_provider.description == tool_description
assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters)
assert created_tool_provider.privacy_policy == tool_privacy_policy
assert created_tool_provider.version == workflow.version
assert created_tool_provider.user_id == account.id
assert created_tool_provider.tenant_id == account.current_tenant.id
assert created_tool_provider.app_id == app.id
# Verify external service calls
mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.assert_called_once()
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()
mock_external_service_dependencies[
"tool_transform_service"
].workflow_provider_to_controller.assert_called_once()
def test_create_workflow_tool_duplicate_name_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when name already exists.
This test verifies:
- Proper error handling for duplicate tool names
- Database constraint enforcement
- Correct error message
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Create first workflow tool
first_tool_name = fake.word()
first_tool_parameters = self._create_test_workflow_tool_parameters()
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=first_tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=first_tool_parameters,
)
# Attempt to create second workflow tool with same name
second_tool_parameters = self._create_test_workflow_tool_parameters()
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=first_tool_name, # Same name
label=fake.word(),
icon={"type": "emoji", "emoji": "⚙️"},
description=fake.text(max_nb_chars=200),
parameters=second_tool_parameters,
)
# Verify error message
assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value)
# Verify only one tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
.count()
)
assert tool_count == 1
def test_create_workflow_tool_invalid_app_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when app does not exist.
This test verifies:
- Proper error handling for non-existent apps
- Correct error message
- No database changes when app is invalid
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Generate non-existent app ID
non_existent_app_id = fake.uuid4()
# Attempt to create workflow tool with non-existent app
tool_parameters = self._create_test_workflow_tool_parameters()
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=non_existent_app_id, # Non-existent app ID
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=tool_parameters,
)
# Verify error message
assert f"App {non_existent_app_id} not found" in str(exc_info.value)
# Verify no workflow tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
.count()
)
assert tool_count == 0
def test_create_workflow_tool_invalid_parameters_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when parameters are invalid.
This test verifies:
- Proper error handling for invalid parameter configurations
- Parameter validation enforcement
- Correct error message
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Setup invalid workflow tool parameters (missing required fields)
invalid_parameters = [
{
"name": "input_text",
# Missing description and form fields
"type": "string",
"required": True,
}
]
# Attempt to create workflow tool with invalid parameters
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=invalid_parameters,
)
# Verify error message contains validation error
assert "validation error" in str(exc_info.value).lower()
# Verify no workflow tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
.count()
)
assert tool_count == 0
def test_create_workflow_tool_duplicate_app_id_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when app_id already exists.
This test verifies:
- Proper error handling for duplicate app_id
- Database constraint enforcement for app_id uniqueness
- Correct error message
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Create first workflow tool
first_tool_name = fake.word()
first_tool_parameters = self._create_test_workflow_tool_parameters()
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=first_tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=first_tool_parameters,
)
# Attempt to create second workflow tool with same app_id but different name
second_tool_name = fake.word()
second_tool_parameters = self._create_test_workflow_tool_parameters()
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id, # Same app_id
name=second_tool_name, # Different name
label=fake.word(),
icon={"type": "emoji", "emoji": "⚙️"},
description=fake.text(max_nb_chars=200),
parameters=second_tool_parameters,
)
# Verify error message
assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value)
# Verify only one tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
.count()
)
assert tool_count == 1
def test_create_workflow_tool_workflow_not_found_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when app has no workflow.
This test verifies:
- Proper error handling for apps without workflows
- Correct error message
- No database changes when workflow is missing
"""
fake = Faker()
# Create test data but without workflow
app, account, _ = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Remove workflow reference from app
from extensions.ext_database import db
app.workflow_id = None
db.session.commit()
# Attempt to create workflow tool for app without workflow
tool_parameters = self._create_test_workflow_tool_parameters()
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=tool_parameters,
)
# Verify error message
assert f"Workflow not found for app {app.id}" in str(exc_info.value)
# Verify no workflow tool was created
tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
.count()
)
assert tool_count == 0
def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful workflow tool update with valid parameters.
This test verifies:
- Proper workflow tool update with all required fields
- Correct database state after update
- Proper relationship maintenance
- External service integration
- Return value correctness
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Create initial workflow tool
initial_tool_name = fake.word()
initial_tool_parameters = self._create_test_workflow_tool_parameters()
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=initial_tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=initial_tool_parameters,
)
# Get the created tool
from extensions.ext_database import db
created_tool = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
)
.first()
)
# Setup update parameters
updated_tool_name = fake.word()
updated_tool_label = fake.word()
updated_tool_icon = {"type": "emoji", "emoji": "⚙️"}
updated_tool_description = fake.text(max_nb_chars=200)
updated_tool_parameters = self._create_test_workflow_tool_parameters()
updated_tool_privacy_policy = fake.text(max_nb_chars=100)
updated_tool_labels = ["automation", "updated"]
# Execute the update method
result = WorkflowToolManageService.update_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_tool_id=created_tool.id,
name=updated_tool_name,
label=updated_tool_label,
icon=updated_tool_icon,
description=updated_tool_description,
parameters=updated_tool_parameters,
privacy_policy=updated_tool_privacy_policy,
labels=updated_tool_labels,
)
# Verify the result
assert result == {"result": "success"}
# Verify database state was updated
db.session.refresh(created_tool)
assert created_tool.name == updated_tool_name
assert created_tool.label == updated_tool_label
assert created_tool.icon == json.dumps(updated_tool_icon)
assert created_tool.description == updated_tool_description
assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters)
assert created_tool.privacy_policy == updated_tool_privacy_policy
assert created_tool.version == workflow.version
assert created_tool.updated_at is not None
# Verify external service calls
mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.assert_called()
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called()
mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called()
def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test workflow tool update fails when tool does not exist.
This test verifies:
- Proper error handling for non-existent tools
- Correct error message
- No database changes when tool is invalid
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Generate non-existent tool ID
non_existent_tool_id = fake.uuid4()
# Attempt to update non-existent workflow tool
tool_parameters = self._create_test_workflow_tool_parameters()
with pytest.raises(ValueError) as exc_info:
WorkflowToolManageService.update_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_tool_id=non_existent_tool_id, # Non-existent tool ID
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=tool_parameters,
)
# Verify error message
assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value)
# Verify no workflow tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
.count()
)
assert tool_count == 0
def test_update_workflow_tool_same_name_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool update succeeds when keeping the same name.
This test verifies:
- Proper handling when updating tool with same name
- Database state maintenance
- Update timestamp is set
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Create first workflow tool
first_tool_name = fake.word()
first_tool_parameters = self._create_test_workflow_tool_parameters()
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=first_tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=first_tool_parameters,
)
# Get the created tool
from extensions.ext_database import db
created_tool = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
)
.first()
)
# Attempt to update tool with same name (should not fail)
result = WorkflowToolManageService.update_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_tool_id=created_tool.id,
name=first_tool_name, # Same name
label=fake.word(),
icon={"type": "emoji", "emoji": "⚙️"},
description=fake.text(max_nb_chars=200),
parameters=first_tool_parameters,
)
# Verify update was successful
assert result == {"result": "success"}
# Verify tool still exists with the same name
db.session.refresh(created_tool)
assert created_tool.name == first_tool_name
assert created_tool.updated_at is not None

View File

@@ -0,0 +1,554 @@
import json
from unittest.mock import patch
import pytest
from faker import Faker
from core.app.app_config.entities import (
DatasetEntity,
DatasetRetrieveConfigEntity,
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
VariableEntityType,
)
from core.model_runtime.entities.llm_entities import LLMMode
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models import Account, Tenant
from models.api_based_extension import APIBasedExtension
from models.model import App, AppMode, AppModelConfig
from models.workflow import Workflow
from services.workflow.workflow_converter import WorkflowConverter
class TestWorkflowConverter:
"""Integration tests for WorkflowConverter using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.workflow.workflow_converter.encrypter") as mock_encrypter,
patch("services.workflow.workflow_converter.SimplePromptTransform") as mock_prompt_transform,
patch("services.workflow.workflow_converter.AgentChatAppConfigManager") as mock_agent_chat_config_manager,
patch("services.workflow.workflow_converter.ChatAppConfigManager") as mock_chat_config_manager,
patch("services.workflow.workflow_converter.CompletionAppConfigManager") as mock_completion_config_manager,
):
# Setup default mock returns
mock_encrypter.decrypt_token.return_value = "decrypted_api_key"
mock_prompt_transform.return_value.get_prompt_template.return_value = {
"prompt_template": PromptTemplateParser(template="You are a helpful assistant {{text_input}}"),
"prompt_rules": {"human_prefix": "Human", "assistant_prefix": "Assistant"},
}
mock_agent_chat_config_manager.get_app_config.return_value = self._create_mock_app_config()
mock_chat_config_manager.get_app_config.return_value = self._create_mock_app_config()
mock_completion_config_manager.get_app_config.return_value = self._create_mock_app_config()
yield {
"encrypter": mock_encrypter,
"prompt_transform": mock_prompt_transform,
"agent_chat_config_manager": mock_agent_chat_config_manager,
"chat_config_manager": mock_chat_config_manager,
"completion_config_manager": mock_completion_config_manager,
}
def _create_mock_app_config(self):
"""Helper method to create a mock app config."""
mock_config = type("obj", (object,), {})()
mock_config.variables = [
VariableEntity(
variable="text_input",
label="Text Input",
type=VariableEntityType.TEXT_INPUT,
)
]
mock_config.model = ModelConfigEntity(
provider="openai",
model="gpt-4",
mode=LLMMode.CHAT,
parameters={},
stop=[],
)
mock_config.prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="You are a helpful assistant {{text_input}}",
)
mock_config.dataset = None
mock_config.external_data_variables = []
mock_config.additional_features = type("obj", (object,), {"file_upload": None})()
mock_config.app_model_config_dict = {}
return mock_config
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
tuple: (account, tenant) - Created account and tenant instances
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant, account):
"""
Helper method to create a test app for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
tenant: Tenant instance
account: Account instance
Returns:
App: Created app instance
"""
fake = Faker()
# Create app
app = App(
tenant_id=tenant.id,
name=fake.company(),
mode=AppMode.CHAT,
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
enable_site=True,
enable_api=True,
api_rpm=100,
api_rph=10,
is_demo=False,
is_public=False,
created_by=account.id,
updated_by=account.id,
)
from extensions.ext_database import db
db.session.add(app)
db.session.commit()
# Create app model config
app_model_config = AppModelConfig(
app_id=app.id,
provider="openai",
model="gpt-4",
configs={},
created_by=account.id,
updated_by=account.id,
)
db.session.add(app_model_config)
db.session.commit()
# Link app model config to app
app.app_model_config_id = app_model_config.id
db.session.commit()
return app
def test_convert_to_workflow_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful conversion of app to workflow.
This test verifies:
- Proper app to workflow conversion
- Correct database state after conversion
- Proper relationship establishment
- Workflow creation with correct configuration
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant, account)
# Act: Execute the conversion
workflow_converter = WorkflowConverter()
new_app = workflow_converter.convert_to_workflow(
app_model=app,
account=account,
name="Test Workflow App",
icon_type="emoji",
icon="🚀",
icon_background="#4CAF50",
)
# Assert: Verify the expected outcomes
assert new_app is not None
assert new_app.name == "Test Workflow App"
assert new_app.mode == AppMode.ADVANCED_CHAT
assert new_app.icon_type == "emoji"
assert new_app.icon == "🚀"
assert new_app.icon_background == "#4CAF50"
assert new_app.tenant_id == app.tenant_id
assert new_app.created_by == account.id
# Verify database state
from extensions.ext_database import db
db.session.refresh(new_app)
assert new_app.id is not None
# Verify workflow was created
workflow = db.session.query(Workflow).where(Workflow.app_id == new_app.id).first()
assert workflow is not None
assert workflow.tenant_id == app.tenant_id
assert workflow.type == "chat"
def test_convert_to_workflow_without_app_model_config_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error handling when app model config is missing.
This test verifies:
- Proper error handling for missing app model config
- Correct exception type and message
- Database state remains unchanged
"""
# Arrange: Create test data without app model config
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
app = App(
tenant_id=tenant.id,
name=fake.company(),
mode=AppMode.CHAT,
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
enable_site=True,
enable_api=True,
api_rpm=100,
api_rph=10,
is_demo=False,
is_public=False,
created_by=account.id,
updated_by=account.id,
)
from extensions.ext_database import db
db.session.add(app)
db.session.commit()
# Act & Assert: Verify proper error handling
workflow_converter = WorkflowConverter()
# Check initial state
initial_workflow_count = db.session.query(Workflow).count()
with pytest.raises(ValueError, match="App model config is required"):
workflow_converter.convert_to_workflow(
app_model=app,
account=account,
name="Test Workflow App",
icon_type="emoji",
icon="🚀",
icon_background="#4CAF50",
)
# Verify database state remains unchanged
# The workflow creation happens in convert_app_model_config_to_workflow
# which is called before the app_model_config check, so we need to clean up
db.session.rollback()
final_workflow_count = db.session.query(Workflow).count()
assert final_workflow_count == initial_workflow_count
def test_convert_app_model_config_to_workflow_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful conversion of app model config to workflow.
This test verifies:
- Proper app model config to workflow conversion
- Correct workflow graph structure
- Proper node creation and configuration
- Database state management
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant, account)
# Act: Execute the conversion
workflow_converter = WorkflowConverter()
workflow = workflow_converter.convert_app_model_config_to_workflow(
app_model=app,
app_model_config=app.app_model_config,
account_id=account.id,
)
# Assert: Verify the expected outcomes
assert workflow is not None
assert workflow.tenant_id == app.tenant_id
assert workflow.app_id == app.id
assert workflow.type == "chat"
assert workflow.version == Workflow.VERSION_DRAFT
assert workflow.created_by == account.id
# Verify workflow graph structure
graph = json.loads(workflow.graph)
assert "nodes" in graph
assert "edges" in graph
assert len(graph["nodes"]) > 0
assert len(graph["edges"]) > 0
# Verify start node exists
start_node = next((node for node in graph["nodes"] if node["data"]["type"] == "start"), None)
assert start_node is not None
assert start_node["id"] == "start"
# Verify LLM node exists
llm_node = next((node for node in graph["nodes"] if node["data"]["type"] == "llm"), None)
assert llm_node is not None
assert llm_node["id"] == "llm"
# Verify answer node exists for chat mode
answer_node = next((node for node in graph["nodes"] if node["data"]["type"] == "answer"), None)
assert answer_node is not None
assert answer_node["id"] == "answer"
# Verify database state
from extensions.ext_database import db
db.session.refresh(workflow)
assert workflow.id is not None
# Verify features were set
features = json.loads(workflow._features) if workflow._features else {}
assert isinstance(features, dict)
def test_convert_to_start_node_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful conversion to start node.
This test verifies:
- Proper start node creation with variables
- Correct node structure and data
- Variable encoding and formatting
"""
# Arrange: Create test variables
variables = [
VariableEntity(
variable="text_input",
label="Text Input",
type=VariableEntityType.TEXT_INPUT,
),
VariableEntity(
variable="number_input",
label="Number Input",
type=VariableEntityType.NUMBER,
),
]
# Act: Execute the conversion
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(variables=variables)
# Assert: Verify the expected outcomes
assert start_node is not None
assert start_node["id"] == "start"
assert start_node["data"]["title"] == "START"
assert start_node["data"]["type"] == "start"
assert len(start_node["data"]["variables"]) == 2
# Verify variable encoding
first_variable = start_node["data"]["variables"][0]
assert first_variable["variable"] == "text_input"
assert first_variable["label"] == "Text Input"
assert first_variable["type"] == "text-input"
second_variable = start_node["data"]["variables"][1]
assert second_variable["variable"] == "number_input"
assert second_variable["label"] == "Number Input"
assert second_variable["type"] == "number"
def test_convert_to_http_request_node_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful conversion to HTTP request node.
This test verifies:
- Proper HTTP request node creation
- Correct API configuration and authorization
- Code node creation for response parsing
- External data variable mapping
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant, account)
# Create API based extension
api_based_extension = APIBasedExtension(
tenant_id=tenant.id,
name="Test API Extension",
api_key="encrypted_api_key",
api_endpoint="https://api.example.com/test",
)
from extensions.ext_database import db
db.session.add(api_based_extension)
db.session.commit()
# Mock encrypter
mock_external_service_dependencies["encrypter"].decrypt_token.return_value = "decrypted_api_key"
variables = [
VariableEntity(
variable="user_input",
label="User Input",
type=VariableEntityType.TEXT_INPUT,
)
]
external_data_variables = [
ExternalDataVariableEntity(
variable="external_data", type="api", config={"api_based_extension_id": api_based_extension.id}
)
]
# Act: Execute the conversion
workflow_converter = WorkflowConverter()
nodes, external_data_variable_node_mapping = workflow_converter._convert_to_http_request_node(
app_model=app,
variables=variables,
external_data_variables=external_data_variables,
)
# Assert: Verify the expected outcomes
assert len(nodes) == 2 # HTTP request node + code node
assert len(external_data_variable_node_mapping) == 1
# Verify HTTP request node
http_request_node = nodes[0]
assert http_request_node["data"]["type"] == "http-request"
assert http_request_node["data"]["method"] == "post"
assert http_request_node["data"]["url"] == api_based_extension.api_endpoint
assert http_request_node["data"]["authorization"]["type"] == "api-key"
assert http_request_node["data"]["authorization"]["config"]["type"] == "bearer"
assert http_request_node["data"]["authorization"]["config"]["api_key"] == "decrypted_api_key"
# Verify code node
code_node = nodes[1]
assert code_node["data"]["type"] == "code"
assert code_node["data"]["code_language"] == "python3"
assert "response_json" in code_node["data"]["variables"][0]["variable"]
# Verify mapping
assert external_data_variable_node_mapping["external_data"] == code_node["id"]
def test_convert_to_knowledge_retrieval_node_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful conversion to knowledge retrieval node.
This test verifies:
- Proper knowledge retrieval node creation
- Correct dataset configuration
- Model configuration integration
- Query variable selector setup
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
# Create dataset config
dataset_config = DatasetEntity(
dataset_ids=["dataset_1", "dataset_2"],
retrieve_config=DatasetRetrieveConfigEntity(
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=10,
score_threshold=0.8,
reranking_model={"provider": "cohere", "model": "rerank-v2"},
reranking_enabled=True,
),
)
model_config = ModelConfigEntity(
provider="openai",
model="gpt-4",
mode=LLMMode.CHAT,
parameters={"temperature": 0.7},
stop=[],
)
# Act: Execute the conversion for advanced chat mode
workflow_converter = WorkflowConverter()
node = workflow_converter._convert_to_knowledge_retrieval_node(
new_app_mode=AppMode.ADVANCED_CHAT,
dataset_config=dataset_config,
model_config=model_config,
)
# Assert: Verify the expected outcomes
assert node is not None
assert node["data"]["type"] == "knowledge-retrieval"
assert node["data"]["title"] == "KNOWLEDGE RETRIEVAL"
assert node["data"]["dataset_ids"] == ["dataset_1", "dataset_2"]
assert node["data"]["retrieval_mode"] == "multiple"
assert node["data"]["query_variable_selector"] == ["sys", "query"]
# Verify multiple retrieval config
multiple_config = node["data"]["multiple_retrieval_config"]
assert multiple_config["top_k"] == 10
assert multiple_config["score_threshold"] == 0.8
assert multiple_config["reranking_model"]["provider"] == "cohere"
assert multiple_config["reranking_model"]["model"] == "rerank-v2"
# Verify single retrieval config is None for multiple strategy
assert node["data"]["single_retrieval_config"] is None