dify
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,885 @@
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.prompt.prompt_templates.advanced_prompt_templates import (
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
||||
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
|
||||
BAICHUAN_CONTEXT,
|
||||
CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||
COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
||||
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
|
||||
CONTEXT,
|
||||
)
|
||||
from models.model import AppMode
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
|
||||
class TestAdvancedPromptTemplateService:
|
||||
"""Integration tests for AdvancedPromptTemplateService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
# This service doesn't have external dependencies, but we keep the pattern
|
||||
# for consistency with other test files
|
||||
return {}
|
||||
|
||||
def test_get_prompt_baichuan_model_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful prompt generation for Baichuan model.
|
||||
|
||||
This test verifies:
|
||||
- Proper prompt generation for Baichuan models
|
||||
- Correct model detection logic
|
||||
- Appropriate prompt template selection
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test data for Baichuan model
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
|
||||
# Verify context is included for Baichuan model
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert BAICHUAN_CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
assert "{{#histories#}}" in prompt_text
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_prompt_common_model_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful prompt generation for common models.
|
||||
|
||||
This test verifies:
|
||||
- Proper prompt generation for non-Baichuan models
|
||||
- Correct model detection logic
|
||||
- Appropriate prompt template selection
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test data for common model
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
|
||||
# Verify context is included for common model
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
assert "{{#histories#}}" in prompt_text
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_prompt_case_insensitive_baichuan_detection(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan model detection is case insensitive.
|
||||
|
||||
This test verifies:
|
||||
- Model name detection works regardless of case
|
||||
- Proper prompt template selection for different case variations
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test different case variations
|
||||
test_cases = ["Baichuan-13B-Chat", "BAICHUAN-13B-CHAT", "baichuan-13b-chat", "BaiChuan-13B-Chat"]
|
||||
|
||||
for model_name in test_cases:
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": model_name,
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify Baichuan template is used
|
||||
assert result is not None
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert BAICHUAN_CONTEXT in prompt_text
|
||||
|
||||
def test_get_common_prompt_chat_app_completion_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation for chat app with completion mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct prompt template selection for chat app + completion mode
|
||||
- Proper context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
assert "conversation_histories_role" in result["completion_prompt_config"]
|
||||
assert "stop" in result
|
||||
|
||||
# Verify context is included
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
assert "{{#histories#}}" in prompt_text
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_chat_app_chat_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test common prompt generation for chat app with chat mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct prompt template selection for chat app + chat mode
|
||||
- Proper context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "chat_prompt_config" in result
|
||||
assert "prompt" in result["chat_prompt_config"]
|
||||
assert len(result["chat_prompt_config"]["prompt"]) > 0
|
||||
assert "role" in result["chat_prompt_config"]["prompt"][0]
|
||||
assert "text" in result["chat_prompt_config"]["prompt"][0]
|
||||
|
||||
# Verify context is included
|
||||
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_completion_app_completion_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation for completion app with completion mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct prompt template selection for completion app + completion mode
|
||||
- Proper context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
assert "stop" in result
|
||||
|
||||
# Verify context is included
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_completion_app_chat_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation for completion app with chat mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct prompt template selection for completion app + chat mode
|
||||
- Proper context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "chat_prompt_config" in result
|
||||
assert "prompt" in result["chat_prompt_config"]
|
||||
assert len(result["chat_prompt_config"]["prompt"]) > 0
|
||||
assert "role" in result["chat_prompt_config"]["prompt"][0]
|
||||
assert "text" in result["chat_prompt_config"]["prompt"][0]
|
||||
|
||||
# Verify context is included
|
||||
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test common prompt generation without context.
|
||||
|
||||
This test verifies:
|
||||
- Correct handling when has_context is "false"
|
||||
- Context is not included in prompt
|
||||
- Template structure remains intact
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "false")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
|
||||
# Verify context is NOT included
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert CONTEXT not in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
assert "{{#histories#}}" in prompt_text
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_unsupported_app_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation with unsupported app mode.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of unsupported app modes
|
||||
- Default empty dict return
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt("unsupported_mode", "completion", "true")
|
||||
|
||||
# Assert: Verify empty dict is returned
|
||||
assert result == {}
|
||||
|
||||
def test_get_common_prompt_unsupported_model_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation with unsupported model mode.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of unsupported model modes
|
||||
- Default empty dict return
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "unsupported_mode", "true")
|
||||
|
||||
# Assert: Verify empty dict is returned
|
||||
assert result == {}
|
||||
|
||||
def test_get_completion_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test completion prompt generation with context.
|
||||
|
||||
This test verifies:
|
||||
- Proper context integration in completion prompts
|
||||
- Template structure preservation
|
||||
- Context placement at the beginning
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test prompt template
|
||||
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "true", CONTEXT)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
|
||||
# Verify context is prepended to original text
|
||||
result_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert result_text.startswith(CONTEXT)
|
||||
assert original_text in result_text
|
||||
assert result_text == CONTEXT + original_text
|
||||
|
||||
def test_get_completion_prompt_without_context(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test completion prompt generation without context.
|
||||
|
||||
This test verifies:
|
||||
- Original template is preserved when no context
|
||||
- No modification to prompt text
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test prompt template
|
||||
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
|
||||
# Verify original text is unchanged
|
||||
result_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert result_text == original_text
|
||||
assert CONTEXT not in result_text
|
||||
|
||||
def test_get_chat_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test chat prompt generation with context.
|
||||
|
||||
This test verifies:
|
||||
- Proper context integration in chat prompts
|
||||
- Template structure preservation
|
||||
- Context placement at the beginning of first message
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test prompt template
|
||||
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
|
||||
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "true", CONTEXT)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "chat_prompt_config" in result
|
||||
assert "prompt" in result["chat_prompt_config"]
|
||||
assert len(result["chat_prompt_config"]["prompt"]) > 0
|
||||
assert "text" in result["chat_prompt_config"]["prompt"][0]
|
||||
|
||||
# Verify context is prepended to original text
|
||||
result_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert result_text.startswith(CONTEXT)
|
||||
assert original_text in result_text
|
||||
assert result_text == CONTEXT + original_text
|
||||
|
||||
def test_get_chat_prompt_without_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test chat prompt generation without context.
|
||||
|
||||
This test verifies:
|
||||
- Original template is preserved when no context
|
||||
- No modification to prompt text
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test prompt template
|
||||
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
|
||||
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "chat_prompt_config" in result
|
||||
assert "prompt" in result["chat_prompt_config"]
|
||||
assert len(result["chat_prompt_config"]["prompt"]) > 0
|
||||
assert "text" in result["chat_prompt_config"]["prompt"][0]
|
||||
|
||||
# Verify original text is unchanged
|
||||
result_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert result_text == original_text
|
||||
assert CONTEXT not in result_text
|
||||
|
||||
def test_get_baichuan_prompt_chat_app_completion_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation for chat app with completion mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct Baichuan prompt template selection for chat app + completion mode
|
||||
- Proper Baichuan context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
assert "conversation_histories_role" in result["completion_prompt_config"]
|
||||
assert "stop" in result
|
||||
|
||||
# Verify Baichuan context is included
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert BAICHUAN_CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
assert "{{#histories#}}" in prompt_text
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_chat_app_chat_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation for chat app with chat mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct Baichuan prompt template selection for chat app + chat mode
|
||||
- Proper Baichuan context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "chat_prompt_config" in result
|
||||
assert "prompt" in result["chat_prompt_config"]
|
||||
assert len(result["chat_prompt_config"]["prompt"]) > 0
|
||||
assert "role" in result["chat_prompt_config"]["prompt"][0]
|
||||
assert "text" in result["chat_prompt_config"]["prompt"][0]
|
||||
|
||||
# Verify Baichuan context is included
|
||||
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert BAICHUAN_CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_completion_app_completion_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation for completion app with completion mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct Baichuan prompt template selection for completion app + completion mode
|
||||
- Proper Baichuan context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
assert "stop" in result
|
||||
|
||||
# Verify Baichuan context is included
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert BAICHUAN_CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_completion_app_chat_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation for completion app with chat mode.
|
||||
|
||||
This test verifies:
|
||||
- Correct Baichuan prompt template selection for completion app + chat mode
|
||||
- Proper Baichuan context integration
|
||||
- Template structure validation
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "true")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "chat_prompt_config" in result
|
||||
assert "prompt" in result["chat_prompt_config"]
|
||||
assert len(result["chat_prompt_config"]["prompt"]) > 0
|
||||
assert "role" in result["chat_prompt_config"]["prompt"][0]
|
||||
assert "text" in result["chat_prompt_config"]["prompt"][0]
|
||||
|
||||
# Verify Baichuan context is included
|
||||
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert BAICHUAN_CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test Baichuan prompt generation without context.
|
||||
|
||||
This test verifies:
|
||||
- Correct handling when has_context is "false"
|
||||
- Baichuan context is not included in prompt
|
||||
- Template structure remains intact
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "completion_prompt_config" in result
|
||||
assert "prompt" in result["completion_prompt_config"]
|
||||
assert "text" in result["completion_prompt_config"]["prompt"]
|
||||
|
||||
# Verify Baichuan context is NOT included
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert BAICHUAN_CONTEXT not in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
assert "{{#histories#}}" in prompt_text
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_unsupported_app_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation with unsupported app mode.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of unsupported app modes
|
||||
- Default empty dict return
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt("unsupported_mode", "completion", "true")
|
||||
|
||||
# Assert: Verify empty dict is returned
|
||||
assert result == {}
|
||||
|
||||
def test_get_baichuan_prompt_unsupported_model_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation with unsupported model mode.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of unsupported model modes
|
||||
- Default empty dict return
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "unsupported_mode", "true")
|
||||
|
||||
# Assert: Verify empty dict is returned
|
||||
assert result == {}
|
||||
|
||||
def test_get_prompt_all_app_modes_common_model(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test prompt generation for all app modes with common model.
|
||||
|
||||
This test verifies:
|
||||
- All app modes work correctly with common models
|
||||
- Proper template selection for each combination
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test all app modes
|
||||
app_modes = [AppMode.CHAT, AppMode.COMPLETION]
|
||||
model_modes = ["completion", "chat"]
|
||||
|
||||
for app_mode in app_modes:
|
||||
for model_mode in model_modes:
|
||||
args = {
|
||||
"app_mode": app_mode,
|
||||
"model_mode": model_mode,
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify result is not empty
|
||||
assert result is not None
|
||||
assert result != {}
|
||||
|
||||
def test_get_prompt_all_app_modes_baichuan_model(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test prompt generation for all app modes with Baichuan model.
|
||||
|
||||
This test verifies:
|
||||
- All app modes work correctly with Baichuan models
|
||||
- Proper template selection for each combination
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test all app modes
|
||||
app_modes = [AppMode.CHAT, AppMode.COMPLETION]
|
||||
model_modes = ["completion", "chat"]
|
||||
|
||||
for app_mode in app_modes:
|
||||
for model_mode in model_modes:
|
||||
args = {
|
||||
"app_mode": app_mode,
|
||||
"model_mode": model_mode,
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify result is not empty
|
||||
assert result is not None
|
||||
assert result != {}
|
||||
|
||||
def test_get_prompt_edge_cases(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test prompt generation with edge cases.
|
||||
|
||||
This test verifies:
|
||||
- Handling of edge case inputs
|
||||
- Proper error handling
|
||||
- Consistent behavior with unusual inputs
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test edge cases
|
||||
edge_cases = [
|
||||
{"app_mode": "", "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true"},
|
||||
{"app_mode": AppMode.CHAT, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"},
|
||||
{"app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "", "has_context": "true"},
|
||||
{
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "",
|
||||
},
|
||||
]
|
||||
|
||||
for args in edge_cases:
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify method handles edge cases gracefully
|
||||
# Should either return a valid result or empty dict, but not crash
|
||||
assert result is not None
|
||||
|
||||
def test_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test that original templates are not modified.
|
||||
|
||||
This test verifies:
|
||||
- Original template constants are not modified
|
||||
- Deep copy is used properly
|
||||
- Template immutability is maintained
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Store original templates
|
||||
original_chat_completion = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
original_chat_chat = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
|
||||
original_completion_completion = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
|
||||
original_completion_chat = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG)
|
||||
|
||||
# Test with context
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify original templates are unchanged
|
||||
assert original_chat_completion == CHAT_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert original_chat_chat == CHAT_APP_CHAT_PROMPT_CONFIG
|
||||
assert original_completion_completion == COMPLETION_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert original_completion_chat == COMPLETION_APP_CHAT_PROMPT_CONFIG
|
||||
|
||||
def test_baichuan_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test that original Baichuan templates are not modified.
|
||||
|
||||
This test verifies:
|
||||
- Original Baichuan template constants are not modified
|
||||
- Deep copy is used properly
|
||||
- Template immutability is maintained
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Store original templates
|
||||
original_baichuan_chat_completion = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
original_baichuan_chat_chat = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG)
|
||||
original_baichuan_completion_completion = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
|
||||
original_baichuan_completion_chat = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG)
|
||||
|
||||
# Test with context
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify original templates are unchanged
|
||||
assert original_baichuan_chat_completion == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert original_baichuan_chat_chat == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
|
||||
assert original_baichuan_completion_completion == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert original_baichuan_completion_chat == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
|
||||
|
||||
def test_context_integration_consistency(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test consistency of context integration across different scenarios.
|
||||
|
||||
This test verifies:
|
||||
- Context is always prepended correctly
|
||||
- Context integration is consistent across different templates
|
||||
- No context duplication or corruption
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test different scenarios
|
||||
test_scenarios = [
|
||||
{
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION,
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"has_context": "true",
|
||||
},
|
||||
]
|
||||
|
||||
for args in test_scenarios:
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify context integration is consistent
|
||||
assert result is not None
|
||||
assert result != {}
|
||||
|
||||
# Check that context is properly integrated
|
||||
if "completion_prompt_config" in result:
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert prompt_text.startswith(CONTEXT)
|
||||
elif "chat_prompt_config" in result:
|
||||
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert prompt_text.startswith(CONTEXT)
|
||||
|
||||
def test_baichuan_context_integration_consistency(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test consistency of Baichuan context integration across different scenarios.
|
||||
|
||||
This test verifies:
|
||||
- Baichuan context is always prepended correctly
|
||||
- Context integration is consistent across different templates
|
||||
- No context duplication or corruption
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Test different scenarios
|
||||
test_scenarios = [
|
||||
{
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "chat",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION,
|
||||
"model_mode": "completion",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
},
|
||||
{
|
||||
"app_mode": AppMode.COMPLETION,
|
||||
"model_mode": "chat",
|
||||
"model_name": "baichuan-13b-chat",
|
||||
"has_context": "true",
|
||||
},
|
||||
]
|
||||
|
||||
for args in test_scenarios:
|
||||
# Act: Execute the method under test
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert: Verify context integration is consistent
|
||||
assert result is not None
|
||||
assert result != {}
|
||||
|
||||
# Check that Baichuan context is properly integrated
|
||||
if "completion_prompt_config" in result:
|
||||
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
|
||||
assert prompt_text.startswith(BAICHUAN_CONTEXT)
|
||||
elif "chat_prompt_config" in result:
|
||||
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
|
||||
assert prompt_text.startswith(BAICHUAN_CONTEXT)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,505 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
|
||||
|
||||
class TestAPIBasedExtensionService:
|
||||
"""Integration tests for APIBasedExtensionService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
patch("services.api_based_extension_service.APIBasedExtensionRequestor") as mock_requestor,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_account_feature_service.get_features.return_value.billing.enabled = False
|
||||
|
||||
# Mock successful ping response
|
||||
mock_requestor_instance = mock_requestor.return_value
|
||||
mock_requestor_instance.request.return_value = {"result": "pong"}
|
||||
|
||||
yield {
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"requestor": mock_requestor,
|
||||
"requestor_instance": mock_requestor_instance,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Setup mocks for account creation
|
||||
mock_external_service_dependencies[
|
||||
"account_feature_service"
|
||||
].get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Create account and tenant
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful saving of API-based extension.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
# Save extension
|
||||
saved_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Verify extension was saved correctly
|
||||
assert saved_extension.id is not None
|
||||
assert saved_extension.tenant_id == tenant.id
|
||||
assert saved_extension.name == extension_data.name
|
||||
assert saved_extension.api_endpoint == extension_data.api_endpoint
|
||||
assert saved_extension.api_key == extension_data.api_key # Should be decrypted when retrieved
|
||||
assert saved_extension.created_at is not None
|
||||
|
||||
# Verify extension was saved to database
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(saved_extension)
|
||||
assert saved_extension.id is not None
|
||||
|
||||
# Verify ping connection was called
|
||||
mock_external_service_dependencies["requestor_instance"].request.assert_called_once()
|
||||
|
||||
def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation errors when saving extension with invalid data.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
# Test empty name
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name="",
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Test empty api_endpoint
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = ""
|
||||
|
||||
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Test empty api_key
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = ""
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of all extensions by tenant ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create multiple extensions
|
||||
extensions = []
|
||||
assert tenant is not None
|
||||
for i in range(3):
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=f"Extension {i}: {fake.company()}",
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
saved_extension = APIBasedExtensionService.save(extension_data)
|
||||
extensions.append(saved_extension)
|
||||
|
||||
# Get all extensions for tenant
|
||||
extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id)
|
||||
|
||||
# Verify results
|
||||
assert len(extension_list) == 3
|
||||
|
||||
# Verify all extensions belong to the correct tenant and are ordered by created_at desc
|
||||
for i, extension in enumerate(extension_list):
|
||||
assert extension.tenant_id == tenant.id
|
||||
assert extension.api_key is not None # Should be decrypted
|
||||
if i > 0:
|
||||
# Verify descending order (newer first)
|
||||
assert extension.created_at <= extension_list[i - 1].created_at
|
||||
|
||||
def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of extension by tenant ID and extension ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
# Create an extension
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Get extension by ID
|
||||
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id)
|
||||
|
||||
# Verify extension was retrieved correctly
|
||||
assert retrieved_extension is not None
|
||||
assert retrieved_extension.id == created_extension.id
|
||||
assert retrieved_extension.tenant_id == tenant.id
|
||||
assert retrieved_extension.name == extension_data.name
|
||||
assert retrieved_extension.api_endpoint == extension_data.api_endpoint
|
||||
assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted
|
||||
assert retrieved_extension.created_at is not None
|
||||
|
||||
def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test retrieval of extension when extension is not found.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
non_existent_extension_id = fake.uuid4()
|
||||
|
||||
# Try to get non-existent extension
|
||||
with pytest.raises(ValueError, match="API based extension is not found"):
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id)
|
||||
|
||||
def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful deletion of extension.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
# Create an extension first
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
extension_id = created_extension.id
|
||||
|
||||
# Delete the extension
|
||||
APIBasedExtensionService.delete(created_extension)
|
||||
|
||||
# Verify extension was deleted
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first()
|
||||
assert deleted_extension is None
|
||||
|
||||
def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation error when saving extension with duplicate name.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
# Create first extension
|
||||
extension_data1 = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name="Test Extension",
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
APIBasedExtensionService.save(extension_data1)
|
||||
# Try to create second extension with same name
|
||||
extension_data2 = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name="Test Extension", # Same name
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
|
||||
APIBasedExtensionService.save(extension_data2)
|
||||
|
||||
def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful update of existing extension.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
# Create initial extension
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Save original values for later comparison
|
||||
original_name = created_extension.name
|
||||
original_endpoint = created_extension.api_endpoint
|
||||
|
||||
# Update the extension with guaranteed different values
|
||||
new_name = fake.company()
|
||||
# Ensure new endpoint is different from original
|
||||
new_endpoint = f"https://{fake.domain_name()}/api"
|
||||
# If by chance they're the same, generate a new one
|
||||
while new_endpoint == original_endpoint:
|
||||
new_endpoint = f"https://{fake.domain_name()}/api"
|
||||
new_api_key = fake.password(length=20)
|
||||
|
||||
created_extension.name = new_name
|
||||
created_extension.api_endpoint = new_endpoint
|
||||
created_extension.api_key = new_api_key
|
||||
|
||||
updated_extension = APIBasedExtensionService.save(created_extension)
|
||||
|
||||
# Verify extension was updated correctly
|
||||
assert updated_extension.id == created_extension.id
|
||||
assert updated_extension.tenant_id == tenant.id
|
||||
assert updated_extension.name == new_name
|
||||
assert updated_extension.api_endpoint == new_endpoint
|
||||
|
||||
# Verify original values were changed
|
||||
assert updated_extension.name != original_name
|
||||
assert updated_extension.api_endpoint != original_endpoint
|
||||
|
||||
# Verify ping connection was called for both create and update
|
||||
assert mock_external_service_dependencies["requestor_instance"].request.call_count == 2
|
||||
|
||||
# Verify the update by retrieving the extension again
|
||||
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id)
|
||||
assert retrieved_extension.name == new_name
|
||||
assert retrieved_extension.api_endpoint == new_endpoint
|
||||
assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved
|
||||
|
||||
def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test connection error when saving extension with invalid endpoint.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Mock connection error
|
||||
mock_external_service_dependencies["requestor_instance"].request.side_effect = ValueError(
|
||||
"connection error: request timeout"
|
||||
)
|
||||
assert tenant is not None
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint="https://invalid-endpoint.com/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
# Try to save extension with connection error
|
||||
with pytest.raises(ValueError, match="connection error: request timeout"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_invalid_api_key_length(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test validation error when saving extension with API key that is too short.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
# Setup extension data with short API key
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key="1234", # Less than 5 characters
|
||||
)
|
||||
|
||||
# Try to save extension with short API key
|
||||
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation errors when saving extension with empty required fields.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant is not None
|
||||
# Test with None values
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=None, # type: ignore # why str become None here???
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Test with None api_endpoint
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = None
|
||||
|
||||
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Test with None api_key
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = None
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test retrieval of extensions when no extensions exist for tenant.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Get all extensions for tenant (none exist)
|
||||
extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id)
|
||||
|
||||
# Verify empty list is returned
|
||||
assert len(extension_list) == 0
|
||||
assert extension_list == []
|
||||
|
||||
def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation error when ping response is invalid.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Mock invalid ping response
|
||||
mock_external_service_dependencies["requestor_instance"].request.return_value = {"result": "invalid"}
|
||||
assert tenant is not None
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
# Try to save extension with invalid ping response
|
||||
with pytest.raises(ValueError, match="{'result': 'invalid'}"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation error when ping response is missing result field.
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Mock ping response without result field
|
||||
mock_external_service_dependencies["requestor_instance"].request.return_value = {"status": "ok"}
|
||||
assert tenant is not None
|
||||
# Setup extension data
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
# Try to save extension with missing ping result
|
||||
with pytest.raises(ValueError, match="{'status': 'ok'}"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test retrieval of extension when tenant ID doesn't match.
|
||||
"""
|
||||
fake = Faker()
|
||||
account1, tenant1 = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create second account and tenant
|
||||
account2, tenant2 = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
assert tenant1 is not None
|
||||
# Create extension in first tenant
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=tenant1.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
|
||||
# Try to get extension with wrong tenant ID
|
||||
with pytest.raises(ValueError, match="API based extension is not found"):
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id)
|
||||
@@ -0,0 +1,432 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from faker import Faker
|
||||
|
||||
from models.model import App, AppModelConfig
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class TestAppDslService:
|
||||
"""Integration tests for AppDslService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.app_dsl_service.WorkflowService") as mock_workflow_service,
|
||||
patch("services.app_dsl_service.DependenciesAnalysisService") as mock_dependencies_service,
|
||||
patch("services.app_dsl_service.WorkflowDraftVariableService") as mock_draft_variable_service,
|
||||
patch("services.app_dsl_service.ssrf_proxy") as mock_ssrf_proxy,
|
||||
patch("services.app_dsl_service.redis_client") as mock_redis_client,
|
||||
patch("services.app_dsl_service.app_was_created") as mock_app_was_created,
|
||||
patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated,
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.app_service.FeatureService") as mock_feature_service,
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_workflow_service.return_value.get_draft_workflow.return_value = None
|
||||
mock_workflow_service.return_value.sync_draft_workflow.return_value = MagicMock()
|
||||
mock_dependencies_service.generate_latest_dependencies.return_value = []
|
||||
mock_dependencies_service.get_leaked_dependencies.return_value = []
|
||||
mock_dependencies_service.generate_dependencies.return_value = []
|
||||
mock_draft_variable_service.return_value.delete_workflow_variables.return_value = None
|
||||
mock_ssrf_proxy.get.return_value.content = b"test content"
|
||||
mock_ssrf_proxy.get.return_value.raise_for_status.return_value = None
|
||||
mock_redis_client.setex.return_value = None
|
||||
mock_redis_client.get.return_value = None
|
||||
mock_redis_client.delete.return_value = None
|
||||
mock_app_was_created.send.return_value = None
|
||||
mock_app_model_config_was_updated.send.return_value = None
|
||||
|
||||
# Mock ModelManager for app service
|
||||
mock_model_instance = mock_model_manager.return_value
|
||||
mock_model_instance.get_default_model_instance.return_value = None
|
||||
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||
|
||||
# Mock FeatureService and EnterpriseService
|
||||
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
|
||||
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
|
||||
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
|
||||
|
||||
yield {
|
||||
"workflow_service": mock_workflow_service,
|
||||
"dependencies_service": mock_dependencies_service,
|
||||
"draft_variable_service": mock_draft_variable_service,
|
||||
"ssrf_proxy": mock_ssrf_proxy,
|
||||
"redis_client": mock_redis_client,
|
||||
"app_was_created": mock_app_was_created,
|
||||
"app_model_config_was_updated": mock_app_model_config_was_updated,
|
||||
"model_manager": mock_model_manager,
|
||||
"feature_service": mock_feature_service,
|
||||
"enterprise_service": mock_enterprise_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (app, account) - Created app and account instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Setup mocks for account creation
|
||||
with patch("services.account_service.FeatureService") as mock_account_feature_service:
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Create app
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
return app, account
|
||||
|
||||
def _create_simple_yaml_content(self, app_name="Test App", app_mode="chat"):
|
||||
"""
|
||||
Helper method to create simple YAML content for testing.
|
||||
"""
|
||||
yaml_data = {
|
||||
"version": "0.3.0",
|
||||
"kind": "app",
|
||||
"app": {
|
||||
"name": app_name,
|
||||
"mode": app_mode,
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FFEAD5",
|
||||
"description": "Test app description",
|
||||
"use_icon_as_answer_icon": False,
|
||||
},
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
},
|
||||
"pre_prompt": "You are a helpful assistant.",
|
||||
"prompt_type": "simple",
|
||||
},
|
||||
}
|
||||
return yaml.dump(yaml_data, allow_unicode=True)
|
||||
|
||||
def test_import_app_missing_yaml_content(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app import with missing YAML content.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Import app without YAML content
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
name="Missing Content App",
|
||||
)
|
||||
|
||||
# Verify import failed
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert result.app_id is None
|
||||
assert "yaml_content is required" in result.error
|
||||
assert result.imported_dsl_version == ""
|
||||
|
||||
# Verify no app was created in database
|
||||
apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count()
|
||||
assert apps_count == 1 # Only the original test app
|
||||
|
||||
def test_import_app_missing_yaml_url(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app import with missing YAML URL.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Import app without YAML URL
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_URL,
|
||||
name="Missing URL App",
|
||||
)
|
||||
|
||||
# Verify import failed
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert result.app_id is None
|
||||
assert "yaml_url is required" in result.error
|
||||
assert result.imported_dsl_version == ""
|
||||
|
||||
# Verify no app was created in database
|
||||
apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count()
|
||||
assert apps_count == 1 # Only the original test app
|
||||
|
||||
def test_import_app_invalid_import_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app import with invalid import mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create YAML content
|
||||
yaml_content = self._create_simple_yaml_content(fake.company(), "chat")
|
||||
|
||||
# Import app with invalid mode should raise ValueError
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="Invalid import_mode: invalid-mode"):
|
||||
dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode="invalid-mode",
|
||||
yaml_content=yaml_content,
|
||||
name="Invalid Mode App",
|
||||
)
|
||||
|
||||
# Verify no app was created in database
|
||||
apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count()
|
||||
assert apps_count == 1 # Only the original test app
|
||||
|
||||
def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful DSL export for chat app.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create model config for the app
|
||||
model_config = AppModelConfig()
|
||||
model_config.id = fake.uuid4()
|
||||
model_config.app_id = app.id
|
||||
model_config.provider = "openai"
|
||||
model_config.model_id = "gpt-3.5-turbo"
|
||||
model_config.model = json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
}
|
||||
)
|
||||
model_config.pre_prompt = "You are a helpful assistant."
|
||||
model_config.prompt_type = "simple"
|
||||
model_config.created_by = account.id
|
||||
model_config.updated_by = account.id
|
||||
|
||||
# Set the app_model_config_id to link the config
|
||||
app.app_model_config_id = model_config.id
|
||||
|
||||
db_session_with_containers.add(model_config)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Export DSL
|
||||
exported_dsl = AppDslService.export_dsl(app, include_secret=False)
|
||||
|
||||
# Parse exported YAML
|
||||
exported_data = yaml.safe_load(exported_dsl)
|
||||
|
||||
# Verify exported data structure
|
||||
assert exported_data["kind"] == "app"
|
||||
assert exported_data["app"]["name"] == app.name
|
||||
assert exported_data["app"]["mode"] == app.mode
|
||||
assert exported_data["app"]["icon"] == app.icon
|
||||
assert exported_data["app"]["icon_background"] == app.icon_background
|
||||
assert exported_data["app"]["description"] == app.description
|
||||
|
||||
# Verify model config was exported
|
||||
assert "model_config" in exported_data
|
||||
# The exported model_config structure may be different from the database structure
|
||||
# Check that the model config exists and has the expected content
|
||||
assert exported_data["model_config"] is not None
|
||||
|
||||
# Verify dependencies were exported
|
||||
assert "dependencies" in exported_data
|
||||
assert isinstance(exported_data["dependencies"], list)
|
||||
|
||||
def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful DSL export for workflow app.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Update app to workflow mode
|
||||
app.mode = "workflow"
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock workflow service to return a workflow
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.to_dict.return_value = {
|
||||
"graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []},
|
||||
"features": {},
|
||||
"environment_variables": [],
|
||||
"conversation_variables": [],
|
||||
}
|
||||
mock_external_service_dependencies[
|
||||
"workflow_service"
|
||||
].return_value.get_draft_workflow.return_value = mock_workflow
|
||||
|
||||
# Export DSL
|
||||
exported_dsl = AppDslService.export_dsl(app, include_secret=False)
|
||||
|
||||
# Parse exported YAML
|
||||
exported_data = yaml.safe_load(exported_dsl)
|
||||
|
||||
# Verify exported data structure
|
||||
assert exported_data["kind"] == "app"
|
||||
assert exported_data["app"]["name"] == app.name
|
||||
assert exported_data["app"]["mode"] == "workflow"
|
||||
|
||||
# Verify workflow was exported
|
||||
assert "workflow" in exported_data
|
||||
assert "graph" in exported_data["workflow"]
|
||||
assert "nodes" in exported_data["workflow"]["graph"]
|
||||
|
||||
# Verify dependencies were exported
|
||||
assert "dependencies" in exported_data
|
||||
assert isinstance(exported_data["dependencies"], list)
|
||||
|
||||
# Verify workflow service was called
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with(
|
||||
app, None
|
||||
)
|
||||
|
||||
def test_export_dsl_with_workflow_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful DSL export with specific workflow ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Update app to workflow mode
|
||||
app.mode = "workflow"
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock workflow service to return a workflow when specific workflow_id is provided
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.to_dict.return_value = {
|
||||
"graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []},
|
||||
"features": {},
|
||||
"environment_variables": [],
|
||||
"conversation_variables": [],
|
||||
}
|
||||
|
||||
# Mock the get_draft_workflow method to return different workflows based on workflow_id
|
||||
def mock_get_draft_workflow(app_model, workflow_id=None):
|
||||
if workflow_id == "specific-workflow-id":
|
||||
return mock_workflow
|
||||
return None
|
||||
|
||||
mock_external_service_dependencies[
|
||||
"workflow_service"
|
||||
].return_value.get_draft_workflow.side_effect = mock_get_draft_workflow
|
||||
|
||||
# Export DSL with specific workflow ID
|
||||
exported_dsl = AppDslService.export_dsl(app, include_secret=False, workflow_id="specific-workflow-id")
|
||||
|
||||
# Parse exported YAML
|
||||
exported_data = yaml.safe_load(exported_dsl)
|
||||
|
||||
# Verify exported data structure
|
||||
assert exported_data["kind"] == "app"
|
||||
assert exported_data["app"]["name"] == app.name
|
||||
assert exported_data["app"]["mode"] == "workflow"
|
||||
|
||||
# Verify workflow was exported
|
||||
assert "workflow" in exported_data
|
||||
assert "graph" in exported_data["workflow"]
|
||||
assert "nodes" in exported_data["workflow"]["graph"]
|
||||
|
||||
# Verify dependencies were exported
|
||||
assert "dependencies" in exported_data
|
||||
assert isinstance(exported_data["dependencies"], list)
|
||||
|
||||
# Verify workflow service was called with specific workflow ID
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with(
|
||||
app, "specific-workflow-id"
|
||||
)
|
||||
|
||||
def test_export_dsl_with_invalid_workflow_id_raises_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that export_dsl raises error when invalid workflow ID is provided.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Update app to workflow mode
|
||||
app.mode = "workflow"
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock workflow service to return None when invalid workflow ID is provided
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.return_value = None
|
||||
|
||||
# Export DSL with invalid workflow ID should raise ValueError
|
||||
with pytest.raises(ValueError, match="Missing draft workflow configuration, please check."):
|
||||
AppDslService.export_dsl(app, include_secret=False, workflow_id="invalid-workflow-id")
|
||||
|
||||
# Verify workflow service was called with the invalid workflow ID
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with(
|
||||
app, "invalid-workflow-id"
|
||||
)
|
||||
|
||||
def test_check_dependencies_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful dependency checking.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Mock Redis to return dependencies
|
||||
mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}'
|
||||
mock_external_service_dependencies["redis_client"].get.return_value = mock_dependencies_json
|
||||
|
||||
# Check dependencies
|
||||
dsl_service = AppDslService(db_session_with_containers)
|
||||
result = dsl_service.check_dependencies(app_model=app)
|
||||
|
||||
# Verify result
|
||||
assert result.leaked_dependencies == []
|
||||
|
||||
# Verify Redis was queried
|
||||
mock_external_service_dependencies["redis_client"].get.assert_called_once_with(
|
||||
f"app_check_dependencies:{app.id}"
|
||||
)
|
||||
|
||||
# Verify dependencies service was called
|
||||
mock_external_service_dependencies["dependencies_service"].get_leaked_dependencies.assert_called_once()
|
||||
@@ -0,0 +1,982 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
|
||||
|
||||
|
||||
class TestAppGenerateService:
|
||||
"""Integration tests for AppGenerateService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.billing_service.BillingService") as mock_billing_service,
|
||||
patch("services.app_generate_service.WorkflowService") as mock_workflow_service,
|
||||
patch("services.app_generate_service.RateLimit") as mock_rate_limit,
|
||||
patch("services.app_generate_service.CompletionAppGenerator") as mock_completion_generator,
|
||||
patch("services.app_generate_service.ChatAppGenerator") as mock_chat_generator,
|
||||
patch("services.app_generate_service.AgentChatAppGenerator") as mock_agent_chat_generator,
|
||||
patch("services.app_generate_service.AdvancedChatAppGenerator") as mock_advanced_chat_generator,
|
||||
patch("services.app_generate_service.WorkflowAppGenerator") as mock_workflow_generator,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
patch("services.app_generate_service.dify_config") as mock_dify_config,
|
||||
patch("configs.dify_config") as mock_global_dify_config,
|
||||
):
|
||||
# Setup default mock returns for billing service
|
||||
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
|
||||
"result": "success",
|
||||
"history_id": "test_history_id",
|
||||
}
|
||||
|
||||
# Setup default mock returns for workflow service
|
||||
mock_workflow_service_instance = mock_workflow_service.return_value
|
||||
mock_workflow_service_instance.get_published_workflow.return_value = MagicMock(spec=Workflow)
|
||||
mock_workflow_service_instance.get_draft_workflow.return_value = MagicMock(spec=Workflow)
|
||||
mock_workflow_service_instance.get_published_workflow_by_id.return_value = MagicMock(spec=Workflow)
|
||||
|
||||
# Setup default mock returns for rate limiting
|
||||
mock_rate_limit_instance = mock_rate_limit.return_value
|
||||
mock_rate_limit_instance.enter.return_value = "test_request_id"
|
||||
mock_rate_limit_instance.generate.return_value = ["test_response"]
|
||||
mock_rate_limit_instance.exit.return_value = None
|
||||
|
||||
# Setup default mock returns for app generators
|
||||
mock_completion_generator_instance = mock_completion_generator.return_value
|
||||
mock_completion_generator_instance.generate.return_value = ["completion_response"]
|
||||
mock_completion_generator_instance.generate_more_like_this.return_value = ["more_like_this_response"]
|
||||
mock_completion_generator.convert_to_event_stream.return_value = ["completion_stream"]
|
||||
|
||||
mock_chat_generator_instance = mock_chat_generator.return_value
|
||||
mock_chat_generator_instance.generate.return_value = ["chat_response"]
|
||||
mock_chat_generator.convert_to_event_stream.return_value = ["chat_stream"]
|
||||
|
||||
mock_agent_chat_generator_instance = mock_agent_chat_generator.return_value
|
||||
mock_agent_chat_generator_instance.generate.return_value = ["agent_chat_response"]
|
||||
mock_agent_chat_generator.convert_to_event_stream.return_value = ["agent_chat_stream"]
|
||||
|
||||
mock_advanced_chat_generator_instance = mock_advanced_chat_generator.return_value
|
||||
mock_advanced_chat_generator_instance.generate.return_value = ["advanced_chat_response"]
|
||||
mock_advanced_chat_generator_instance.single_iteration_generate.return_value = ["single_iteration_response"]
|
||||
mock_advanced_chat_generator_instance.single_loop_generate.return_value = ["single_loop_response"]
|
||||
mock_advanced_chat_generator.convert_to_event_stream.return_value = ["advanced_chat_stream"]
|
||||
|
||||
mock_workflow_generator_instance = mock_workflow_generator.return_value
|
||||
mock_workflow_generator_instance.generate.return_value = ["workflow_response"]
|
||||
mock_workflow_generator_instance.single_iteration_generate.return_value = [
|
||||
"workflow_single_iteration_response"
|
||||
]
|
||||
mock_workflow_generator_instance.single_loop_generate.return_value = ["workflow_single_loop_response"]
|
||||
mock_workflow_generator.convert_to_event_stream.return_value = ["workflow_stream"]
|
||||
|
||||
# Setup default mock returns for account service
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Setup dify_config mock returns
|
||||
mock_dify_config.BILLING_ENABLED = False
|
||||
mock_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
|
||||
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
|
||||
|
||||
mock_global_dify_config.BILLING_ENABLED = False
|
||||
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
|
||||
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
|
||||
|
||||
yield {
|
||||
"billing_service": mock_billing_service,
|
||||
"workflow_service": mock_workflow_service,
|
||||
"rate_limit": mock_rate_limit,
|
||||
"completion_generator": mock_completion_generator,
|
||||
"chat_generator": mock_chat_generator,
|
||||
"agent_chat_generator": mock_agent_chat_generator,
|
||||
"advanced_chat_generator": mock_advanced_chat_generator,
|
||||
"workflow_generator": mock_workflow_generator,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"dify_config": mock_dify_config,
|
||||
"global_dify_config": mock_global_dify_config,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies, mode="chat"):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
mode: App mode to create
|
||||
|
||||
Returns:
|
||||
tuple: (app, account) - Created app and account instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Setup mocks for account creation
|
||||
mock_external_service_dependencies[
|
||||
"account_feature_service"
|
||||
].get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Create account and tenant
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": mode,
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
"max_active_requests": 5,
|
||||
}
|
||||
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
return app, account
|
||||
|
||||
def _create_test_workflow(self, db_session_with_containers, app):
|
||||
"""
|
||||
Helper method to create a test workflow for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app: App instance
|
||||
|
||||
Returns:
|
||||
Workflow: Created workflow instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
app_id=app.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
type="workflow",
|
||||
status="published",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
|
||||
return workflow
|
||||
|
||||
def test_generate_completion_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful generation for completion mode app.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="completion"
|
||||
)
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify rate limiting was called
|
||||
mock_external_service_dependencies["rate_limit"].return_value.enter.assert_called_once()
|
||||
mock_external_service_dependencies["rate_limit"].return_value.generate.assert_called_once()
|
||||
|
||||
# Verify completion generator was called
|
||||
mock_external_service_dependencies["completion_generator"].return_value.generate.assert_called_once()
|
||||
mock_external_service_dependencies["completion_generator"].convert_to_event_stream.assert_called_once()
|
||||
|
||||
def test_generate_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful generation for chat mode app.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="chat"
|
||||
)
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify chat generator was called
|
||||
mock_external_service_dependencies["chat_generator"].return_value.generate.assert_called_once()
|
||||
mock_external_service_dependencies["chat_generator"].convert_to_event_stream.assert_called_once()
|
||||
|
||||
def test_generate_agent_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful generation for agent chat mode app.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="agent-chat"
|
||||
)
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify agent chat generator was called
|
||||
mock_external_service_dependencies["agent_chat_generator"].return_value.generate.assert_called_once()
|
||||
mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once()
|
||||
|
||||
def test_generate_advanced_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful generation for advanced chat mode app.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
|
||||
)
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify advanced chat generator was called
|
||||
mock_external_service_dependencies["advanced_chat_generator"].return_value.generate.assert_called_once()
|
||||
mock_external_service_dependencies["advanced_chat_generator"].convert_to_event_stream.assert_called_once()
|
||||
|
||||
def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful generation for workflow mode app.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="workflow"
|
||||
)
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify workflow generator was called
|
||||
mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once()
|
||||
mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once()
|
||||
|
||||
def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test generation with a specific workflow ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
|
||||
)
|
||||
|
||||
workflow_id = str(uuid.uuid4())
|
||||
|
||||
# Setup test arguments
|
||||
args = {
|
||||
"inputs": {"query": fake.text(max_nb_chars=50)},
|
||||
"workflow_id": workflow_id,
|
||||
"response_mode": "streaming",
|
||||
}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify workflow service was called with specific workflow ID
|
||||
mock_external_service_dependencies[
|
||||
"workflow_service"
|
||||
].return_value.get_published_workflow_by_id.assert_called_once()
|
||||
|
||||
def test_generate_with_debugger_invoke_from(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test generation with debugger invoke from.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
|
||||
)
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify draft workflow was fetched for debugger
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once()
|
||||
|
||||
def test_generate_with_non_streaming_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test generation with non-streaming mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="completion"
|
||||
)
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "blocking"}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=False
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify rate limit exit was called for non-streaming mode
|
||||
mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once()
|
||||
|
||||
def test_generate_with_end_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test generation with EndUser instead of Account.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="completion"
|
||||
)
|
||||
|
||||
# Create end user
|
||||
end_user = EndUser(
|
||||
tenant_id=account.current_tenant.id,
|
||||
app_id=app.id,
|
||||
type="normal",
|
||||
external_user_id=fake.uuid4(),
|
||||
name=fake.name(),
|
||||
is_anonymous=False,
|
||||
session_id=fake.uuid4(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
def test_generate_with_billing_enabled_sandbox_plan(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation with billing enabled and sandbox plan.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="completion"
|
||||
)
|
||||
|
||||
# Set BILLING_ENABLED to True for this test
|
||||
mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True
|
||||
mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify billing service was called to consume quota
|
||||
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
|
||||
|
||||
def test_generate_with_invalid_app_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test generation with invalid app mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="chat"
|
||||
)
|
||||
|
||||
# Manually set invalid mode after creation
|
||||
app.mode = "invalid_mode"
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test and expect ValueError
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert "Invalid app mode" in str(exc_info.value)
|
||||
|
||||
def test_generate_with_workflow_id_format_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation with invalid workflow ID format.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
|
||||
)
|
||||
|
||||
# Setup test arguments with invalid workflow ID
|
||||
args = {
|
||||
"inputs": {"query": fake.text(max_nb_chars=50)},
|
||||
"workflow_id": "invalid_uuid",
|
||||
"response_mode": "streaming",
|
||||
}
|
||||
|
||||
# Execute the method under test and expect WorkflowIdFormatError
|
||||
with pytest.raises(WorkflowIdFormatError) as exc_info:
|
||||
AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert "Invalid workflow_id format" in str(exc_info.value)
|
||||
|
||||
def test_generate_with_workflow_not_found_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation when workflow is not found.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
|
||||
)
|
||||
|
||||
workflow_id = str(uuid.uuid4())
|
||||
|
||||
# Setup workflow service to return None (workflow not found)
|
||||
mock_external_service_dependencies[
|
||||
"workflow_service"
|
||||
].return_value.get_published_workflow_by_id.return_value = None
|
||||
|
||||
# Setup test arguments
|
||||
args = {
|
||||
"inputs": {"query": fake.text(max_nb_chars=50)},
|
||||
"workflow_id": workflow_id,
|
||||
"response_mode": "streaming",
|
||||
}
|
||||
|
||||
# Execute the method under test and expect WorkflowNotFoundError
|
||||
with pytest.raises(WorkflowNotFoundError) as exc_info:
|
||||
AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert f"Workflow not found with id: {workflow_id}" in str(exc_info.value)
|
||||
|
||||
def test_generate_with_workflow_not_initialized_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation when workflow is not initialized for debugger.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
|
||||
)
|
||||
|
||||
# Setup workflow service to return None (workflow not initialized)
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.return_value = None
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test and expect ValueError
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert "Workflow not initialized" in str(exc_info.value)
|
||||
|
||||
def test_generate_with_workflow_not_published_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation when workflow is not published for non-debugger.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
|
||||
)
|
||||
|
||||
# Setup workflow service to return None (workflow not published)
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test and expect ValueError
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert "Workflow not published" in str(exc_info.value)
|
||||
|
||||
def test_generate_single_iteration_advanced_chat_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful single iteration generation for advanced chat mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
|
||||
)
|
||||
|
||||
node_id = fake.uuid4()
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate_single_iteration(
|
||||
app_model=app, user=account, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["advanced_chat_stream"]
|
||||
|
||||
# Verify advanced chat generator was called
|
||||
mock_external_service_dependencies[
|
||||
"advanced_chat_generator"
|
||||
].return_value.single_iteration_generate.assert_called_once()
|
||||
|
||||
def test_generate_single_iteration_workflow_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful single iteration generation for workflow mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="workflow"
|
||||
)
|
||||
|
||||
node_id = fake.uuid4()
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate_single_iteration(
|
||||
app_model=app, user=account, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["advanced_chat_stream"]
|
||||
|
||||
# Verify workflow generator was called
|
||||
mock_external_service_dependencies[
|
||||
"workflow_generator"
|
||||
].return_value.single_iteration_generate.assert_called_once()
|
||||
|
||||
def test_generate_single_iteration_invalid_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test single iteration generation with invalid app mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="completion"
|
||||
)
|
||||
|
||||
node_id = fake.uuid4()
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}}
|
||||
|
||||
# Execute the method under test and expect ValueError
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AppGenerateService.generate_single_iteration(
|
||||
app_model=app, user=account, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert "Invalid app mode" in str(exc_info.value)
|
||||
|
||||
def test_generate_single_loop_advanced_chat_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful single loop generation for advanced chat mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
|
||||
)
|
||||
|
||||
node_id = fake.uuid4()
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate_single_loop(
|
||||
app_model=app, user=account, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["advanced_chat_stream"]
|
||||
|
||||
# Verify advanced chat generator was called
|
||||
mock_external_service_dependencies[
|
||||
"advanced_chat_generator"
|
||||
].return_value.single_loop_generate.assert_called_once()
|
||||
|
||||
def test_generate_single_loop_workflow_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful single loop generation for workflow mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="workflow"
|
||||
)
|
||||
|
||||
node_id = fake.uuid4()
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate_single_loop(
|
||||
app_model=app, user=account, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["advanced_chat_stream"]
|
||||
|
||||
# Verify workflow generator was called
|
||||
mock_external_service_dependencies["workflow_generator"].return_value.single_loop_generate.assert_called_once()
|
||||
|
||||
def test_generate_single_loop_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test single loop generation with invalid app mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="completion"
|
||||
)
|
||||
|
||||
node_id = fake.uuid4()
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}}
|
||||
|
||||
# Execute the method under test and expect ValueError
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AppGenerateService.generate_single_loop(
|
||||
app_model=app, user=account, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert "Invalid app mode" in str(exc_info.value)
|
||||
|
||||
def test_generate_more_like_this_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful more like this generation.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="completion"
|
||||
)
|
||||
|
||||
message_id = fake.uuid4()
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate_more_like_this(
|
||||
app_model=app, user=account, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["more_like_this_response"]
|
||||
|
||||
# Verify completion generator was called
|
||||
mock_external_service_dependencies[
|
||||
"completion_generator"
|
||||
].return_value.generate_more_like_this.assert_called_once()
|
||||
|
||||
def test_generate_more_like_this_with_end_user(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test more like this generation with EndUser.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="completion"
|
||||
)
|
||||
|
||||
# Create end user
|
||||
end_user = EndUser(
|
||||
tenant_id=account.current_tenant.id,
|
||||
app_id=app.id,
|
||||
type="normal",
|
||||
external_user_id=fake.uuid4(),
|
||||
name=fake.name(),
|
||||
is_anonymous=False,
|
||||
session_id=fake.uuid4(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
message_id = fake.uuid4()
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate_more_like_this(
|
||||
app_model=app, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["more_like_this_response"]
|
||||
|
||||
def test_get_max_active_requests_with_app_limit(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting max active requests with app-specific limit.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="completion"
|
||||
)
|
||||
|
||||
# Set app-specific limit
|
||||
app.max_active_requests = 10
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService._get_max_active_requests(app)
|
||||
|
||||
# Verify the result (should return the smaller value between app limit and config limit)
|
||||
assert result == 10
|
||||
|
||||
def test_get_max_active_requests_with_config_limit(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting max active requests with config limit being smaller.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="completion"
|
||||
)
|
||||
|
||||
# Set app-specific limit higher than config
|
||||
app.max_active_requests = 100
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService._get_max_active_requests(app)
|
||||
|
||||
# Verify the result (should return the smaller value)
|
||||
# Assuming config limit is smaller than 100
|
||||
assert result <= 100
|
||||
|
||||
def test_get_max_active_requests_with_zero_limits(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting max active requests with zero limits (infinite).
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="completion"
|
||||
)
|
||||
|
||||
# Set app-specific limit to 0 (infinite)
|
||||
app.max_active_requests = 0
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService._get_max_active_requests(app)
|
||||
|
||||
# Verify the result (should return config limit when app limit is 0)
|
||||
assert result == 100 # dify_config.APP_MAX_ACTIVE_REQUESTS
|
||||
|
||||
def test_generate_with_exception_cleanup(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test that rate limit exit is called when an exception occurs.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="completion"
|
||||
)
|
||||
|
||||
# Setup completion generator to raise an exception
|
||||
mock_external_service_dependencies["completion_generator"].return_value.generate.side_effect = Exception(
|
||||
"Test exception"
|
||||
)
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test and expect exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify exception message
|
||||
assert "Test exception" in str(exc_info.value)
|
||||
|
||||
# Verify rate limit exit was called for cleanup
|
||||
mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once()
|
||||
|
||||
def test_generate_with_agent_mode_detection(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test generation with agent mode detection based on app configuration.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="chat"
|
||||
)
|
||||
|
||||
# Mock app to have agent mode enabled by setting the mode directly
|
||||
app.mode = "agent-chat"
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify agent chat generator was called instead of regular chat generator
|
||||
mock_external_service_dependencies["agent_chat_generator"].return_value.generate.assert_called_once()
|
||||
mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once()
|
||||
|
||||
def test_generate_with_different_invoke_from_values(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation with different invoke from values.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat"
|
||||
)
|
||||
|
||||
# Test different invoke from values
|
||||
invoke_from_values = [
|
||||
InvokeFrom.SERVICE_API,
|
||||
InvokeFrom.WEB_APP,
|
||||
InvokeFrom.EXPLORE,
|
||||
InvokeFrom.DEBUGGER,
|
||||
]
|
||||
|
||||
for invoke_from in invoke_from_values:
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=invoke_from, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
def test_generate_with_complex_args(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test generation with complex arguments including files and external trace ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies, mode="workflow"
|
||||
)
|
||||
|
||||
# Setup complex test arguments
|
||||
args = {
|
||||
"inputs": {
|
||||
"query": fake.text(max_nb_chars=50),
|
||||
"context": fake.text(max_nb_chars=100),
|
||||
"parameters": {"temperature": 0.7, "max_tokens": 1000},
|
||||
},
|
||||
"files": [
|
||||
{"id": fake.uuid4(), "name": "test_file.txt", "size": 1024},
|
||||
{"id": fake.uuid4(), "name": "test_image.jpg", "size": 2048},
|
||||
],
|
||||
"external_trace_id": fake.uuid4(),
|
||||
"response_mode": "streaming",
|
||||
}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify workflow generator was called with complex args
|
||||
mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once()
|
||||
call_args = mock_external_service_dependencies["workflow_generator"].return_value.generate.call_args
|
||||
assert call_args[1]["args"] == args
|
||||
@@ -0,0 +1,954 @@
|
||||
from unittest.mock import create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from constants.model_template import default_app_templates
|
||||
from models import Account
|
||||
from models.model import App, Site
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class TestAppService:
|
||||
"""Integration tests for AppService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.app_service.FeatureService") as mock_feature_service,
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
):
|
||||
# Setup default mock returns for app service
|
||||
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
|
||||
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
|
||||
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
|
||||
|
||||
# Setup default mock returns for account service
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Mock ModelManager for model configuration
|
||||
mock_model_instance = mock_model_manager.return_value
|
||||
mock_model_instance.get_default_model_instance.return_value = None
|
||||
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||
|
||||
yield {
|
||||
"feature_service": mock_feature_service,
|
||||
"enterprise_service": mock_enterprise_service,
|
||||
"model_manager": mock_model_manager,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app creation with basic parameters.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Create app
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Verify app was created correctly
|
||||
assert app.name == app_args["name"]
|
||||
assert app.description == app_args["description"]
|
||||
assert app.mode == app_args["mode"]
|
||||
assert app.icon_type == app_args["icon_type"]
|
||||
assert app.icon == app_args["icon"]
|
||||
assert app.icon_background == app_args["icon_background"]
|
||||
assert app.tenant_id == tenant.id
|
||||
assert app.api_rph == app_args["api_rph"]
|
||||
assert app.api_rpm == app_args["api_rpm"]
|
||||
assert app.created_by == account.id
|
||||
assert app.updated_by == account.id
|
||||
assert app.status == "normal"
|
||||
assert app.enable_site is True
|
||||
assert app.enable_api is True
|
||||
assert app.is_demo is False
|
||||
assert app.is_public is False
|
||||
assert app.is_universal is False
|
||||
|
||||
def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app creation with different app modes.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Test different app modes
|
||||
# from AppMode enum in default_app_model_template
|
||||
app_modes = [v.value for v in default_app_templates]
|
||||
|
||||
for mode in app_modes:
|
||||
app_args = {
|
||||
"name": f"{fake.company()} {mode}",
|
||||
"description": f"Test app for {mode} mode",
|
||||
"mode": mode,
|
||||
"icon_type": "emoji",
|
||||
"icon": "🚀",
|
||||
"icon_background": "#4ECDC4",
|
||||
}
|
||||
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Verify app mode was set correctly
|
||||
assert app.mode == mode
|
||||
assert app.name == app_args["name"]
|
||||
assert app.tenant_id == tenant.id
|
||||
assert app.created_by == account.id
|
||||
|
||||
def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app retrieval.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
created_app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Get app using the service - needs current_user mock
|
||||
mock_current_user = create_autospec(Account, instance=True)
|
||||
mock_current_user.id = account.id
|
||||
mock_current_user.current_tenant_id = account.current_tenant_id
|
||||
|
||||
with patch("services.app_service.current_user", mock_current_user):
|
||||
retrieved_app = app_service.get_app(created_app)
|
||||
|
||||
# Verify retrieved app matches created app
|
||||
assert retrieved_app.id == created_app.id
|
||||
assert retrieved_app.name == created_app.name
|
||||
assert retrieved_app.description == created_app.description
|
||||
assert retrieved_app.mode == created_app.mode
|
||||
assert retrieved_app.tenant_id == created_app.tenant_id
|
||||
assert retrieved_app.created_by == created_app.created_by
|
||||
|
||||
def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful paginated app list retrieval.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create multiple apps
|
||||
app_names = [fake.company() for _ in range(5)]
|
||||
for name in app_names:
|
||||
app_args = {
|
||||
"name": name,
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "📱",
|
||||
"icon_background": "#96CEB4",
|
||||
}
|
||||
app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Get paginated apps
|
||||
args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
}
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
|
||||
# Verify pagination results
|
||||
assert paginated_apps is not None
|
||||
assert len(paginated_apps.items) >= 5 # Should have at least 5 apps
|
||||
assert paginated_apps.page == 1
|
||||
assert paginated_apps.per_page == 10
|
||||
|
||||
# Verify all apps belong to the correct tenant
|
||||
for app in paginated_apps.items:
|
||||
assert app.tenant_id == tenant.id
|
||||
assert app.mode == "chat"
|
||||
|
||||
def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test paginated app list with various filters.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create apps with different modes
|
||||
chat_app_args = {
|
||||
"name": "Chat App",
|
||||
"description": "A chat application",
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "💬",
|
||||
"icon_background": "#FF6B6B",
|
||||
}
|
||||
completion_app_args = {
|
||||
"name": "Completion App",
|
||||
"description": "A completion application",
|
||||
"mode": "completion",
|
||||
"icon_type": "emoji",
|
||||
"icon": "✍️",
|
||||
"icon_background": "#4ECDC4",
|
||||
}
|
||||
|
||||
chat_app = app_service.create_app(tenant.id, chat_app_args, account)
|
||||
completion_app = app_service.create_app(tenant.id, completion_app_args, account)
|
||||
|
||||
# Test filter by mode
|
||||
chat_args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
}
|
||||
chat_apps = app_service.get_paginate_apps(account.id, tenant.id, chat_args)
|
||||
assert len(chat_apps.items) == 1
|
||||
assert chat_apps.items[0].mode == "chat"
|
||||
|
||||
# Test filter by name
|
||||
name_args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
"name": "Chat",
|
||||
}
|
||||
filtered_apps = app_service.get_paginate_apps(account.id, tenant.id, name_args)
|
||||
assert len(filtered_apps.items) == 1
|
||||
assert "Chat" in filtered_apps.items[0].name
|
||||
|
||||
# Test filter by created_by_me
|
||||
created_by_me_args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "completion",
|
||||
"is_created_by_me": True,
|
||||
}
|
||||
my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args)
|
||||
assert len(my_apps.items) == 1
|
||||
|
||||
def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test paginated app list with tag filters.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create an app
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🏷️",
|
||||
"icon_background": "#FFEAA7",
|
||||
}
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Mock TagService to return the app ID for tag filtering
|
||||
with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service:
|
||||
mock_tag_service.return_value = [app.id]
|
||||
|
||||
# Test with tag filter
|
||||
args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
"tag_ids": ["tag1", "tag2"],
|
||||
}
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
|
||||
# Verify tag service was called
|
||||
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"])
|
||||
|
||||
# Verify results
|
||||
assert paginated_apps is not None
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].id == app.id
|
||||
|
||||
# Test with tag filter that returns no results
|
||||
with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service:
|
||||
mock_tag_service.return_value = []
|
||||
|
||||
args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
"tag_ids": ["nonexistent_tag"],
|
||||
}
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
|
||||
# Should return None when no apps match tag filter
|
||||
assert paginated_apps is None
|
||||
|
||||
def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app update with all fields.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original values
|
||||
original_name = app.name
|
||||
original_description = app.description
|
||||
original_icon = app.icon
|
||||
original_icon_background = app.icon_background
|
||||
original_use_icon_as_answer_icon = app.use_icon_as_answer_icon
|
||||
|
||||
# Update app
|
||||
update_args = {
|
||||
"name": "Updated App Name",
|
||||
"description": "Updated app description",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔄",
|
||||
"icon_background": "#FF8C42",
|
||||
"use_icon_as_answer_icon": True,
|
||||
}
|
||||
|
||||
mock_current_user = create_autospec(Account, instance=True)
|
||||
mock_current_user.id = account.id
|
||||
mock_current_user.current_tenant_id = account.current_tenant_id
|
||||
|
||||
with patch("services.app_service.current_user", mock_current_user):
|
||||
updated_app = app_service.update_app(app, update_args)
|
||||
|
||||
# Verify updated fields
|
||||
assert updated_app.name == update_args["name"]
|
||||
assert updated_app.description == update_args["description"]
|
||||
assert updated_app.icon == update_args["icon"]
|
||||
assert updated_app.icon_background == update_args["icon_background"]
|
||||
assert updated_app.use_icon_as_answer_icon is True
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app name update.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original name
|
||||
original_name = app.name
|
||||
|
||||
# Update app name
|
||||
new_name = "New App Name"
|
||||
mock_current_user = create_autospec(Account, instance=True)
|
||||
mock_current_user.id = account.id
|
||||
mock_current_user.current_tenant_id = account.current_tenant_id
|
||||
|
||||
with patch("services.app_service.current_user", mock_current_user):
|
||||
updated_app = app_service.update_app_name(app, new_name)
|
||||
|
||||
assert updated_app.name == new_name
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app icon update.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original values
|
||||
original_icon = app.icon
|
||||
original_icon_background = app.icon_background
|
||||
|
||||
# Update app icon
|
||||
new_icon = "🌟"
|
||||
new_icon_background = "#FFD93D"
|
||||
mock_current_user = create_autospec(Account, instance=True)
|
||||
mock_current_user.id = account.id
|
||||
mock_current_user.current_tenant_id = account.current_tenant_id
|
||||
|
||||
with patch("services.app_service.current_user", mock_current_user):
|
||||
updated_app = app_service.update_app_icon(app, new_icon, new_icon_background)
|
||||
|
||||
assert updated_app.icon == new_icon
|
||||
assert updated_app.icon_background == new_icon_background
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.name == app.name
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app site status update.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🌐",
|
||||
"icon_background": "#74B9FF",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original site status
|
||||
original_site_status = app.enable_site
|
||||
|
||||
# Update site status to disabled
|
||||
mock_current_user = create_autospec(Account, instance=True)
|
||||
mock_current_user.id = account.id
|
||||
mock_current_user.current_tenant_id = account.current_tenant_id
|
||||
|
||||
with patch("services.app_service.current_user", mock_current_user):
|
||||
updated_app = app_service.update_app_site_status(app, False)
|
||||
assert updated_app.enable_site is False
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Update site status back to enabled
|
||||
with patch("services.app_service.current_user", mock_current_user):
|
||||
updated_app = app_service.update_app_site_status(updated_app, True)
|
||||
assert updated_app.enable_site is True
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.name == app.name
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app API status update.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔌",
|
||||
"icon_background": "#A29BFE",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original API status
|
||||
original_api_status = app.enable_api
|
||||
|
||||
# Update API status to disabled
|
||||
mock_current_user = create_autospec(Account, instance=True)
|
||||
mock_current_user.id = account.id
|
||||
mock_current_user.current_tenant_id = account.current_tenant_id
|
||||
|
||||
with patch("services.app_service.current_user", mock_current_user):
|
||||
updated_app = app_service.update_app_api_status(app, False)
|
||||
assert updated_app.enable_api is False
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Update API status back to enabled
|
||||
with patch("services.app_service.current_user", mock_current_user):
|
||||
updated_app = app_service.update_app_api_status(updated_app, True)
|
||||
assert updated_app.enable_api is True
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.name == app.name
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app site status update when status doesn't change.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔄",
|
||||
"icon_background": "#FD79A8",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original values
|
||||
original_site_status = app.enable_site
|
||||
original_updated_at = app.updated_at
|
||||
|
||||
# Update site status to the same value (no change)
|
||||
updated_app = app_service.update_app_site_status(app, original_site_status)
|
||||
|
||||
# Verify app is returned unchanged
|
||||
assert updated_app.id == app.id
|
||||
assert updated_app.enable_site == original_site_status
|
||||
assert updated_app.updated_at == original_updated_at
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.name == app.name
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app deletion.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🗑️",
|
||||
"icon_background": "#E17055",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store app ID for verification
|
||||
app_id = app.id
|
||||
|
||||
# Mock the async deletion task
|
||||
with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task:
|
||||
mock_delete_task.delay.return_value = None
|
||||
|
||||
# Delete app
|
||||
app_service.delete_app(app)
|
||||
|
||||
# Verify async deletion task was called
|
||||
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
|
||||
|
||||
# Verify app was deleted from database
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_app = db.session.query(App).filter_by(id=app_id).first()
|
||||
assert deleted_app is None
|
||||
|
||||
def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app deletion with related data cleanup.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🧹",
|
||||
"icon_background": "#00B894",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store app ID for verification
|
||||
app_id = app.id
|
||||
|
||||
# Mock webapp auth cleanup
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
].get_system_features.return_value.webapp_auth.enabled = True
|
||||
|
||||
# Mock the async deletion task
|
||||
with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task:
|
||||
mock_delete_task.delay.return_value = None
|
||||
|
||||
# Delete app
|
||||
app_service.delete_app(app)
|
||||
|
||||
# Verify webapp auth cleanup was called
|
||||
mock_external_service_dependencies["enterprise_service"].WebAppAuth.cleanup_webapp.assert_called_once_with(
|
||||
app_id
|
||||
)
|
||||
|
||||
# Verify async deletion task was called
|
||||
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
|
||||
|
||||
# Verify app was deleted from database
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_app = db.session.query(App).filter_by(id=app_id).first()
|
||||
assert deleted_app is None
|
||||
|
||||
def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app metadata retrieval.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "📊",
|
||||
"icon_background": "#6C5CE7",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Get app metadata
|
||||
app_meta = app_service.get_app_meta(app)
|
||||
|
||||
# Verify metadata contains expected fields
|
||||
assert "tool_icons" in app_meta
|
||||
# Note: get_app_meta currently only returns tool_icons
|
||||
|
||||
def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app code retrieval by app ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔗",
|
||||
"icon_background": "#FDCB6E",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Get app code by ID
|
||||
app_code = AppService.get_app_code_by_id(app.id)
|
||||
|
||||
# Verify app code was retrieved correctly
|
||||
# Note: Site would be created when App is created, site.code is auto-generated
|
||||
assert app_code is not None
|
||||
assert len(app_code) > 0
|
||||
|
||||
def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app ID retrieval by app code.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🆔",
|
||||
"icon_background": "#E84393",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Create a site for the app
|
||||
site = Site()
|
||||
site.app_id = app.id
|
||||
site.code = fake.postalcode()
|
||||
site.title = fake.company()
|
||||
site.status = "normal"
|
||||
site.default_language = "en-US"
|
||||
site.customize_token_strategy = "uuid"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(site)
|
||||
db.session.commit()
|
||||
|
||||
# Get app ID by code
|
||||
app_id = AppService.get_app_id_by_code(site.code)
|
||||
|
||||
# Verify app ID was retrieved correctly
|
||||
assert app_id == app.id
|
||||
|
||||
def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app creation with invalid mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments with invalid mode
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "invalid_mode", # Invalid mode
|
||||
"icon_type": "emoji",
|
||||
"icon": "❌",
|
||||
"icon_background": "#D63031",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Attempt to create app with invalid mode
|
||||
with pytest.raises(ValueError, match="invalid mode value"):
|
||||
app_service.create_app(tenant.id, app_args, account)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,386 @@
|
||||
"""Unit tests for FeedbackService."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
|
||||
class TestFeedbackService:
|
||||
"""Test FeedbackService methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self, monkeypatch):
|
||||
"""Mock database session."""
|
||||
mock_session = mock.Mock()
|
||||
monkeypatch.setattr(db, "session", mock_session)
|
||||
return mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
"""Create sample data for testing."""
|
||||
app_id = "test-app-id"
|
||||
|
||||
# Create mock models
|
||||
app = App(id=app_id, name="Test App")
|
||||
|
||||
conversation = Conversation(id="test-conversation-id", app_id=app_id, name="Test Conversation")
|
||||
|
||||
message = Message(
|
||||
id="test-message-id",
|
||||
conversation_id="test-conversation-id",
|
||||
query="What is AI?",
|
||||
answer="AI is artificial intelligence.",
|
||||
inputs={"query": "What is AI?"},
|
||||
created_at=datetime(2024, 1, 1, 10, 0, 0),
|
||||
)
|
||||
|
||||
# Use SimpleNamespace to avoid ORM model constructor issues
|
||||
user_feedback = SimpleNamespace(
|
||||
id="user-feedback-id",
|
||||
app_id=app_id,
|
||||
conversation_id="test-conversation-id",
|
||||
message_id="test-message-id",
|
||||
rating="like",
|
||||
from_source="user",
|
||||
content="Great answer!",
|
||||
from_end_user_id="user-123",
|
||||
from_account_id=None,
|
||||
from_account=None, # Mock account object
|
||||
created_at=datetime(2024, 1, 1, 10, 5, 0),
|
||||
)
|
||||
|
||||
admin_feedback = SimpleNamespace(
|
||||
id="admin-feedback-id",
|
||||
app_id=app_id,
|
||||
conversation_id="test-conversation-id",
|
||||
message_id="test-message-id",
|
||||
rating="dislike",
|
||||
from_source="admin",
|
||||
content="Could be more detailed",
|
||||
from_end_user_id=None,
|
||||
from_account_id="admin-456",
|
||||
from_account=SimpleNamespace(name="Admin User"), # Mock account object
|
||||
created_at=datetime(2024, 1, 1, 10, 10, 0),
|
||||
)
|
||||
|
||||
return {
|
||||
"app": app,
|
||||
"conversation": conversation,
|
||||
"message": message,
|
||||
"user_feedback": user_feedback,
|
||||
"admin_feedback": admin_feedback,
|
||||
}
|
||||
|
||||
def test_export_feedbacks_csv_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback data in CSV format."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Test CSV export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
|
||||
# Verify response structure
|
||||
assert hasattr(result, "headers")
|
||||
assert "text/csv" in result.headers["Content-Type"]
|
||||
assert "attachment" in result.headers["Content-Disposition"]
|
||||
|
||||
# Check CSV content
|
||||
csv_content = result.get_data(as_text=True)
|
||||
# Verify essential headers exist (order may include additional columns)
|
||||
assert "feedback_id" in csv_content
|
||||
assert "app_name" in csv_content
|
||||
assert "conversation_id" in csv_content
|
||||
assert sample_data["app"].name in csv_content
|
||||
assert sample_data["message"].query in csv_content
|
||||
|
||||
def test_export_feedbacks_json_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback data in JSON format."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Test JSON export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
|
||||
# Verify response structure
|
||||
assert hasattr(result, "headers")
|
||||
assert "application/json" in result.headers["Content-Type"]
|
||||
assert "attachment" in result.headers["Content-Disposition"]
|
||||
|
||||
# Check JSON content
|
||||
json_content = json.loads(result.get_data(as_text=True))
|
||||
assert "export_info" in json_content
|
||||
assert "feedback_data" in json_content
|
||||
assert json_content["export_info"]["app_id"] == sample_data["app"].id
|
||||
assert json_content["export_info"]["total_records"] == 1
|
||||
|
||||
def test_export_feedbacks_with_filters(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback with various filters."""
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Test with filters
|
||||
result = FeedbackService.export_feedbacks(
|
||||
app_id=sample_data["app"].id,
|
||||
from_source="admin",
|
||||
rating="dislike",
|
||||
has_comment=True,
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
format_type="csv",
|
||||
)
|
||||
|
||||
# Verify filters were applied
|
||||
assert mock_query.filter.called
|
||||
filter_calls = mock_query.filter.call_args_list
|
||||
# At least three filter invocations are expected (source, rating, comment)
|
||||
assert len(filter_calls) >= 3
|
||||
|
||||
def test_export_feedbacks_no_data(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback when no data exists."""
|
||||
|
||||
# Setup mock query result with no data
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
|
||||
# Should return an empty CSV with headers only
|
||||
assert hasattr(result, "headers")
|
||||
assert "text/csv" in result.headers["Content-Type"]
|
||||
csv_content = result.get_data(as_text=True)
|
||||
# Headers should exist (order can include additional columns)
|
||||
assert "feedback_id" in csv_content
|
||||
assert "app_name" in csv_content
|
||||
assert "conversation_id" in csv_content
|
||||
# No data rows expected
|
||||
assert len([line for line in csv_content.strip().splitlines() if line.strip()]) == 1
|
||||
|
||||
def test_export_feedbacks_invalid_date_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback with invalid date format."""
|
||||
|
||||
# Test with invalid start_date
|
||||
with pytest.raises(ValueError, match="Invalid start_date format"):
|
||||
FeedbackService.export_feedbacks(app_id=sample_data["app"].id, start_date="invalid-date-format")
|
||||
|
||||
# Test with invalid end_date
|
||||
with pytest.raises(ValueError, match="Invalid end_date format"):
|
||||
FeedbackService.export_feedbacks(app_id=sample_data["app"].id, end_date="invalid-date-format")
|
||||
|
||||
def test_export_feedbacks_invalid_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback with unsupported format."""
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported format"):
|
||||
FeedbackService.export_feedbacks(
|
||||
app_id=sample_data["app"].id,
|
||||
format_type="xml", # Unsupported format
|
||||
)
|
||||
|
||||
def test_export_feedbacks_long_response_truncation(self, mock_db_session, sample_data):
|
||||
"""Test that long AI responses are truncated in export."""
|
||||
|
||||
# Create message with long response
|
||||
long_message = Message(
|
||||
id="long-message-id",
|
||||
conversation_id="test-conversation-id",
|
||||
query="What is AI?",
|
||||
answer="A" * 600, # 600 character response
|
||||
inputs={"query": "What is AI?"},
|
||||
created_at=datetime(2024, 1, 1, 10, 0, 0),
|
||||
)
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
long_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
|
||||
# Check JSON content
|
||||
json_content = json.loads(result.get_data(as_text=True))
|
||||
exported_answer = json_content["feedback_data"][0]["ai_response"]
|
||||
|
||||
# Should be truncated with ellipsis
|
||||
assert len(exported_answer) <= 503 # 500 + "..."
|
||||
assert exported_answer.endswith("...")
|
||||
assert len(exported_answer) > 500 # Should be close to limit
|
||||
|
||||
def test_export_feedbacks_unicode_content(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback with unicode content (Chinese characters)."""
|
||||
|
||||
# Create feedback with Chinese content (use SimpleNamespace to avoid ORM constructor constraints)
|
||||
chinese_feedback = SimpleNamespace(
|
||||
id="chinese-feedback-id",
|
||||
app_id=sample_data["app"].id,
|
||||
conversation_id="test-conversation-id",
|
||||
message_id="test-message-id",
|
||||
rating="dislike",
|
||||
from_source="user",
|
||||
content="回答不够详细,需要更多信息",
|
||||
from_end_user_id="user-123",
|
||||
from_account_id=None,
|
||||
created_at=datetime(2024, 1, 1, 10, 5, 0),
|
||||
)
|
||||
|
||||
# Create Chinese message
|
||||
chinese_message = Message(
|
||||
id="chinese-message-id",
|
||||
conversation_id="test-conversation-id",
|
||||
query="什么是人工智能?",
|
||||
answer="人工智能是模拟人类智能的技术。",
|
||||
inputs={"query": "什么是人工智能?"},
|
||||
created_at=datetime(2024, 1, 1, 10, 0, 0),
|
||||
)
|
||||
|
||||
# Setup mock query result
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
chinese_feedback,
|
||||
chinese_message,
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
None, # No account for user feedback
|
||||
)
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
|
||||
# Check that unicode content is preserved
|
||||
csv_content = result.get_data(as_text=True)
|
||||
assert "什么是人工智能?" in csv_content
|
||||
assert "回答不够详细,需要更多信息" in csv_content
|
||||
assert "人工智能是模拟人类智能的技术" in csv_content
|
||||
|
||||
def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data):
|
||||
"""Test that rating emojis are properly formatted in export."""
|
||||
|
||||
# Setup mock query result with both like and dislike feedback
|
||||
mock_query = mock.Mock()
|
||||
mock_query.join.return_value = mock_query
|
||||
mock_query.outerjoin.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = [
|
||||
(
|
||||
sample_data["user_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["user_feedback"].from_account,
|
||||
),
|
||||
(
|
||||
sample_data["admin_feedback"],
|
||||
sample_data["message"],
|
||||
sample_data["conversation"],
|
||||
sample_data["app"],
|
||||
sample_data["admin_feedback"].from_account,
|
||||
),
|
||||
]
|
||||
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
|
||||
# Check JSON content for emoji ratings
|
||||
json_content = json.loads(result.get_data(as_text=True))
|
||||
feedback_data = json_content["feedback_data"]
|
||||
|
||||
# Should have both feedback records
|
||||
assert len(feedback_data) == 2
|
||||
|
||||
# Check that emojis are properly set
|
||||
like_feedback = next(f for f in feedback_data if f["feedback_rating_raw"] == "like")
|
||||
dislike_feedback = next(f for f in feedback_data if f["feedback_rating_raw"] == "dislike")
|
||||
|
||||
assert like_feedback["feedback_rating"] == "👍"
|
||||
assert dislike_feedback["feedback_rating"] == "👎"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,775 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models.model import MessageFeedback
|
||||
from services.app_service import AppService
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
LastMessageNotExistsError,
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
class TestMessageService:
|
||||
"""Integration tests for MessageService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
patch("services.message_service.ModelManager") as mock_model_manager,
|
||||
patch("services.message_service.WorkflowService") as mock_workflow_service,
|
||||
patch("services.message_service.AdvancedChatAppConfigManager") as mock_app_config_manager,
|
||||
patch("services.message_service.LLMGenerator") as mock_llm_generator,
|
||||
patch("services.message_service.TraceQueueManager") as mock_trace_manager_class,
|
||||
patch("services.message_service.TokenBufferMemory") as mock_token_buffer_memory,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_account_feature_service.get_features.return_value.billing.enabled = False
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_instance = mock_model_manager.return_value.get_default_model_instance.return_value
|
||||
mock_model_instance.get_tts_voices.return_value = [{"value": "test-voice"}]
|
||||
|
||||
# Mock get_model_instance method as well
|
||||
mock_model_manager.return_value.get_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Mock WorkflowService
|
||||
mock_workflow = mock_workflow_service.return_value.get_published_workflow.return_value
|
||||
mock_workflow_service.return_value.get_draft_workflow.return_value = mock_workflow
|
||||
|
||||
# Mock AdvancedChatAppConfigManager
|
||||
mock_app_config = mock_app_config_manager.get_app_config.return_value
|
||||
mock_app_config.additional_features.suggested_questions_after_answer = True
|
||||
|
||||
# Mock LLMGenerator
|
||||
mock_llm_generator.generate_suggested_questions_after_answer.return_value = ["Question 1", "Question 2"]
|
||||
|
||||
# Mock TraceQueueManager
|
||||
mock_trace_manager_instance = mock_trace_manager_class.return_value
|
||||
|
||||
# Mock TokenBufferMemory
|
||||
mock_memory_instance = mock_token_buffer_memory.return_value
|
||||
mock_memory_instance.get_history_prompt_text.return_value = "Mocked history prompt"
|
||||
|
||||
yield {
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"model_manager": mock_model_manager,
|
||||
"workflow_service": mock_workflow_service,
|
||||
"app_config_manager": mock_app_config_manager,
|
||||
"llm_generator": mock_llm_generator,
|
||||
"trace_manager_class": mock_trace_manager_class,
|
||||
"trace_manager_instance": mock_trace_manager_instance,
|
||||
"token_buffer_memory": mock_token_buffer_memory,
|
||||
# "current_user": mock_current_user,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (app, account) - Created app and account instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Setup mocks for account creation
|
||||
mock_external_service_dependencies[
|
||||
"account_feature_service"
|
||||
].get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Create account and tenant first
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "advanced-chat", # Use advanced-chat mode to use mocked workflow
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Create app
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Setup current_user mock
|
||||
self._mock_current_user(mock_external_service_dependencies, account.id, tenant.id)
|
||||
|
||||
return app, account
|
||||
|
||||
def _mock_current_user(self, mock_external_service_dependencies, account_id, tenant_id):
|
||||
"""
|
||||
Helper method to mock the current user for testing.
|
||||
"""
|
||||
# mock_external_service_dependencies["current_user"].id = account_id
|
||||
# mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id
|
||||
|
||||
def _create_test_conversation(self, app, account, fake):
|
||||
"""
|
||||
Helper method to create a test conversation with all required fields.
|
||||
"""
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
app_model_config_id=None,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
mode=app.mode,
|
||||
name=fake.sentence(),
|
||||
inputs={},
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
from_end_user_id=None,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.flush()
|
||||
return conversation
|
||||
|
||||
def _create_test_message(self, app, conversation, account, fake):
|
||||
"""
|
||||
Helper method to create a test message with all required fields.
|
||||
"""
|
||||
import json
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
model_provider=None,
|
||||
model_id="",
|
||||
override_model_configs=None,
|
||||
conversation_id=conversation.id,
|
||||
inputs={},
|
||||
query=fake.sentence(),
|
||||
message=json.dumps([{"role": "user", "text": fake.sentence()}]),
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0.001,
|
||||
answer=fake.text(max_nb_chars=200),
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0.001,
|
||||
parent_message_id=None,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
from_end_user_id=None,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
return message
|
||||
|
||||
def test_pagination_by_first_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful pagination by first ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and multiple messages
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
messages = []
|
||||
for i in range(5):
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
messages.append(message)
|
||||
|
||||
# Test pagination by first ID
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=app,
|
||||
user=account,
|
||||
conversation_id=conversation.id,
|
||||
first_id=messages[2].id, # Use middle message as first_id
|
||||
limit=2,
|
||||
order="asc",
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result.limit == 2
|
||||
assert len(result.data) == 2
|
||||
# total 5, from the middle, no more
|
||||
assert result.has_more is False
|
||||
# Verify messages are in ascending order
|
||||
assert result.data[0].created_at <= result.data[1].created_at
|
||||
|
||||
def test_pagination_by_first_id_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test pagination by first ID when no user is provided.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Test pagination with no user
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=app, user=None, conversation_id=fake.uuid4(), first_id=None, limit=10
|
||||
)
|
||||
|
||||
# Verify empty result
|
||||
assert result.limit == 10
|
||||
assert len(result.data) == 0
|
||||
assert result.has_more is False
|
||||
|
||||
def test_pagination_by_first_id_no_conversation_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by first ID when no conversation ID is provided.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Test pagination with no conversation ID
|
||||
result = MessageService.pagination_by_first_id(
|
||||
app_model=app, user=account, conversation_id="", first_id=None, limit=10
|
||||
)
|
||||
|
||||
# Verify empty result
|
||||
assert result.limit == 10
|
||||
assert len(result.data) == 0
|
||||
assert result.has_more is False
|
||||
|
||||
def test_pagination_by_first_id_invalid_first_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by first ID with invalid first_id.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Test pagination with invalid first_id
|
||||
with pytest.raises(FirstMessageNotExistsError):
|
||||
MessageService.pagination_by_first_id(
|
||||
app_model=app,
|
||||
user=account,
|
||||
conversation_id=conversation.id,
|
||||
first_id=fake.uuid4(), # Non-existent message ID
|
||||
limit=10,
|
||||
)
|
||||
|
||||
def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful pagination by last ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and multiple messages
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
messages = []
|
||||
for i in range(5):
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
messages.append(message)
|
||||
|
||||
# Test pagination by last ID
|
||||
result = MessageService.pagination_by_last_id(
|
||||
app_model=app,
|
||||
user=account,
|
||||
last_id=messages[2].id, # Use middle message as last_id
|
||||
limit=2,
|
||||
conversation_id=conversation.id,
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result.limit == 2
|
||||
assert len(result.data) == 2
|
||||
# total 5, from the middle, no more
|
||||
assert result.has_more is False
|
||||
# Verify messages are in descending order
|
||||
assert result.data[0].created_at >= result.data[1].created_at
|
||||
|
||||
def test_pagination_by_last_id_with_include_ids(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by last ID with include_ids filter.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and multiple messages
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
messages = []
|
||||
for i in range(5):
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
messages.append(message)
|
||||
|
||||
# Test pagination with include_ids
|
||||
include_ids = [messages[0].id, messages[1].id, messages[2].id]
|
||||
result = MessageService.pagination_by_last_id(
|
||||
app_model=app, user=account, last_id=messages[1].id, limit=2, include_ids=include_ids
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result.limit == 2
|
||||
assert len(result.data) <= 2
|
||||
# Verify all returned messages are in include_ids
|
||||
for message in result.data:
|
||||
assert message.id in include_ids
|
||||
|
||||
def test_pagination_by_last_id_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test pagination by last ID when no user is provided.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Test pagination with no user
|
||||
result = MessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
|
||||
|
||||
# Verify empty result
|
||||
assert result.limit == 10
|
||||
assert len(result.data) == 0
|
||||
assert result.has_more is False
|
||||
|
||||
def test_pagination_by_last_id_invalid_last_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by last ID with invalid last_id.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Test pagination with invalid last_id
|
||||
with pytest.raises(LastMessageNotExistsError):
|
||||
MessageService.pagination_by_last_id(
|
||||
app_model=app,
|
||||
user=account,
|
||||
last_id=fake.uuid4(), # Non-existent message ID
|
||||
limit=10,
|
||||
conversation_id=conversation.id,
|
||||
)
|
||||
|
||||
def test_create_feedback_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful creation of feedback.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Create feedback
|
||||
rating = "like"
|
||||
content = fake.text(max_nb_chars=100)
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating=rating, content=content
|
||||
)
|
||||
|
||||
# Verify feedback was created correctly
|
||||
assert feedback.app_id == app.id
|
||||
assert feedback.conversation_id == conversation.id
|
||||
assert feedback.message_id == message.id
|
||||
assert feedback.rating == rating
|
||||
assert feedback.content == content
|
||||
assert feedback.from_source == "admin"
|
||||
assert feedback.from_account_id == account.id
|
||||
assert feedback.from_end_user_id is None
|
||||
|
||||
def test_create_feedback_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test creating feedback when no user is provided.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Test creating feedback with no user
|
||||
with pytest.raises(ValueError, match="user cannot be None"):
|
||||
MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100)
|
||||
)
|
||||
|
||||
def test_create_feedback_update_existing(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test updating existing feedback.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Create initial feedback
|
||||
initial_rating = "like"
|
||||
initial_content = fake.text(max_nb_chars=100)
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating=initial_rating, content=initial_content
|
||||
)
|
||||
|
||||
# Update feedback
|
||||
updated_rating = "dislike"
|
||||
updated_content = fake.text(max_nb_chars=100)
|
||||
updated_feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating=updated_rating, content=updated_content
|
||||
)
|
||||
|
||||
# Verify feedback was updated correctly
|
||||
assert updated_feedback.id == feedback.id
|
||||
assert updated_feedback.rating == updated_rating
|
||||
assert updated_feedback.content == updated_content
|
||||
assert updated_feedback.rating != initial_rating
|
||||
assert updated_feedback.content != initial_content
|
||||
|
||||
def test_create_feedback_delete_existing(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test deleting existing feedback by setting rating to None.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Create initial feedback
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating="like", content=fake.text(max_nb_chars=100)
|
||||
)
|
||||
|
||||
# Delete feedback by setting rating to None
|
||||
MessageService.create_feedback(app_model=app, message_id=message.id, user=account, rating=None, content=None)
|
||||
|
||||
# Verify feedback was deleted
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first()
|
||||
assert deleted_feedback is None
|
||||
|
||||
def test_create_feedback_no_rating_when_not_exists(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creating feedback with no rating when feedback doesn't exist.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Test creating feedback with no rating when no feedback exists
|
||||
with pytest.raises(ValueError, match="rating cannot be None when feedback not exists"):
|
||||
MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating=None, content=None
|
||||
)
|
||||
|
||||
def test_get_all_messages_feedbacks_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of all message feedbacks.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create multiple conversations and messages with feedbacks
|
||||
feedbacks = []
|
||||
for i in range(3):
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app,
|
||||
message_id=message.id,
|
||||
user=account,
|
||||
rating="like" if i % 2 == 0 else "dislike",
|
||||
content=f"Feedback {i}: {fake.text(max_nb_chars=50)}",
|
||||
)
|
||||
feedbacks.append(feedback)
|
||||
|
||||
# Get all feedbacks
|
||||
result = MessageService.get_all_messages_feedbacks(app, page=1, limit=10)
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 3
|
||||
|
||||
# Verify feedbacks are ordered by created_at desc
|
||||
for i in range(len(result) - 1):
|
||||
assert result[i]["created_at"] >= result[i + 1]["created_at"]
|
||||
|
||||
def test_get_all_messages_feedbacks_pagination(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination of message feedbacks.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create multiple conversations and messages with feedbacks
|
||||
for i in range(5):
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}"
|
||||
)
|
||||
|
||||
# Get feedbacks with pagination
|
||||
result_page_1 = MessageService.get_all_messages_feedbacks(app, page=1, limit=3)
|
||||
result_page_2 = MessageService.get_all_messages_feedbacks(app, page=2, limit=3)
|
||||
|
||||
# Verify pagination results
|
||||
assert len(result_page_1) == 3
|
||||
assert len(result_page_2) == 2
|
||||
|
||||
# Verify no overlap between pages
|
||||
page_1_ids = {feedback["id"] for feedback in result_page_1}
|
||||
page_2_ids = {feedback["id"] for feedback in result_page_2}
|
||||
assert len(page_1_ids.intersection(page_2_ids)) == 0
|
||||
|
||||
def test_get_message_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of message.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Get message
|
||||
retrieved_message = MessageService.get_message(app_model=app, user=account, message_id=message.id)
|
||||
|
||||
# Verify message was retrieved correctly
|
||||
assert retrieved_message.id == message.id
|
||||
assert retrieved_message.app_id == app.id
|
||||
assert retrieved_message.conversation_id == conversation.id
|
||||
assert retrieved_message.from_source == "console"
|
||||
assert retrieved_message.from_account_id == account.id
|
||||
|
||||
def test_get_message_not_exists(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting message that doesn't exist.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Test getting non-existent message
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
MessageService.get_message(app_model=app, user=account, message_id=fake.uuid4())
|
||||
|
||||
def test_get_message_wrong_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting message with wrong user (different account).
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Create another account
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
other_account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(other_account, name=fake.company())
|
||||
|
||||
# Test getting message with different user
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
MessageService.get_message(app_model=app, user=other_account, message_id=message.id)
|
||||
|
||||
def test_get_suggested_questions_after_answer_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful generation of suggested questions after answer.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Mock the LLMGenerator to return specific questions
|
||||
mock_questions = ["What is AI?", "How does machine learning work?", "Tell me about neural networks"]
|
||||
mock_external_service_dependencies[
|
||||
"llm_generator"
|
||||
].generate_suggested_questions_after_answer.return_value = mock_questions
|
||||
|
||||
# Get suggested questions
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
result = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result == mock_questions
|
||||
|
||||
# Verify LLMGenerator was called
|
||||
mock_external_service_dependencies[
|
||||
"llm_generator"
|
||||
].generate_suggested_questions_after_answer.assert_called_once()
|
||||
|
||||
# Verify TraceQueueManager was called
|
||||
mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once()
|
||||
|
||||
def test_get_suggested_questions_after_answer_no_user(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting suggested questions when no user is provided.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Test getting suggested questions with no user
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
with pytest.raises(ValueError, match="user cannot be None"):
|
||||
MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app, user=None, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
|
||||
)
|
||||
|
||||
def test_get_suggested_questions_after_answer_disabled(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting suggested questions when feature is disabled.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Mock the feature to be disabled
|
||||
mock_external_service_dependencies[
|
||||
"app_config_manager"
|
||||
].get_app_config.return_value.additional_features.suggested_questions_after_answer = False
|
||||
|
||||
# Test getting suggested questions when feature is disabled
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
with pytest.raises(SuggestedQuestionsAfterAnswerDisabledError):
|
||||
MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
|
||||
)
|
||||
|
||||
def test_get_suggested_questions_after_answer_no_workflow(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting suggested questions when no workflow exists.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Mock no workflow
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None
|
||||
|
||||
# Get suggested questions (should return empty list)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
result = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API
|
||||
)
|
||||
|
||||
# Verify empty result
|
||||
assert result == []
|
||||
|
||||
def test_get_suggested_questions_after_answer_debugger_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting suggested questions in debugger mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
|
||||
# Mock questions
|
||||
mock_questions = ["Debug question 1", "Debug question 2"]
|
||||
mock_external_service_dependencies[
|
||||
"llm_generator"
|
||||
].generate_suggested_questions_after_answer.return_value = mock_questions
|
||||
|
||||
# Get suggested questions in debugger mode
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
result = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.DEBUGGER
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result == mock_questions
|
||||
|
||||
# Verify draft workflow was used instead of published workflow
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with(
|
||||
app_model=app
|
||||
)
|
||||
|
||||
# Verify TraceQueueManager was called
|
||||
mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,475 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
from models.model import Account, Tenant
|
||||
from models.provider import LoadBalancingModelConfig, Provider, ProviderModelSetting
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
class TestModelLoadBalancingService:
|
||||
"""Integration tests for ModelLoadBalancingService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.model_load_balancing_service.ProviderManager") as mock_provider_manager,
|
||||
patch("services.model_load_balancing_service.LBModelManager") as mock_lb_model_manager,
|
||||
patch("services.model_load_balancing_service.ModelProviderFactory") as mock_model_provider_factory,
|
||||
patch("services.model_load_balancing_service.encrypter") as mock_encrypter,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_provider_manager_instance = mock_provider_manager.return_value
|
||||
|
||||
# Mock provider configuration
|
||||
mock_provider_config = MagicMock()
|
||||
mock_provider_config.provider.provider = "openai"
|
||||
mock_provider_config.custom_configuration.provider = None
|
||||
|
||||
# Mock provider model setting
|
||||
mock_provider_model_setting = MagicMock()
|
||||
mock_provider_model_setting.load_balancing_enabled = False
|
||||
|
||||
mock_provider_config.get_provider_model_setting.return_value = mock_provider_model_setting
|
||||
|
||||
# Mock provider configurations dict
|
||||
mock_provider_configs = {"openai": mock_provider_config}
|
||||
mock_provider_manager_instance.get_configurations.return_value = mock_provider_configs
|
||||
|
||||
# Mock LBModelManager
|
||||
mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0)
|
||||
|
||||
# Mock ModelProviderFactory
|
||||
mock_model_provider_factory_instance = mock_model_provider_factory.return_value
|
||||
|
||||
# Mock credential schemas
|
||||
mock_credential_schema = MagicMock()
|
||||
mock_credential_schema.credential_form_schemas = []
|
||||
|
||||
# Mock provider configuration methods
|
||||
mock_provider_config.extract_secret_variables.return_value = []
|
||||
mock_provider_config.obfuscated_credentials.return_value = {}
|
||||
mock_provider_config._get_credential_schema.return_value = mock_credential_schema
|
||||
|
||||
yield {
|
||||
"provider_manager": mock_provider_manager,
|
||||
"lb_model_manager": mock_lb_model_manager,
|
||||
"model_provider_factory": mock_model_provider_factory,
|
||||
"encrypter": mock_encrypter,
|
||||
"provider_config": mock_provider_config,
|
||||
"provider_model_setting": mock_provider_model_setting,
|
||||
"credential_schema": mock_credential_schema,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_provider_and_setting(
|
||||
self, db_session_with_containers, tenant_id, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Helper method to create a test provider and provider model setting.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
tenant_id: Tenant ID for the provider
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (provider, provider_model_setting) - Created provider and setting instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create provider
|
||||
provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name="openai",
|
||||
provider_type="custom",
|
||||
is_valid=True,
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# Create provider model setting
|
||||
provider_model_setting = ProviderModelSetting(
|
||||
tenant_id=tenant_id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
model_type="text-generation", # Use the origin model type that matches the query
|
||||
enabled=True,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
db.session.add(provider_model_setting)
|
||||
db.session.commit()
|
||||
|
||||
return provider, provider_model_setting
|
||||
|
||||
def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful model load balancing enablement.
|
||||
|
||||
This test verifies:
|
||||
- Proper provider configuration retrieval
|
||||
- Successful enablement of model load balancing
|
||||
- Correct method calls to provider configuration
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
provider, provider_model_setting = self._create_test_provider_and_setting(
|
||||
db_session_with_containers, tenant.id, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup mocks for enable method
|
||||
mock_provider_config = mock_external_service_dependencies["provider_config"]
|
||||
mock_provider_config.enable_model_load_balancing = MagicMock()
|
||||
|
||||
# Act: Execute the method under test
|
||||
service = ModelLoadBalancingService()
|
||||
service.enable_model_load_balancing(
|
||||
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
mock_provider_config.enable_model_load_balancing.assert_called_once()
|
||||
call_args = mock_provider_config.enable_model_load_balancing.call_args
|
||||
assert call_args.kwargs["model"] == "gpt-3.5-turbo"
|
||||
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(provider)
|
||||
db.session.refresh(provider_model_setting)
|
||||
assert provider.id is not None
|
||||
assert provider_model_setting.id is not None
|
||||
|
||||
def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful model load balancing disablement.
|
||||
|
||||
This test verifies:
|
||||
- Proper provider configuration retrieval
|
||||
- Successful disablement of model load balancing
|
||||
- Correct method calls to provider configuration
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
provider, provider_model_setting = self._create_test_provider_and_setting(
|
||||
db_session_with_containers, tenant.id, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup mocks for disable method
|
||||
mock_provider_config = mock_external_service_dependencies["provider_config"]
|
||||
mock_provider_config.disable_model_load_balancing = MagicMock()
|
||||
|
||||
# Act: Execute the method under test
|
||||
service = ModelLoadBalancingService()
|
||||
service.disable_model_load_balancing(
|
||||
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
mock_provider_config.disable_model_load_balancing.assert_called_once()
|
||||
call_args = mock_provider_config.disable_model_load_balancing.call_args
|
||||
assert call_args.kwargs["model"] == "gpt-3.5-turbo"
|
||||
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(provider)
|
||||
db.session.refresh(provider_model_setting)
|
||||
assert provider.id is not None
|
||||
assert provider_model_setting.id is not None
|
||||
|
||||
def test_enable_model_load_balancing_provider_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when provider does not exist.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for non-existent provider
|
||||
- Correct exception type and message
|
||||
- No database state changes
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup mocks to return empty provider configurations
|
||||
mock_provider_manager = mock_external_service_dependencies["provider_manager"]
|
||||
mock_provider_manager_instance = mock_provider_manager.return_value
|
||||
mock_provider_manager_instance.get_configurations.return_value = {}
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
service = ModelLoadBalancingService()
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
service.enable_model_load_balancing(
|
||||
tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm"
|
||||
)
|
||||
|
||||
# Verify correct error message
|
||||
assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
|
||||
|
||||
# Verify no database state changes occurred
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.rollback()
|
||||
|
||||
def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of load balancing configurations.
|
||||
|
||||
This test verifies:
|
||||
- Proper provider configuration retrieval
|
||||
- Successful database query for load balancing configs
|
||||
- Correct return format and data structure
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
provider, provider_model_setting = self._create_test_provider_and_setting(
|
||||
db_session_with_containers, tenant.id, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create load balancing config
|
||||
from extensions.ext_database import db
|
||||
|
||||
load_balancing_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant.id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
model_type="text-generation", # Use the origin model type that matches the query
|
||||
name="config1",
|
||||
encrypted_config='{"api_key": "test_key"}',
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(load_balancing_config)
|
||||
db.session.commit()
|
||||
|
||||
# Verify the config was created
|
||||
db.session.refresh(load_balancing_config)
|
||||
assert load_balancing_config.id is not None
|
||||
|
||||
# Setup mocks for get_load_balancing_configs method
|
||||
mock_provider_config = mock_external_service_dependencies["provider_config"]
|
||||
mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"]
|
||||
mock_provider_model_setting.load_balancing_enabled = True
|
||||
|
||||
# Mock credential schema methods
|
||||
mock_credential_schema = mock_external_service_dependencies["credential_schema"]
|
||||
mock_credential_schema.credential_form_schemas = []
|
||||
|
||||
# Mock encrypter
|
||||
mock_encrypter = mock_external_service_dependencies["encrypter"]
|
||||
mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher")
|
||||
|
||||
# Mock _get_credential_schema method
|
||||
mock_provider_config._get_credential_schema.return_value = mock_credential_schema
|
||||
|
||||
# Mock extract_secret_variables method
|
||||
mock_provider_config.extract_secret_variables.return_value = []
|
||||
|
||||
# Mock obfuscated_credentials method
|
||||
mock_provider_config.obfuscated_credentials.return_value = {}
|
||||
|
||||
# Mock LBModelManager.get_config_in_cooldown_and_ttl
|
||||
mock_lb_model_manager = mock_external_service_dependencies["lb_model_manager"]
|
||||
mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0)
|
||||
|
||||
# Act: Execute the method under test
|
||||
service = ModelLoadBalancingService()
|
||||
is_enabled, configs = service.get_load_balancing_configs(
|
||||
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert is_enabled is True
|
||||
assert len(configs) == 1
|
||||
assert configs[0]["id"] == load_balancing_config.id
|
||||
assert configs[0]["name"] == "config1"
|
||||
assert configs[0]["enabled"] is True
|
||||
assert configs[0]["in_cooldown"] is False
|
||||
assert configs[0]["ttl"] == 0
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(load_balancing_config)
|
||||
assert load_balancing_config.id is not None
|
||||
|
||||
def test_get_load_balancing_configs_provider_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when provider does not exist in get_load_balancing_configs.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for non-existent provider
|
||||
- Correct exception type and message
|
||||
- No database state changes
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup mocks to return empty provider configurations
|
||||
mock_provider_manager = mock_external_service_dependencies["provider_manager"]
|
||||
mock_provider_manager_instance = mock_provider_manager.return_value
|
||||
mock_provider_manager_instance.get_configurations.return_value = {}
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
service = ModelLoadBalancingService()
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
service.get_load_balancing_configs(
|
||||
tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm"
|
||||
)
|
||||
|
||||
# Verify correct error message
|
||||
assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
|
||||
|
||||
# Verify no database state changes occurred
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.rollback()
|
||||
|
||||
def test_get_load_balancing_configs_with_inherit_config(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test load balancing configs retrieval with inherit configuration.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of inherit configuration
|
||||
- Correct ordering of configurations
|
||||
- Inherit config initialization when needed
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
provider, provider_model_setting = self._create_test_provider_and_setting(
|
||||
db_session_with_containers, tenant.id, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create load balancing config
|
||||
from extensions.ext_database import db
|
||||
|
||||
load_balancing_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant.id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
model_type="text-generation", # Use the origin model type that matches the query
|
||||
name="config1",
|
||||
encrypted_config='{"api_key": "test_key"}',
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(load_balancing_config)
|
||||
db.session.commit()
|
||||
|
||||
# Setup mocks for inherit config scenario
|
||||
mock_provider_config = mock_external_service_dependencies["provider_config"]
|
||||
mock_provider_config.custom_configuration.provider = MagicMock() # Enable custom config
|
||||
|
||||
mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"]
|
||||
mock_provider_model_setting.load_balancing_enabled = True
|
||||
|
||||
# Mock credential schema methods
|
||||
mock_credential_schema = mock_external_service_dependencies["credential_schema"]
|
||||
mock_credential_schema.credential_form_schemas = []
|
||||
|
||||
# Mock encrypter
|
||||
mock_encrypter = mock_external_service_dependencies["encrypter"]
|
||||
mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher")
|
||||
|
||||
# Act: Execute the method under test
|
||||
service = ModelLoadBalancingService()
|
||||
is_enabled, configs = service.get_load_balancing_configs(
|
||||
tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm"
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert is_enabled is True
|
||||
assert len(configs) == 2 # inherit config + existing config
|
||||
|
||||
# First config should be inherit config
|
||||
assert configs[0]["name"] == "__inherit__"
|
||||
assert configs[0]["enabled"] is True
|
||||
|
||||
# Second config should be the existing config
|
||||
assert configs[1]["id"] == load_balancing_config.id
|
||||
assert configs[1]["name"] == "config1"
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(load_balancing_config)
|
||||
assert load_balancing_config.id is not None
|
||||
|
||||
# Verify inherit config was created in database
|
||||
inherit_configs = db.session.scalars(
|
||||
select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__")
|
||||
).all()
|
||||
assert len(inherit_configs) == 1
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,620 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models.model import EndUser, Message
|
||||
from models.web import SavedMessage
|
||||
from services.app_service import AppService
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class TestSavedMessageService:
|
||||
"""Integration tests for SavedMessageService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.saved_message_service.MessageService") as mock_message_service,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Mock ModelManager for app creation
|
||||
mock_model_instance = mock_model_manager.return_value
|
||||
mock_model_instance.get_default_model_instance.return_value = None
|
||||
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||
|
||||
# Mock MessageService
|
||||
mock_message_service.get_message.return_value = None
|
||||
mock_message_service.pagination_by_last_id.return_value = None
|
||||
|
||||
yield {
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"model_manager": mock_model_manager,
|
||||
"message_service": mock_message_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (app, account) - Created app and account instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Setup mocks for account creation
|
||||
mock_external_service_dependencies[
|
||||
"account_feature_service"
|
||||
].get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Create account and tenant first
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
return app, account
|
||||
|
||||
def _create_test_end_user(self, db_session_with_containers, app):
|
||||
"""
|
||||
Helper method to create a test end user for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app: App instance to associate the end user with
|
||||
|
||||
Returns:
|
||||
EndUser: Created end user instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
end_user = EndUser(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
external_user_id=fake.uuid4(),
|
||||
name=fake.name(),
|
||||
type="normal",
|
||||
session_id=fake.uuid4(),
|
||||
is_anonymous=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
def _create_test_message(self, db_session_with_containers, app, user):
|
||||
"""
|
||||
Helper method to create a test message for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app: App instance to associate the message with
|
||||
user: User instance (Account or EndUser) to associate the message with
|
||||
|
||||
Returns:
|
||||
Message: Created message instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create a simple conversation first
|
||||
from models.model import Conversation
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
from_source="account" if hasattr(user, "current_tenant") else "end_user",
|
||||
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
|
||||
from_account_id=user.id if hasattr(user, "current_tenant") else None,
|
||||
name=fake.sentence(nb_words=3),
|
||||
inputs={},
|
||||
status="normal",
|
||||
mode="chat",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
|
||||
# Create message
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
from_source="account" if hasattr(user, "current_tenant") else "end_user",
|
||||
from_end_user_id=user.id if not hasattr(user, "current_tenant") else None,
|
||||
from_account_id=user.id if hasattr(user, "current_tenant") else None,
|
||||
inputs={},
|
||||
query=fake.sentence(nb_words=5),
|
||||
message=fake.text(max_nb_chars=100),
|
||||
answer=fake.text(max_nb_chars=200),
|
||||
message_tokens=50,
|
||||
answer_tokens=100,
|
||||
message_unit_price=0.001,
|
||||
answer_unit_price=0.002,
|
||||
total_price=0.003,
|
||||
currency="USD",
|
||||
status="success",
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
|
||||
return message
|
||||
|
||||
def test_pagination_by_last_id_success_with_account_user(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pagination by last ID with account user.
|
||||
|
||||
This test verifies:
|
||||
- Proper pagination with account user
|
||||
- Correct filtering by app_id and user
|
||||
- Proper role identification for account users
|
||||
- MessageService integration
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create test messages
|
||||
message1 = self._create_test_message(db_session_with_containers, app, account)
|
||||
message2 = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Create saved messages
|
||||
saved_message1 = SavedMessage(
|
||||
app_id=app.id,
|
||||
message_id=message1.id,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
saved_message2 = SavedMessage(
|
||||
app_id=app.id,
|
||||
message_id=message2.id,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add_all([saved_message1, saved_message2])
|
||||
db.session.commit()
|
||||
|
||||
# Mock MessageService.pagination_by_last_id return value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
|
||||
mock_pagination = InfiniteScrollPagination(data=[message1, message2], limit=10, has_more=False)
|
||||
mock_external_service_dependencies["message_service"].pagination_by_last_id.return_value = mock_pagination
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=account, last_id=None, limit=10)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result.data == [message1, message2]
|
||||
assert result.limit == 10
|
||||
assert result.has_more is False
|
||||
|
||||
# Verify MessageService was called with correct parameters
|
||||
# Sort the IDs to handle database query order variations
|
||||
expected_include_ids = sorted([message1.id, message2.id])
|
||||
actual_call = mock_external_service_dependencies["message_service"].pagination_by_last_id.call_args
|
||||
actual_include_ids = sorted(actual_call.kwargs.get("include_ids", []))
|
||||
|
||||
assert actual_call.kwargs["app_model"] == app
|
||||
assert actual_call.kwargs["user"] == account
|
||||
assert actual_call.kwargs["last_id"] is None
|
||||
assert actual_call.kwargs["limit"] == 10
|
||||
assert actual_include_ids == expected_include_ids
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(saved_message1)
|
||||
db.session.refresh(saved_message2)
|
||||
assert saved_message1.id is not None
|
||||
assert saved_message2.id is not None
|
||||
assert saved_message1.created_by_role == "account"
|
||||
assert saved_message2.created_by_role == "account"
|
||||
|
||||
def test_pagination_by_last_id_success_with_end_user(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pagination by last ID with end user.
|
||||
|
||||
This test verifies:
|
||||
- Proper pagination with end user
|
||||
- Correct filtering by app_id and user
|
||||
- Proper role identification for end users
|
||||
- MessageService integration
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
end_user = self._create_test_end_user(db_session_with_containers, app)
|
||||
|
||||
# Create test messages
|
||||
message1 = self._create_test_message(db_session_with_containers, app, end_user)
|
||||
message2 = self._create_test_message(db_session_with_containers, app, end_user)
|
||||
|
||||
# Create saved messages
|
||||
saved_message1 = SavedMessage(
|
||||
app_id=app.id,
|
||||
message_id=message1.id,
|
||||
created_by_role="end_user",
|
||||
created_by=end_user.id,
|
||||
)
|
||||
saved_message2 = SavedMessage(
|
||||
app_id=app.id,
|
||||
message_id=message2.id,
|
||||
created_by_role="end_user",
|
||||
created_by=end_user.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add_all([saved_message1, saved_message2])
|
||||
db.session.commit()
|
||||
|
||||
# Mock MessageService.pagination_by_last_id return value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
|
||||
mock_pagination = InfiniteScrollPagination(data=[message1, message2], limit=5, has_more=True)
|
||||
mock_external_service_dependencies["message_service"].pagination_by_last_id.return_value = mock_pagination
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = SavedMessageService.pagination_by_last_id(
|
||||
app_model=app, user=end_user, last_id="test_last_id", limit=5
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result.data == [message1, message2]
|
||||
assert result.limit == 5
|
||||
assert result.has_more is True
|
||||
|
||||
# Verify MessageService was called with correct parameters
|
||||
# Sort the IDs to handle database query order variations
|
||||
expected_include_ids = sorted([message1.id, message2.id])
|
||||
actual_call = mock_external_service_dependencies["message_service"].pagination_by_last_id.call_args
|
||||
actual_include_ids = sorted(actual_call.kwargs.get("include_ids", []))
|
||||
|
||||
assert actual_call.kwargs["app_model"] == app
|
||||
assert actual_call.kwargs["user"] == end_user
|
||||
assert actual_call.kwargs["last_id"] == "test_last_id"
|
||||
assert actual_call.kwargs["limit"] == 5
|
||||
assert actual_include_ids == expected_include_ids
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(saved_message1)
|
||||
db.session.refresh(saved_message2)
|
||||
assert saved_message1.id is not None
|
||||
assert saved_message2.id is not None
|
||||
assert saved_message1.created_by_role == "end_user"
|
||||
assert saved_message2.created_by_role == "end_user"
|
||||
|
||||
def test_save_success_with_new_message(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful save of a new message.
|
||||
|
||||
This test verifies:
|
||||
- Proper creation of new saved message
|
||||
- Correct database state after save
|
||||
- Proper relationship establishment
|
||||
- MessageService integration for message retrieval
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
message = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Mock MessageService.get_message return value
|
||||
mock_external_service_dependencies["message_service"].get_message.return_value = message
|
||||
|
||||
# Act: Execute the method under test
|
||||
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Check if saved message was created in database
|
||||
from extensions.ext_database import db
|
||||
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
SavedMessage.created_by_role == "account",
|
||||
SavedMessage.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert saved_message is not None
|
||||
assert saved_message.app_id == app.id
|
||||
assert saved_message.message_id == message.id
|
||||
assert saved_message.created_by_role == "account"
|
||||
assert saved_message.created_by == account.id
|
||||
assert saved_message.created_at is not None
|
||||
|
||||
# Verify MessageService.get_message was called
|
||||
mock_external_service_dependencies["message_service"].get_message.assert_called_once_with(
|
||||
app_model=app, user=account, message_id=message.id
|
||||
)
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(saved_message)
|
||||
assert saved_message.id is not None
|
||||
|
||||
def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test error handling when no user is provided.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing user
|
||||
- ValueError is raised when user is None
|
||||
- No database operations are performed
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
|
||||
|
||||
assert "User is required" in str(exc_info.value)
|
||||
|
||||
# Verify no database operations were performed
|
||||
from extensions.ext_database import db
|
||||
|
||||
saved_messages = db.session.query(SavedMessage).all()
|
||||
assert len(saved_messages) == 0
|
||||
|
||||
def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test error handling when saving message with no user.
|
||||
|
||||
This test verifies:
|
||||
- Method returns early when user is None
|
||||
- No database operations are performed
|
||||
- No exceptions are raised
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
message = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Act: Execute the method under test with None user
|
||||
result = SavedMessageService.save(app_model=app, user=None, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is None
|
||||
|
||||
# Verify no saved message was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert saved_message is None
|
||||
|
||||
def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful deletion of an existing saved message.
|
||||
|
||||
This test verifies:
|
||||
- Proper deletion of existing saved message
|
||||
- Correct database state after deletion
|
||||
- No errors during deletion process
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
message = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Create a saved message first
|
||||
saved_message = SavedMessage(
|
||||
app_id=app.id,
|
||||
message_id=message.id,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(saved_message)
|
||||
db.session.commit()
|
||||
|
||||
# Verify saved message exists
|
||||
assert (
|
||||
db.session.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
SavedMessage.created_by_role == "account",
|
||||
SavedMessage.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
SavedMessageService.delete(app_model=app, user=account, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Check if saved message was deleted from database
|
||||
deleted_saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
SavedMessage.created_by_role == "account",
|
||||
SavedMessage.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert deleted_saved_message is None
|
||||
|
||||
# Verify database state
|
||||
db.session.commit()
|
||||
# The message should still exist, only the saved_message should be deleted
|
||||
assert db.session.query(Message).where(Message.id == message.id).first() is not None
|
||||
|
||||
def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test error handling when no user is provided.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing user
|
||||
- ValueError is raised when user is None
|
||||
- No database operations are performed
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
|
||||
|
||||
assert "User is required" in str(exc_info.value)
|
||||
|
||||
# Verify no database operations were performed for this specific test
|
||||
# Note: We don't check total count as other tests may have created data
|
||||
# Instead, we verify that the error was properly raised
|
||||
pass
|
||||
|
||||
def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test error handling when saving message with no user.
|
||||
|
||||
This test verifies:
|
||||
- Method returns early when user is None
|
||||
- No database operations are performed
|
||||
- No exceptions are raised
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
message = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Act: Execute the method under test with None user
|
||||
result = SavedMessageService.save(app_model=app, user=None, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is None
|
||||
|
||||
# Verify no saved message was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert saved_message is None
|
||||
|
||||
def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful deletion of an existing saved message.
|
||||
|
||||
This test verifies:
|
||||
- Proper deletion of existing saved message
|
||||
- Correct database state after deletion
|
||||
- No errors during deletion process
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
message = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Create a saved message first
|
||||
saved_message = SavedMessage(
|
||||
app_id=app.id,
|
||||
message_id=message.id,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(saved_message)
|
||||
db.session.commit()
|
||||
|
||||
# Verify saved message exists
|
||||
assert (
|
||||
db.session.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
SavedMessage.created_by_role == "account",
|
||||
SavedMessage.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
SavedMessageService.delete(app_model=app, user=account, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Check if saved message was deleted from database
|
||||
deleted_saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
SavedMessage.created_by_role == "account",
|
||||
SavedMessage.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert deleted_saved_message is None
|
||||
|
||||
# Verify database state
|
||||
db.session.commit()
|
||||
# The message should still exist, only the saved_message should be deleted
|
||||
assert db.session.query(Message).where(Message.id == message.id).first() is not None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,573 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models import Account
|
||||
from models.model import Conversation, EndUser
|
||||
from models.web import PinnedConversation
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
|
||||
class TestWebConversationService:
|
||||
"""Integration tests for WebConversationService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.app_service.FeatureService") as mock_feature_service,
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
):
|
||||
# Setup default mock returns for app service
|
||||
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
|
||||
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
|
||||
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
|
||||
|
||||
# Setup default mock returns for account service
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Mock ModelManager for model configuration
|
||||
mock_model_instance = mock_model_manager.return_value
|
||||
mock_model_instance.get_default_model_instance.return_value = None
|
||||
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||
|
||||
yield {
|
||||
"feature_service": mock_feature_service,
|
||||
"enterprise_service": mock_enterprise_service,
|
||||
"model_manager": mock_model_manager,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (app, account) - Created app and account instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Setup mocks for account creation
|
||||
mock_external_service_dependencies[
|
||||
"account_feature_service"
|
||||
].get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Create account and tenant
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
return app, account
|
||||
|
||||
def _create_test_end_user(self, db_session_with_containers, app):
|
||||
"""
|
||||
Helper method to create a test end user for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app: App instance
|
||||
|
||||
Returns:
|
||||
EndUser: Created end user instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
end_user = EndUser(
|
||||
session_id=fake.uuid4(),
|
||||
app_id=app.id,
|
||||
type="normal",
|
||||
is_anonymous=False,
|
||||
tenant_id=app.tenant_id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
def _create_test_conversation(self, db_session_with_containers, app, user, fake):
|
||||
"""
|
||||
Helper method to create a test conversation for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app: App instance
|
||||
user: User instance (Account or EndUser)
|
||||
fake: Faker instance
|
||||
|
||||
Returns:
|
||||
Conversation: Created conversation instance
|
||||
"""
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
app_model_config_id=app.app_model_config_id,
|
||||
model_provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
name=fake.sentence(nb_words=3),
|
||||
summary=fake.text(max_nb_chars=100),
|
||||
inputs={},
|
||||
introduction=fake.text(max_nb_chars=200),
|
||||
system_instruction=fake.text(max_nb_chars=300),
|
||||
system_instruction_tokens=50,
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
from_source="console" if isinstance(user, Account) else "api",
|
||||
from_end_user_id=user.id if isinstance(user, EndUser) else None,
|
||||
from_account_id=user.id if isinstance(user, Account) else None,
|
||||
dialogue_count=0,
|
||||
is_deleted=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
|
||||
return conversation
|
||||
|
||||
def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful pagination by last ID with basic parameters.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create multiple conversations
|
||||
conversations = []
|
||||
for i in range(5):
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
conversations.append(conversation)
|
||||
|
||||
# Test pagination without pinned filter
|
||||
result = WebConversationService.pagination_by_last_id(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
user=account,
|
||||
last_id=None,
|
||||
limit=3,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=None,
|
||||
sort_by="-updated_at",
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result.limit == 3
|
||||
assert len(result.data) == 3
|
||||
assert result.has_more is True
|
||||
|
||||
# Verify conversations are in descending order by updated_at
|
||||
assert result.data[0].updated_at >= result.data[1].updated_at
|
||||
assert result.data[1].updated_at >= result.data[2].updated_at
|
||||
|
||||
def test_pagination_by_last_id_with_pinned_filter(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by last ID with pinned conversation filter.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create conversations
|
||||
conversations = []
|
||||
for i in range(5):
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
conversations.append(conversation)
|
||||
|
||||
# Pin some conversations
|
||||
pinned_conversation1 = PinnedConversation(
|
||||
app_id=app.id,
|
||||
conversation_id=conversations[0].id,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
pinned_conversation2 = PinnedConversation(
|
||||
app_id=app.id,
|
||||
conversation_id=conversations[2].id,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(pinned_conversation1)
|
||||
db.session.add(pinned_conversation2)
|
||||
db.session.commit()
|
||||
|
||||
# Test pagination with pinned filter
|
||||
result = WebConversationService.pagination_by_last_id(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
user=account,
|
||||
last_id=None,
|
||||
limit=10,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=True,
|
||||
sort_by="-updated_at",
|
||||
)
|
||||
|
||||
# Verify only pinned conversations are returned
|
||||
assert result.limit == 10
|
||||
assert len(result.data) == 2
|
||||
assert result.has_more is False
|
||||
|
||||
# Verify the returned conversations are the pinned ones
|
||||
returned_ids = [conv.id for conv in result.data]
|
||||
expected_ids = [conversations[0].id, conversations[2].id]
|
||||
assert set(returned_ids) == set(expected_ids)
|
||||
|
||||
def test_pagination_by_last_id_with_unpinned_filter(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by last ID with unpinned conversation filter.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create conversations
|
||||
conversations = []
|
||||
for i in range(5):
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
conversations.append(conversation)
|
||||
|
||||
# Pin one conversation
|
||||
pinned_conversation = PinnedConversation(
|
||||
app_id=app.id,
|
||||
conversation_id=conversations[0].id,
|
||||
created_by_role="account",
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(pinned_conversation)
|
||||
db.session.commit()
|
||||
|
||||
# Test pagination with unpinned filter
|
||||
result = WebConversationService.pagination_by_last_id(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
user=account,
|
||||
last_id=None,
|
||||
limit=10,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=False,
|
||||
sort_by="-updated_at",
|
||||
)
|
||||
|
||||
# Verify unpinned conversations are returned (should be 4 out of 5)
|
||||
assert result.limit == 10
|
||||
assert len(result.data) == 4
|
||||
assert result.has_more is False
|
||||
|
||||
# Verify the pinned conversation is not in the results
|
||||
returned_ids = [conv.id for conv in result.data]
|
||||
assert conversations[0].id not in returned_ids
|
||||
|
||||
# Verify all other conversations are in the results
|
||||
expected_unpinned_ids = [conv.id for conv in conversations[1:]]
|
||||
assert set(returned_ids) == set(expected_unpinned_ids)
|
||||
|
||||
def test_pin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful pinning of a conversation.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
|
||||
# Pin the conversation
|
||||
WebConversationService.pin(app, conversation.id, account)
|
||||
|
||||
# Verify the conversation was pinned
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
PinnedConversation.created_by_role == "account",
|
||||
PinnedConversation.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert pinned_conversation is not None
|
||||
assert pinned_conversation.app_id == app.id
|
||||
assert pinned_conversation.conversation_id == conversation.id
|
||||
assert pinned_conversation.created_by_role == "account"
|
||||
assert pinned_conversation.created_by == account.id
|
||||
|
||||
def test_pin_conversation_already_pinned(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test pinning a conversation that is already pinned (should not create duplicate).
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
|
||||
# Pin the conversation first time
|
||||
WebConversationService.pin(app, conversation.id, account)
|
||||
|
||||
# Pin the conversation again
|
||||
WebConversationService.pin(app, conversation.id, account)
|
||||
|
||||
# Verify only one pinned conversation record exists
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversations = db.session.scalars(
|
||||
select(PinnedConversation).where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
PinnedConversation.created_by_role == "account",
|
||||
PinnedConversation.created_by == account.id,
|
||||
)
|
||||
).all()
|
||||
|
||||
assert len(pinned_conversations) == 1
|
||||
|
||||
def test_pin_conversation_with_end_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test pinning a conversation with an end user.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create an end user
|
||||
end_user = self._create_test_end_user(db_session_with_containers, app)
|
||||
|
||||
# Create a conversation for the end user
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, end_user, fake)
|
||||
|
||||
# Pin the conversation
|
||||
WebConversationService.pin(app, conversation.id, end_user)
|
||||
|
||||
# Verify the conversation was pinned
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
PinnedConversation.created_by_role == "end_user",
|
||||
PinnedConversation.created_by == end_user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert pinned_conversation is not None
|
||||
assert pinned_conversation.app_id == app.id
|
||||
assert pinned_conversation.conversation_id == conversation.id
|
||||
assert pinned_conversation.created_by_role == "end_user"
|
||||
assert pinned_conversation.created_by == end_user.id
|
||||
|
||||
def test_unpin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful unpinning of a conversation.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
|
||||
# Pin the conversation first
|
||||
WebConversationService.pin(app, conversation.id, account)
|
||||
|
||||
# Verify it was pinned
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
PinnedConversation.created_by_role == "account",
|
||||
PinnedConversation.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert pinned_conversation is not None
|
||||
|
||||
# Unpin the conversation
|
||||
WebConversationService.unpin(app, conversation.id, account)
|
||||
|
||||
# Verify it was unpinned
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
PinnedConversation.created_by_role == "account",
|
||||
PinnedConversation.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert pinned_conversation is None
|
||||
|
||||
def test_unpin_conversation_not_pinned(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test unpinning a conversation that is not pinned (should not cause error).
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
|
||||
# Try to unpin a conversation that was never pinned
|
||||
WebConversationService.unpin(app, conversation.id, account)
|
||||
|
||||
# Verify no pinned conversation record exists
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
PinnedConversation.created_by_role == "account",
|
||||
PinnedConversation.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert pinned_conversation is None
|
||||
|
||||
def test_pagination_by_last_id_user_required_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that pagination_by_last_id raises ValueError when user is None.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Test with None user
|
||||
with pytest.raises(ValueError, match="User is required"):
|
||||
WebConversationService.pagination_by_last_id(
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
user=None,
|
||||
last_id=None,
|
||||
limit=10,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=None,
|
||||
sort_by="-updated_at",
|
||||
)
|
||||
|
||||
def test_pin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test that pin method returns early when user is None.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
|
||||
# Try to pin with None user
|
||||
WebConversationService.pin(app, conversation.id, None)
|
||||
|
||||
# Verify no pinned conversation was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert pinned_conversation is None
|
||||
|
||||
def test_unpin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test that unpin method returns early when user is None.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
|
||||
# Pin the conversation first
|
||||
WebConversationService.pin(app, conversation.id, account)
|
||||
|
||||
# Verify it was pinned
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
PinnedConversation.created_by_role == "account",
|
||||
PinnedConversation.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert pinned_conversation is not None
|
||||
|
||||
# Try to unpin with None user
|
||||
WebConversationService.unpin(app, conversation.id, None)
|
||||
|
||||
# Verify the conversation is still pinned
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
PinnedConversation.created_by_role == "account",
|
||||
PinnedConversation.created_by == account.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert pinned_conversation is not None
|
||||
@@ -0,0 +1,895 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from libs.password import hash_password
|
||||
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import App, Site
|
||||
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
||||
|
||||
|
||||
class TestWebAppAuthService:
|
||||
"""Integration tests for WebAppAuthService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.webapp_auth_service.PassportService") as mock_passport_service,
|
||||
patch("services.webapp_auth_service.TokenManager") as mock_token_manager,
|
||||
patch("services.webapp_auth_service.send_email_code_login_mail_task") as mock_mail_task,
|
||||
patch("services.webapp_auth_service.AppService") as mock_app_service,
|
||||
patch("services.webapp_auth_service.EnterpriseService") as mock_enterprise_service,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_passport_service.return_value.issue.return_value = "mock_jwt_token"
|
||||
mock_token_manager.generate_token.return_value = "mock_token"
|
||||
mock_token_manager.get_token_data.return_value = {"code": "123456"}
|
||||
mock_mail_task.delay.return_value = None
|
||||
mock_app_service.get_app_id_by_code.return_value = "mock_app_id"
|
||||
mock_enterprise_service.WebAppAuth.get_app_access_mode_by_id.return_value = type(
|
||||
"MockWebAppAuth", (), {"access_mode": "private"}
|
||||
)()
|
||||
# Note: get_app_access_mode_by_code method was removed in refactoring
|
||||
|
||||
yield {
|
||||
"passport_service": mock_passport_service,
|
||||
"token_manager": mock_token_manager,
|
||||
"mail_task": mock_mail_task,
|
||||
"app_service": mock_app_service,
|
||||
"enterprise_service": mock_enterprise_service,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
import uuid
|
||||
|
||||
# Create account with unique email to avoid collisions
|
||||
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com"
|
||||
account = Account(
|
||||
email=unique_email,
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_account_with_password(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account with password for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant, password) - Created account, tenant and password
|
||||
"""
|
||||
fake = Faker()
|
||||
password = fake.password(length=12)
|
||||
|
||||
# Create account with password
|
||||
import uuid
|
||||
|
||||
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com"
|
||||
account = Account(
|
||||
email=unique_email,
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
|
||||
# Hash password
|
||||
salt = b"test_salt_16_bytes"
|
||||
password_hash = hash_password(password, salt)
|
||||
|
||||
# Convert to base64 for storage
|
||||
import base64
|
||||
|
||||
account.password = base64.b64encode(password_hash).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant, password
|
||||
|
||||
def _create_test_app_and_site(self, db_session_with_containers, mock_external_service_dependencies, tenant):
|
||||
"""
|
||||
Helper method to create a test app and site for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
tenant: Tenant instance to associate with
|
||||
|
||||
Returns:
|
||||
tuple: (app, site) - Created app and site instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create app
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
|
||||
# Create site
|
||||
site = Site(
|
||||
app_id=app.id,
|
||||
title=fake.company(),
|
||||
code=fake.unique.lexify(text="??????"),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
default_language="en-US",
|
||||
status="normal",
|
||||
customize_token_strategy="not_allow",
|
||||
)
|
||||
db.session.add(site)
|
||||
db.session.commit()
|
||||
|
||||
return app, site
|
||||
|
||||
def test_authenticate_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful authentication with valid email and password.
|
||||
|
||||
This test verifies:
|
||||
- Proper authentication with valid credentials
|
||||
- Correct account return
|
||||
- Database state consistency
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant, password = self._create_test_account_with_password(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Act: Execute authentication
|
||||
result = WebAppAuthService.authenticate(account.email, password)
|
||||
|
||||
# Assert: Verify successful authentication
|
||||
assert result is not None
|
||||
assert result.id == account.id
|
||||
assert result.email == account.email
|
||||
assert result.name == account.name
|
||||
assert result.status == AccountStatus.ACTIVE
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(result)
|
||||
assert result.id is not None
|
||||
assert result.password is not None
|
||||
assert result.password_salt is not None
|
||||
|
||||
def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test authentication with non-existent email.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for non-existent accounts
|
||||
- Correct exception type and message
|
||||
"""
|
||||
# Arrange: Generate a guaranteed non-existent email
|
||||
# Use UUID and timestamp to ensure uniqueness
|
||||
unique_id = str(uuid.uuid4()).replace("-", "")
|
||||
timestamp = str(int(time.time() * 1000000)) # microseconds
|
||||
non_existent_email = f"nonexistent_{unique_id}_{timestamp}@test-domain-that-never-exists.invalid"
|
||||
|
||||
# Double-check this email doesn't exist in the database
|
||||
existing_account = db_session_with_containers.query(Account).filter_by(email=non_existent_email).first()
|
||||
assert existing_account is None, f"Test email {non_existent_email} already exists in database"
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(AccountNotFoundError):
|
||||
WebAppAuthService.authenticate(non_existent_email, "any_password")
|
||||
|
||||
def test_authenticate_account_banned(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test authentication with banned account.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for banned accounts
|
||||
- Correct exception type and message
|
||||
"""
|
||||
# Arrange: Create banned account
|
||||
fake = Faker()
|
||||
password = fake.password(length=12)
|
||||
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status=AccountStatus.BANNED,
|
||||
)
|
||||
|
||||
# Hash password
|
||||
salt = b"test_salt_16_bytes"
|
||||
password_hash = hash_password(password, salt)
|
||||
|
||||
# Convert to base64 for storage
|
||||
import base64
|
||||
|
||||
account.password = base64.b64encode(password_hash).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(AccountLoginError) as exc_info:
|
||||
WebAppAuthService.authenticate(account.email, password)
|
||||
|
||||
assert "Account is banned." in str(exc_info.value)
|
||||
|
||||
def test_authenticate_invalid_password(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test authentication with invalid password.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for invalid passwords
|
||||
- Correct exception type and message
|
||||
"""
|
||||
# Arrange: Create account with password
|
||||
account, tenant, correct_password = self._create_test_account_with_password(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Act & Assert: Verify proper error handling with wrong password
|
||||
with pytest.raises(AccountPasswordError) as exc_info:
|
||||
WebAppAuthService.authenticate(account.email, "wrong_password")
|
||||
|
||||
assert "Invalid email or password." in str(exc_info.value)
|
||||
|
||||
def test_authenticate_account_without_password(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test authentication for account without password.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for accounts without password
|
||||
- Correct exception type and message
|
||||
"""
|
||||
# Arrange: Create account without password
|
||||
fake = Faker()
|
||||
import uuid
|
||||
|
||||
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com"
|
||||
|
||||
account = Account(
|
||||
email=unique_email,
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(AccountPasswordError) as exc_info:
|
||||
WebAppAuthService.authenticate(account.email, "any_password")
|
||||
|
||||
assert "Invalid email or password." in str(exc_info.value)
|
||||
|
||||
def test_login_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful login and JWT token generation.
|
||||
|
||||
This test verifies:
|
||||
- Proper JWT token generation
|
||||
- Correct token format and content
|
||||
- Mock service integration
|
||||
"""
|
||||
# Arrange: Create test account
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Act: Execute login
|
||||
result = WebAppAuthService.login(account)
|
||||
|
||||
# Assert: Verify successful login
|
||||
assert result is not None
|
||||
assert result == "mock_jwt_token"
|
||||
|
||||
# Verify mock service was called correctly
|
||||
mock_external_service_dependencies["passport_service"].return_value.issue.assert_called_once()
|
||||
call_args = mock_external_service_dependencies["passport_service"].return_value.issue.call_args[0][0]
|
||||
|
||||
assert call_args["sub"] == "Web API Passport"
|
||||
assert call_args["user_id"] == account.id
|
||||
assert call_args["session_id"] == account.email
|
||||
assert call_args["token_source"] == "webapp_login_token"
|
||||
assert call_args["auth_type"] == "internal"
|
||||
assert "exp" in call_args
|
||||
|
||||
def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful user retrieval through email.
|
||||
|
||||
This test verifies:
|
||||
- Proper user retrieval by email
|
||||
- Correct account return
|
||||
- Database state consistency
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Act: Execute user retrieval
|
||||
result = WebAppAuthService.get_user_through_email(account.email)
|
||||
|
||||
# Assert: Verify successful retrieval
|
||||
assert result is not None
|
||||
assert result.id == account.id
|
||||
assert result.email == account.email
|
||||
assert result.name == account.name
|
||||
assert result.status == AccountStatus.ACTIVE
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(result)
|
||||
assert result.id is not None
|
||||
|
||||
def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test user retrieval with non-existent email.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling for non-existent users
|
||||
- Correct return value (None)
|
||||
"""
|
||||
# Arrange: Use non-existent email
|
||||
fake = Faker()
|
||||
non_existent_email = fake.email()
|
||||
|
||||
# Act: Execute user retrieval
|
||||
result = WebAppAuthService.get_user_through_email(non_existent_email)
|
||||
|
||||
# Assert: Verify proper handling
|
||||
assert result is None
|
||||
|
||||
def test_get_user_through_email_banned(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test user retrieval with banned account.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for banned accounts
|
||||
- Correct exception type and message
|
||||
"""
|
||||
# Arrange: Create banned account
|
||||
fake = Faker()
|
||||
import uuid
|
||||
|
||||
unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com"
|
||||
|
||||
account = Account(
|
||||
email=unique_email,
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status=AccountStatus.BANNED,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
WebAppAuthService.get_user_through_email(account.email)
|
||||
|
||||
assert "Account is banned." in str(exc_info.value)
|
||||
|
||||
def test_send_email_code_login_email_with_account(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test sending email code login email with account.
|
||||
|
||||
This test verifies:
|
||||
- Proper email code generation
|
||||
- Token generation with correct data
|
||||
- Mail task scheduling
|
||||
- Mock service integration
|
||||
"""
|
||||
# Arrange: Create test account
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Act: Execute email code login email sending
|
||||
result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US")
|
||||
|
||||
# Assert: Verify successful email sending
|
||||
assert result is not None
|
||||
assert result == "mock_token"
|
||||
|
||||
# Verify mock services were called correctly
|
||||
mock_external_service_dependencies["token_manager"].generate_token.assert_called_once()
|
||||
mock_external_service_dependencies["mail_task"].delay.assert_called_once()
|
||||
|
||||
# Verify token generation parameters
|
||||
token_call_args = mock_external_service_dependencies["token_manager"].generate_token.call_args
|
||||
assert token_call_args[1]["account"] == account
|
||||
assert token_call_args[1]["email"] == account.email
|
||||
assert token_call_args[1]["token_type"] == "email_code_login"
|
||||
assert "code" in token_call_args[1]["additional_data"]
|
||||
|
||||
# Verify mail task parameters
|
||||
mail_call_args = mock_external_service_dependencies["mail_task"].delay.call_args
|
||||
assert mail_call_args[1]["language"] == "en-US"
|
||||
assert mail_call_args[1]["to"] == account.email
|
||||
assert "code" in mail_call_args[1]
|
||||
|
||||
def test_send_email_code_login_email_with_email_only(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test sending email code login email with email only.
|
||||
|
||||
This test verifies:
|
||||
- Proper email code generation without account
|
||||
- Token generation with email only
|
||||
- Mail task scheduling
|
||||
- Mock service integration
|
||||
"""
|
||||
# Arrange: Use test email
|
||||
fake = Faker()
|
||||
test_email = fake.email()
|
||||
|
||||
# Act: Execute email code login email sending
|
||||
result = WebAppAuthService.send_email_code_login_email(email=test_email, language="zh-Hans")
|
||||
|
||||
# Assert: Verify successful email sending
|
||||
assert result is not None
|
||||
assert result == "mock_token"
|
||||
|
||||
# Verify mock services were called correctly
|
||||
mock_external_service_dependencies["token_manager"].generate_token.assert_called_once()
|
||||
mock_external_service_dependencies["mail_task"].delay.assert_called_once()
|
||||
|
||||
# Verify token generation parameters
|
||||
token_call_args = mock_external_service_dependencies["token_manager"].generate_token.call_args
|
||||
assert token_call_args[1]["account"] is None
|
||||
assert token_call_args[1]["email"] == test_email
|
||||
assert token_call_args[1]["token_type"] == "email_code_login"
|
||||
assert "code" in token_call_args[1]["additional_data"]
|
||||
|
||||
# Verify mail task parameters
|
||||
mail_call_args = mock_external_service_dependencies["mail_task"].delay.call_args
|
||||
assert mail_call_args[1]["language"] == "zh-Hans"
|
||||
assert mail_call_args[1]["to"] == test_email
|
||||
assert "code" in mail_call_args[1]
|
||||
|
||||
def test_send_email_code_login_email_no_email_provided(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test sending email code login email without providing email.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling when no email is provided
|
||||
- Correct exception type and message
|
||||
"""
|
||||
# Arrange: No email provided
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WebAppAuthService.send_email_code_login_email()
|
||||
|
||||
assert "Email must be provided." in str(exc_info.value)
|
||||
|
||||
def test_get_email_code_login_data_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of email code login data.
|
||||
|
||||
This test verifies:
|
||||
- Proper token data retrieval
|
||||
- Correct data format
|
||||
- Mock service integration
|
||||
"""
|
||||
# Arrange: Setup mock return
|
||||
expected_data = {"code": "123456", "email": "test@example.com"}
|
||||
mock_external_service_dependencies["token_manager"].get_token_data.return_value = expected_data
|
||||
|
||||
# Act: Execute data retrieval
|
||||
result = WebAppAuthService.get_email_code_login_data("mock_token")
|
||||
|
||||
# Assert: Verify successful retrieval
|
||||
assert result is not None
|
||||
assert result == expected_data
|
||||
assert result["code"] == "123456"
|
||||
assert result["email"] == "test@example.com"
|
||||
|
||||
# Verify mock service was called correctly
|
||||
mock_external_service_dependencies["token_manager"].get_token_data.assert_called_once_with(
|
||||
"mock_token", "email_code_login"
|
||||
)
|
||||
|
||||
def test_get_email_code_login_data_no_data(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test email code login data retrieval when no data exists.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling when no token data exists
|
||||
- Correct return value (None)
|
||||
- Mock service integration
|
||||
"""
|
||||
# Arrange: Setup mock return for no data
|
||||
mock_external_service_dependencies["token_manager"].get_token_data.return_value = None
|
||||
|
||||
# Act: Execute data retrieval
|
||||
result = WebAppAuthService.get_email_code_login_data("invalid_token")
|
||||
|
||||
# Assert: Verify proper handling
|
||||
assert result is None
|
||||
|
||||
# Verify mock service was called correctly
|
||||
mock_external_service_dependencies["token_manager"].get_token_data.assert_called_once_with(
|
||||
"invalid_token", "email_code_login"
|
||||
)
|
||||
|
||||
def test_revoke_email_code_login_token_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful revocation of email code login token.
|
||||
|
||||
This test verifies:
|
||||
- Proper token revocation
|
||||
- Mock service integration
|
||||
"""
|
||||
# Arrange: Setup mock
|
||||
|
||||
# Act: Execute token revocation
|
||||
WebAppAuthService.revoke_email_code_login_token("mock_token")
|
||||
|
||||
# Assert: Verify mock service was called correctly
|
||||
mock_external_service_dependencies["token_manager"].revoke_token.assert_called_once_with(
|
||||
"mock_token", "email_code_login"
|
||||
)
|
||||
|
||||
def test_create_end_user_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful end user creation.
|
||||
|
||||
This test verifies:
|
||||
- Proper end user creation with valid app code
|
||||
- Correct database state after creation
|
||||
- Proper relationship establishment
|
||||
- Mock service integration
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
app, site = self._create_test_app_and_site(
|
||||
db_session_with_containers, mock_external_service_dependencies, tenant
|
||||
)
|
||||
|
||||
# Act: Execute end user creation
|
||||
result = WebAppAuthService.create_end_user(site.code, "test@example.com")
|
||||
|
||||
# Assert: Verify successful creation
|
||||
assert result is not None
|
||||
assert result.tenant_id == app.tenant_id
|
||||
assert result.app_id == app.id
|
||||
assert result.type == "browser"
|
||||
assert result.is_anonymous is False
|
||||
assert result.session_id == "test@example.com"
|
||||
assert result.name == "enterpriseuser"
|
||||
assert result.external_user_id == "enterpriseuser"
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(result)
|
||||
assert result.id is not None
|
||||
assert result.created_at is not None
|
||||
assert result.updated_at is not None
|
||||
|
||||
def test_create_end_user_site_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test end user creation with non-existent site code.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for non-existent sites
|
||||
- Correct exception type and message
|
||||
"""
|
||||
# Arrange: Use non-existent site code
|
||||
fake = Faker()
|
||||
non_existent_code = fake.unique.lexify(text="??????")
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(NotFound) as exc_info:
|
||||
WebAppAuthService.create_end_user(non_existent_code, "test@example.com")
|
||||
|
||||
assert "Site not found." in str(exc_info.value)
|
||||
|
||||
def test_create_end_user_app_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test end user creation when app is not found.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling when app is missing
|
||||
- Correct exception type and message
|
||||
"""
|
||||
# Arrange: Create site without app
|
||||
fake = Faker()
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
site = Site(
|
||||
app_id="00000000-0000-0000-0000-000000000000",
|
||||
title=fake.company(),
|
||||
code=fake.unique.lexify(text="??????"),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
default_language="en-US",
|
||||
status="normal",
|
||||
customize_token_strategy="not_allow",
|
||||
)
|
||||
db.session.add(site)
|
||||
db.session.commit()
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(NotFound) as exc_info:
|
||||
WebAppAuthService.create_end_user(site.code, "test@example.com")
|
||||
|
||||
assert "App not found." in str(exc_info.value)
|
||||
|
||||
def test_is_app_require_permission_check_with_access_mode_private(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test permission check requirement for private access mode.
|
||||
|
||||
This test verifies:
|
||||
- Proper permission check requirement for private mode
|
||||
- Correct return value
|
||||
- Mock service integration
|
||||
"""
|
||||
# Arrange: Setup test with private access mode
|
||||
|
||||
# Act: Execute permission check requirement test
|
||||
result = WebAppAuthService.is_app_require_permission_check(access_mode="private")
|
||||
|
||||
# Assert: Verify correct result
|
||||
assert result is True
|
||||
|
||||
def test_is_app_require_permission_check_with_access_mode_public(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test permission check requirement for public access mode.
|
||||
|
||||
This test verifies:
|
||||
- Proper permission check requirement for public mode
|
||||
- Correct return value
|
||||
- Mock service integration
|
||||
"""
|
||||
# Arrange: Setup test with public access mode
|
||||
|
||||
# Act: Execute permission check requirement test
|
||||
result = WebAppAuthService.is_app_require_permission_check(access_mode="public")
|
||||
|
||||
# Assert: Verify correct result
|
||||
assert result is False
|
||||
|
||||
def test_is_app_require_permission_check_with_app_code(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test permission check requirement using app code.
|
||||
|
||||
This test verifies:
|
||||
- Proper permission check requirement using app code
|
||||
- Correct return value
|
||||
- Mock service integration
|
||||
"""
|
||||
# Arrange: Setup mock for app service
|
||||
mock_external_service_dependencies["app_service"].get_app_id_by_code.return_value = "mock_app_id"
|
||||
|
||||
# Act: Execute permission check requirement test
|
||||
result = WebAppAuthService.is_app_require_permission_check(app_code="mock_app_code")
|
||||
|
||||
# Assert: Verify correct result
|
||||
assert result is True
|
||||
|
||||
# Verify mock service was called correctly
|
||||
mock_external_service_dependencies["app_service"].get_app_id_by_code.assert_called_once_with("mock_app_code")
|
||||
mock_external_service_dependencies[
|
||||
"enterprise_service"
|
||||
].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with("mock_app_id")
|
||||
|
||||
def test_is_app_require_permission_check_no_parameters(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test permission check requirement with no parameters.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling when no parameters provided
|
||||
- Correct exception type and message
|
||||
"""
|
||||
# Arrange: No parameters provided
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WebAppAuthService.is_app_require_permission_check()
|
||||
|
||||
assert "Either app_code or app_id must be provided." in str(exc_info.value)
|
||||
|
||||
def test_get_app_auth_type_with_access_mode_public(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test app authentication type for public access mode.
|
||||
|
||||
This test verifies:
|
||||
- Proper authentication type determination for public mode
|
||||
- Correct return value
|
||||
- Mock service integration
|
||||
"""
|
||||
# Arrange: Setup test with public access mode
|
||||
|
||||
# Act: Execute authentication type determination
|
||||
result = WebAppAuthService.get_app_auth_type(access_mode="public")
|
||||
|
||||
# Assert: Verify correct result
|
||||
assert result == WebAppAuthType.PUBLIC
|
||||
|
||||
def test_get_app_auth_type_with_access_mode_private(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test app authentication type for private access mode.
|
||||
|
||||
This test verifies:
|
||||
- Proper authentication type determination for private mode
|
||||
- Correct return value
|
||||
- Mock service integration
|
||||
"""
|
||||
# Arrange: Setup test with private access mode
|
||||
|
||||
# Act: Execute authentication type determination
|
||||
result = WebAppAuthService.get_app_auth_type(access_mode="private")
|
||||
|
||||
# Assert: Verify correct result
|
||||
assert result == WebAppAuthType.INTERNAL
|
||||
|
||||
def test_get_app_auth_type_with_app_code(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app authentication type using app code.
|
||||
|
||||
This test verifies:
|
||||
- Proper authentication type determination using app code
|
||||
- Correct return value
|
||||
- Mock service integration
|
||||
"""
|
||||
# Arrange: Setup mock for enterprise service
|
||||
mock_external_service_dependencies["app_service"].get_app_id_by_code.return_value = "mock_app_id"
|
||||
setting = type("MockWebAppAuth", (), {"access_mode": "sso_verified"})()
|
||||
mock_external_service_dependencies[
|
||||
"enterprise_service"
|
||||
].WebAppAuth.get_app_access_mode_by_id.return_value = setting
|
||||
|
||||
# Act: Execute authentication type determination
|
||||
result: WebAppAuthType = WebAppAuthService.get_app_auth_type(app_code="mock_app_code")
|
||||
|
||||
# Assert: Verify correct result
|
||||
assert result == WebAppAuthType.EXTERNAL
|
||||
|
||||
# Verify mock service was called correctly
|
||||
mock_external_service_dependencies[
|
||||
"enterprise_service"
|
||||
].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with(app_id="mock_app_id")
|
||||
|
||||
def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app authentication type with no parameters.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling when no parameters provided
|
||||
- Correct exception type and message
|
||||
"""
|
||||
# Arrange: No parameters provided
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WebAppAuthService.get_app_auth_type()
|
||||
|
||||
assert "Either app_code or access_mode must be provided." in str(exc_info.value)
|
||||
@@ -0,0 +1,571 @@
|
||||
import json
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from flask import Flask
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from models.enums import AppTriggerStatus, AppTriggerType
|
||||
from models.model import App
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
from models.workflow import Workflow
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class TestWebhookService:
|
||||
"""Integration tests for WebhookService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_dependencies(self):
|
||||
"""Mock external service dependencies."""
|
||||
with (
|
||||
patch("services.trigger.webhook_service.AsyncWorkflowService") as mock_async_service,
|
||||
patch("services.trigger.webhook_service.ToolFileManager") as mock_tool_file_manager,
|
||||
patch("services.trigger.webhook_service.file_factory") as mock_file_factory,
|
||||
patch("services.account_service.FeatureService") as mock_feature_service,
|
||||
):
|
||||
# Mock ToolFileManager
|
||||
mock_tool_file_instance = MagicMock()
|
||||
mock_tool_file_manager.return_value = mock_tool_file_instance
|
||||
|
||||
# Mock file creation
|
||||
mock_tool_file = MagicMock()
|
||||
mock_tool_file.id = "test_file_id"
|
||||
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
|
||||
|
||||
# Mock file factory
|
||||
mock_file_obj = MagicMock()
|
||||
mock_file_factory.build_from_mapping.return_value = mock_file_obj
|
||||
|
||||
# Mock feature service
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True
|
||||
|
||||
yield {
|
||||
"async_service": mock_async_service,
|
||||
"tool_file_manager": mock_tool_file_manager,
|
||||
"file_factory": mock_file_factory,
|
||||
"tool_file": mock_tool_file,
|
||||
"file_obj": mock_file_obj,
|
||||
"feature_service": mock_feature_service,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def test_data(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Create test data for webhook service tests."""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
assert tenant is not None
|
||||
|
||||
# Create app
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(),
|
||||
mode="workflow",
|
||||
icon="",
|
||||
icon_background="",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create workflow
|
||||
workflow_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "webhook_node",
|
||||
"type": "webhook",
|
||||
"data": {
|
||||
"title": "Test Webhook",
|
||||
"method": "post",
|
||||
"content_type": "application/json",
|
||||
"headers": [
|
||||
{"name": "Authorization", "required": True},
|
||||
{"name": "Content-Type", "required": False},
|
||||
],
|
||||
"params": [{"name": "version", "required": True}, {"name": "format", "required": False}],
|
||||
"body": [
|
||||
{"name": "message", "type": "string", "required": True},
|
||||
{"name": "count", "type": "number", "required": False},
|
||||
{"name": "upload", "type": "file", "required": False},
|
||||
],
|
||||
"status_code": 200,
|
||||
"response_body": '{"status": "success"}',
|
||||
"timeout": 30,
|
||||
},
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
workflow = Workflow(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
type="workflow",
|
||||
graph=json.dumps(workflow_data),
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
version="1.0",
|
||||
)
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create webhook trigger
|
||||
webhook_id = fake.uuid4()[:16]
|
||||
webhook_trigger = WorkflowWebhookTrigger(
|
||||
app_id=app.id,
|
||||
node_id="webhook_node",
|
||||
tenant_id=tenant.id,
|
||||
webhook_id=str(webhook_id),
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(webhook_trigger)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create app trigger (required for non-debug mode)
|
||||
app_trigger = AppTrigger(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
node_id="webhook_node",
|
||||
trigger_type=AppTriggerType.TRIGGER_WEBHOOK,
|
||||
provider_name="webhook",
|
||||
title="Test Webhook",
|
||||
status=AppTriggerStatus.ENABLED,
|
||||
)
|
||||
db_session_with_containers.add(app_trigger)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return {
|
||||
"tenant": tenant,
|
||||
"account": account,
|
||||
"app": app,
|
||||
"workflow": workflow,
|
||||
"webhook_trigger": webhook_trigger,
|
||||
"webhook_id": webhook_id,
|
||||
"app_trigger": app_trigger,
|
||||
}
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_success(self, test_data, flask_app_with_containers):
|
||||
"""Test successful retrieval of webhook trigger and workflow."""
|
||||
webhook_id = test_data["webhook_id"]
|
||||
|
||||
with flask_app_with_containers.app_context():
|
||||
webhook_trigger, workflow, node_config = WebhookService.get_webhook_trigger_and_workflow(webhook_id)
|
||||
|
||||
assert webhook_trigger is not None
|
||||
assert webhook_trigger.webhook_id == webhook_id
|
||||
assert workflow is not None
|
||||
assert workflow.app_id == test_data["app"].id
|
||||
assert node_config is not None
|
||||
assert node_config["id"] == "webhook_node"
|
||||
assert node_config["data"]["title"] == "Test Webhook"
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers):
|
||||
"""Test webhook trigger not found scenario."""
|
||||
with flask_app_with_containers.app_context():
|
||||
with pytest.raises(ValueError, match="Webhook not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow("nonexistent_webhook")
|
||||
|
||||
def test_extract_webhook_data_json(self):
|
||||
"""Test webhook data extraction from JSON request."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json", "Authorization": "Bearer token"},
|
||||
query_string="version=1&format=json",
|
||||
json={"message": "hello", "count": 42},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["headers"]["Authorization"] == "Bearer token"
|
||||
assert webhook_data["query_params"]["version"] == "1"
|
||||
assert webhook_data["query_params"]["format"] == "json"
|
||||
assert webhook_data["body"]["message"] == "hello"
|
||||
assert webhook_data["body"]["count"] == 42
|
||||
assert webhook_data["files"] == {}
|
||||
|
||||
def test_extract_webhook_data_form_urlencoded(self):
|
||||
"""Test webhook data extraction from form URL encoded request."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={"username": "test", "password": "secret"},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["body"]["username"] == "test"
|
||||
assert webhook_data["body"]["password"] == "secret"
|
||||
|
||||
def test_extract_webhook_data_multipart_with_files(self, mock_external_dependencies):
|
||||
"""Test webhook data extraction from multipart form with files."""
|
||||
app = Flask(__name__)
|
||||
|
||||
# Create a mock file
|
||||
file_content = b"test file content"
|
||||
file_storage = FileStorage(stream=BytesIO(file_content), filename="test.txt", content_type="text/plain")
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "multipart/form-data"},
|
||||
data={"message": "test", "upload": file_storage},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["body"]["message"] == "test"
|
||||
assert "upload" in webhook_data["files"]
|
||||
|
||||
# Verify file processing was called
|
||||
mock_external_dependencies["tool_file_manager"].assert_called_once()
|
||||
mock_external_dependencies["file_factory"].build_from_mapping.assert_called_once()
|
||||
|
||||
def test_extract_webhook_data_raw_text(self):
|
||||
"""Test webhook data extraction from raw text request."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook", method="POST", headers={"Content-Type": "text/plain"}, data="raw text content"
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["body"]["raw"] == "raw text content"
|
||||
|
||||
def test_extract_and_validate_webhook_request_success(self):
|
||||
"""Test successful webhook request validation and type conversion."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json", "Authorization": "Bearer token"},
|
||||
query_string="version=1",
|
||||
json={"message": "hello"},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
node_config = {
|
||||
"data": {
|
||||
"method": "post",
|
||||
"content_type": "application/json",
|
||||
"headers": [
|
||||
{"name": "Authorization", "required": True},
|
||||
{"name": "Content-Type", "required": False},
|
||||
],
|
||||
"params": [{"name": "version", "required": True}],
|
||||
"body": [{"name": "message", "type": "string", "required": True}],
|
||||
}
|
||||
}
|
||||
|
||||
result = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
assert result["headers"]["Authorization"] == "Bearer token"
|
||||
assert result["query_params"]["version"] == "1"
|
||||
assert result["body"]["message"] == "hello"
|
||||
|
||||
def test_extract_and_validate_webhook_request_method_mismatch(self):
|
||||
"""Test webhook validation with HTTP method mismatch."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="GET",
|
||||
headers={"Content-Type": "application/json"},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
node_config = {"data": {"method": "post", "content_type": "application/json"}}
|
||||
|
||||
with pytest.raises(ValueError, match="HTTP method mismatch"):
|
||||
WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
def test_extract_and_validate_webhook_request_missing_required_header(self):
|
||||
"""Test webhook validation with missing required header."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
node_config = {
|
||||
"data": {
|
||||
"method": "post",
|
||||
"content_type": "application/json",
|
||||
"headers": [{"name": "Authorization", "required": True}],
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Required header missing: Authorization"):
|
||||
WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
def test_extract_and_validate_webhook_request_case_insensitive_headers(self):
|
||||
"""Test webhook validation with case-insensitive header matching."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json", "authorization": "Bearer token"},
|
||||
json={"message": "hello"},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
node_config = {
|
||||
"data": {
|
||||
"method": "post",
|
||||
"content_type": "application/json",
|
||||
"headers": [{"name": "Authorization", "required": True}],
|
||||
"body": [{"name": "message", "type": "string", "required": True}],
|
||||
}
|
||||
}
|
||||
|
||||
result = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
assert result["headers"].get("Authorization") == "Bearer token"
|
||||
|
||||
def test_extract_and_validate_webhook_request_missing_required_param(self):
|
||||
"""Test webhook validation with missing required query parameter."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={"message": "hello"},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
node_config = {
|
||||
"data": {
|
||||
"method": "post",
|
||||
"content_type": "application/json",
|
||||
"params": [{"name": "version", "required": True}],
|
||||
"body": [{"name": "message", "type": "string", "required": True}],
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Required parameter missing: version"):
|
||||
WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
def test_extract_and_validate_webhook_request_missing_required_body_param(self):
|
||||
"""Test webhook validation with missing required body parameter."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
node_config = {
|
||||
"data": {
|
||||
"method": "post",
|
||||
"content_type": "application/json",
|
||||
"body": [{"name": "message", "type": "string", "required": True}],
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Required body parameter missing: message"):
|
||||
WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
def test_extract_and_validate_webhook_request_missing_required_file(self):
|
||||
"""Test webhook validation when required file is missing from multipart request."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
data={"note": "test"},
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "tenant"
|
||||
webhook_trigger.created_by = "user"
|
||||
node_config = {
|
||||
"data": {
|
||||
"method": "post",
|
||||
"content_type": "multipart/form-data",
|
||||
"body": [{"name": "upload", "type": "file", "required": True}],
|
||||
}
|
||||
}
|
||||
|
||||
result = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
assert result["files"] == {}
|
||||
|
||||
def test_trigger_workflow_execution_success(self, test_data, mock_external_dependencies, flask_app_with_containers):
|
||||
"""Test successful workflow execution trigger."""
|
||||
webhook_data = {
|
||||
"method": "POST",
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
"query_params": {"version": "1"},
|
||||
"body": {"message": "hello"},
|
||||
"files": {},
|
||||
}
|
||||
|
||||
with flask_app_with_containers.app_context():
|
||||
# Mock tenant owner lookup to return the test account
|
||||
with patch("services.trigger.webhook_service.select") as mock_select:
|
||||
mock_query = MagicMock()
|
||||
mock_select.return_value.join.return_value.where.return_value = mock_query
|
||||
|
||||
# Mock the session to return our test account
|
||||
with patch("services.trigger.webhook_service.Session") as mock_session:
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.scalar.return_value = test_data["account"]
|
||||
|
||||
# Should not raise any exceptions
|
||||
WebhookService.trigger_workflow_execution(
|
||||
test_data["webhook_trigger"], webhook_data, test_data["workflow"]
|
||||
)
|
||||
|
||||
# Verify AsyncWorkflowService was called
|
||||
mock_external_dependencies["async_service"].trigger_workflow_async.assert_called_once()
|
||||
|
||||
def test_trigger_workflow_execution_end_user_service_failure(
|
||||
self, test_data, mock_external_dependencies, flask_app_with_containers
|
||||
):
|
||||
"""Test workflow execution trigger when EndUserService fails."""
|
||||
webhook_data = {"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}}
|
||||
|
||||
with flask_app_with_containers.app_context():
|
||||
# Mock EndUserService to raise an exception
|
||||
with patch(
|
||||
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type"
|
||||
) as mock_end_user:
|
||||
mock_end_user.side_effect = ValueError("Failed to create end user")
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to create end user"):
|
||||
WebhookService.trigger_workflow_execution(
|
||||
test_data["webhook_trigger"], webhook_data, test_data["workflow"]
|
||||
)
|
||||
|
||||
def test_generate_webhook_response_default(self):
|
||||
"""Test webhook response generation with default values."""
|
||||
node_config = {"data": {}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 200
|
||||
assert response_data["status"] == "success"
|
||||
assert "Webhook processed successfully" in response_data["message"]
|
||||
|
||||
def test_generate_webhook_response_custom_json(self):
|
||||
"""Test webhook response generation with custom JSON response."""
|
||||
node_config = {"data": {"status_code": 201, "response_body": '{"result": "created", "id": 123}'}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 201
|
||||
assert response_data["result"] == "created"
|
||||
assert response_data["id"] == 123
|
||||
|
||||
def test_generate_webhook_response_custom_text(self):
|
||||
"""Test webhook response generation with custom text response."""
|
||||
node_config = {"data": {"status_code": 202, "response_body": "Request accepted for processing"}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 202
|
||||
assert response_data["message"] == "Request accepted for processing"
|
||||
|
||||
def test_generate_webhook_response_invalid_json(self):
|
||||
"""Test webhook response generation with invalid JSON response."""
|
||||
node_config = {"data": {"status_code": 400, "response_body": '{"invalid": json}'}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 400
|
||||
assert response_data["message"] == '{"invalid": json}'
|
||||
|
||||
def test_process_file_uploads_success(self, mock_external_dependencies):
|
||||
"""Test successful file upload processing."""
|
||||
# Create mock files
|
||||
files = {
|
||||
"file1": MagicMock(filename="test1.txt", content_type="text/plain"),
|
||||
"file2": MagicMock(filename="test2.jpg", content_type="image/jpeg"),
|
||||
}
|
||||
|
||||
# Mock file reads
|
||||
files["file1"].read.return_value = b"content1"
|
||||
files["file2"].read.return_value = b"content2"
|
||||
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
result = WebhookService._process_file_uploads(files, webhook_trigger)
|
||||
|
||||
assert len(result) == 2
|
||||
assert "file1" in result
|
||||
assert "file2" in result
|
||||
|
||||
# Verify file processing was called for each file
|
||||
assert mock_external_dependencies["tool_file_manager"].call_count == 2
|
||||
assert mock_external_dependencies["file_factory"].build_from_mapping.call_count == 2
|
||||
|
||||
def test_process_file_uploads_with_errors(self, mock_external_dependencies):
|
||||
"""Test file upload processing with errors."""
|
||||
# Create mock files, one will fail
|
||||
files = {
|
||||
"good_file": MagicMock(filename="test.txt", content_type="text/plain"),
|
||||
"bad_file": MagicMock(filename="test.bad", content_type="text/plain"),
|
||||
}
|
||||
|
||||
files["good_file"].read.return_value = b"content"
|
||||
files["bad_file"].read.side_effect = Exception("Read error")
|
||||
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
result = WebhookService._process_file_uploads(files, webhook_trigger)
|
||||
|
||||
# Should process the good file and skip the bad one
|
||||
assert len(result) == 1
|
||||
assert "good_file" in result
|
||||
assert "bad_file" not in result
|
||||
|
||||
def test_process_file_uploads_empty_filename(self, mock_external_dependencies):
|
||||
"""Test file upload processing with empty filename."""
|
||||
files = {
|
||||
"no_filename": MagicMock(filename="", content_type="text/plain"),
|
||||
"none_filename": MagicMock(filename=None, content_type="text/plain"),
|
||||
}
|
||||
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
result = WebhookService._process_file_uploads(files, webhook_trigger)
|
||||
|
||||
# Should skip files without filenames
|
||||
assert len(result) == 0
|
||||
mock_external_dependencies["tool_file_manager"].assert_not_called()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,841 @@
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from models import App, Workflow
|
||||
from models.enums import DraftVariableType
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import (
|
||||
UpdateNotSupportedError,
|
||||
WorkflowDraftVariableService,
|
||||
)
|
||||
|
||||
|
||||
def _get_random_variable_name(fake: Faker):
|
||||
return "".join(fake.random_letters(length=10))
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableService:
|
||||
"""
|
||||
Comprehensive integration tests for WorkflowDraftVariableService using testcontainers.
|
||||
|
||||
This test class covers all major functionality of the WorkflowDraftVariableService:
|
||||
- CRUD operations for workflow draft variables (Create, Read, Update, Delete)
|
||||
- Variable listing and filtering by type (conversation, system, node)
|
||||
- Variable updates and resets with proper validation
|
||||
- Variable deletion operations at different scopes
|
||||
- Special functionality like prefill and conversation ID retrieval
|
||||
- Error handling for various edge cases and invalid operations
|
||||
|
||||
All tests use the testcontainers infrastructure to ensure proper database isolation
|
||||
and realistic testing environment with actual database interactions.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""
|
||||
Mock setup for external service dependencies.
|
||||
|
||||
WorkflowDraftVariableService doesn't have external dependencies that need mocking,
|
||||
so this fixture returns an empty dictionary to maintain consistency with other test classes.
|
||||
This ensures the test structure remains consistent across different service test files.
|
||||
"""
|
||||
# WorkflowDraftVariableService doesn't have external dependencies that need mocking
|
||||
return {}
|
||||
|
||||
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None):
|
||||
"""
|
||||
Helper method to create a test app with realistic data for testing.
|
||||
|
||||
This method creates a complete App instance with all required fields populated
|
||||
using Faker for generating realistic test data. The app is configured for
|
||||
workflow mode to support workflow draft variable testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies (unused in this service)
|
||||
fake: Faker instance for generating test data, creates new instance if not provided
|
||||
|
||||
Returns:
|
||||
App: Created test app instance with all required fields populated
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
app = App()
|
||||
app.id = fake.uuid4()
|
||||
app.tenant_id = fake.uuid4()
|
||||
app.name = fake.company()
|
||||
app.description = fake.text()
|
||||
app.mode = "workflow"
|
||||
app.icon_type = "emoji"
|
||||
app.icon = "🤖"
|
||||
app.icon_background = "#FFEAD5"
|
||||
app.enable_site = True
|
||||
app.enable_api = True
|
||||
app.created_by = fake.uuid4()
|
||||
app.updated_by = app.created_by
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
return app
|
||||
|
||||
def _create_test_workflow(self, db_session_with_containers, app, fake=None):
|
||||
"""
|
||||
Helper method to create a test workflow associated with an app.
|
||||
|
||||
This method creates a Workflow instance using the proper factory method
|
||||
to ensure all required fields are set correctly. The workflow is configured
|
||||
as a draft version with basic graph structure for testing workflow variables.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app: The app to associate the workflow with
|
||||
fake: Faker instance for generating test data, creates new instance if not provided
|
||||
|
||||
Returns:
|
||||
Workflow: Created test workflow instance with proper configuration
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
workflow = Workflow.new(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features="{}",
|
||||
created_by=app.created_by,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
return workflow
|
||||
|
||||
def _create_test_variable(
|
||||
self,
|
||||
db_session_with_containers,
|
||||
app_id,
|
||||
node_id,
|
||||
name,
|
||||
value,
|
||||
variable_type: DraftVariableType = DraftVariableType.CONVERSATION,
|
||||
fake=None,
|
||||
):
|
||||
"""
|
||||
Helper method to create a test workflow draft variable with proper configuration.
|
||||
|
||||
This method creates different types of variables (conversation, system, node) using
|
||||
the appropriate factory methods to ensure proper initialization. Each variable type
|
||||
has specific requirements and this method handles the creation logic for all types.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app_id: ID of the app to associate the variable with
|
||||
node_id: ID of the node (or special constants like CONVERSATION_VARIABLE_NODE_ID)
|
||||
name: Name of the variable for identification
|
||||
value: StringSegment value for the variable content
|
||||
variable_type: Type of variable ("conversation", "system", "node") determining creation method
|
||||
fake: Faker instance for generating test data, creates new instance if not provided
|
||||
|
||||
Returns:
|
||||
WorkflowDraftVariable: Created test variable instance with proper type configuration
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
if variable_type == "conversation":
|
||||
# Create conversation variable using the appropriate factory method
|
||||
variable = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=app_id,
|
||||
name=name,
|
||||
value=value,
|
||||
description=fake.text(max_nb_chars=20),
|
||||
)
|
||||
elif variable_type == "system":
|
||||
# Create system variable with editable flag and execution context
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=app_id,
|
||||
name=name,
|
||||
value=value,
|
||||
node_execution_id=fake.uuid4(),
|
||||
editable=True,
|
||||
)
|
||||
else: # node variable
|
||||
# Create node variable with visibility and editability settings
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app_id,
|
||||
node_id=node_id,
|
||||
name=name,
|
||||
value=value,
|
||||
node_execution_id=fake.uuid4(),
|
||||
visible=True,
|
||||
editable=True,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(variable)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting a single variable by ID successfully.
|
||||
|
||||
This test verifies that the service can retrieve a specific variable
|
||||
by its ID and that the returned variable contains the correct data.
|
||||
It ensures the basic CRUD read operation works correctly for workflow draft variables.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
test_value = StringSegment(value=fake.word())
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_variable = service.get_variable(variable.id)
|
||||
assert retrieved_variable is not None
|
||||
assert retrieved_variable.id == variable.id
|
||||
assert retrieved_variable.name == "test_var"
|
||||
assert retrieved_variable.app_id == app.id
|
||||
assert retrieved_variable.get_value().value == test_value.value
|
||||
|
||||
def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting a variable that doesn't exist.
|
||||
|
||||
This test verifies that the service returns None when trying to
|
||||
retrieve a variable with a non-existent ID. This ensures proper
|
||||
handling of missing data scenarios.
|
||||
"""
|
||||
fake = Faker()
|
||||
non_existent_id = fake.uuid4()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_variable = service.get_variable(non_existent_id)
|
||||
assert retrieved_variable is None
|
||||
|
||||
def test_get_draft_variables_by_selectors_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting variables by selectors successfully.
|
||||
|
||||
This test verifies that the service can retrieve multiple variables
|
||||
using selector pairs (node_id, variable_name) and returns the correct
|
||||
variables for each selector. This is useful for bulk variable retrieval
|
||||
operations in workflow execution contexts.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
var1_value = StringSegment(value=fake.word())
|
||||
var2_value = StringSegment(value=fake.word())
|
||||
var3_value = StringSegment(value=fake.word())
|
||||
var1 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var1", var1_value, fake=fake
|
||||
)
|
||||
var2 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var2", var2_value, fake=fake
|
||||
)
|
||||
var3 = self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
"test_node_1",
|
||||
"var3",
|
||||
var3_value,
|
||||
variable_type=DraftVariableType.NODE,
|
||||
fake=fake,
|
||||
)
|
||||
selectors = [
|
||||
[CONVERSATION_VARIABLE_NODE_ID, "var1"],
|
||||
[CONVERSATION_VARIABLE_NODE_ID, "var2"],
|
||||
["test_node_1", "var3"],
|
||||
]
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_variables = service.get_draft_variables_by_selectors(app.id, selectors)
|
||||
assert len(retrieved_variables) == 3
|
||||
var_names = [var.name for var in retrieved_variables]
|
||||
assert "var1" in var_names
|
||||
assert "var2" in var_names
|
||||
assert "var3" in var_names
|
||||
for var in retrieved_variables:
|
||||
if var.name == "var1":
|
||||
assert var.get_value().value == var1_value.value
|
||||
elif var.name == "var2":
|
||||
assert var.get_value().value == var2_value.value
|
||||
elif var.name == "var3":
|
||||
assert var.get_value().value == var3_value.value
|
||||
|
||||
def test_list_variables_without_values_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test listing variables without values successfully with pagination.
|
||||
|
||||
This test verifies that the service can list variables with pagination
|
||||
and that the returned variables don't include their values (for performance).
|
||||
This is important for scenarios where only variable metadata is needed
|
||||
without loading the actual content.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
for i in range(5):
|
||||
test_value = StringSegment(value=fake.numerify("value######"))
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
_get_random_variable_name(fake),
|
||||
test_value,
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_variables_without_values(app.id, page=1, limit=3)
|
||||
assert result.total == 5
|
||||
assert len(result.variables) == 3
|
||||
assert result.variables[0].created_at >= result.variables[1].created_at
|
||||
assert result.variables[1].created_at >= result.variables[2].created_at
|
||||
for var in result.variables:
|
||||
assert var.name is not None
|
||||
assert var.app_id == app.id
|
||||
|
||||
def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test listing variables for a specific node successfully.
|
||||
|
||||
This test verifies that the service can filter and return only
|
||||
variables associated with a specific node ID. This is crucial for
|
||||
workflow execution where variables need to be scoped to specific nodes.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
node_id = fake.word()
|
||||
var1_value = StringSegment(value=fake.word())
|
||||
var2_value = StringSegment(value=fake.word())
|
||||
var3_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
node_id,
|
||||
"var1",
|
||||
var1_value,
|
||||
variable_type=DraftVariableType.NODE,
|
||||
fake=fake,
|
||||
)
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
node_id,
|
||||
"var2",
|
||||
var3_value,
|
||||
variable_type=DraftVariableType.NODE,
|
||||
fake=fake,
|
||||
)
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
"other_node",
|
||||
"var3",
|
||||
var2_value,
|
||||
variable_type=DraftVariableType.NODE,
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_node_variables(app.id, node_id)
|
||||
assert len(result.variables) == 2
|
||||
for var in result.variables:
|
||||
assert var.node_id == node_id
|
||||
assert var.app_id == app.id
|
||||
var_names = [var.name for var in result.variables]
|
||||
assert "var1" in var_names
|
||||
assert "var2" in var_names
|
||||
assert "var3" not in var_names
|
||||
|
||||
def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test listing conversation variables successfully.
|
||||
|
||||
This test verifies that the service can filter and return only
|
||||
conversation variables, excluding system and node variables.
|
||||
Conversation variables are user-facing variables that can be
|
||||
modified during conversation flows.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
conv_var1_value = StringSegment(value=fake.word())
|
||||
conv_var2_value = StringSegment(value=fake.word())
|
||||
conv_var1 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var1", conv_var1_value, fake=fake
|
||||
)
|
||||
conv_var2 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var2", conv_var2_value, fake=fake
|
||||
)
|
||||
sys_var_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
"sys_var",
|
||||
sys_var_value,
|
||||
variable_type=DraftVariableType.SYS,
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_conversation_variables(app.id)
|
||||
assert len(result.variables) == 2
|
||||
for var in result.variables:
|
||||
assert var.node_id == CONVERSATION_VARIABLE_NODE_ID
|
||||
assert var.app_id == app.id
|
||||
assert var.get_variable_type() == DraftVariableType.CONVERSATION
|
||||
var_names = [var.name for var in result.variables]
|
||||
assert "conv_var1" in var_names
|
||||
assert "conv_var2" in var_names
|
||||
assert "sys_var" not in var_names
|
||||
|
||||
def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test updating a variable's name and value successfully.
|
||||
|
||||
This test verifies that the service can update both the name and value
|
||||
of an editable variable and that the changes are persisted correctly.
|
||||
It also checks that the last_edited_at timestamp is updated appropriately.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
original_value = StringSegment(value=fake.word())
|
||||
new_value = StringSegment(value=fake.word())
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
"original_name",
|
||||
original_value,
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
updated_variable = service.update_variable(variable, name="new_name", value=new_value)
|
||||
assert updated_variable.name == "new_name"
|
||||
assert updated_variable.get_value().value == new_value.value
|
||||
assert updated_variable.last_edited_at is not None
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(variable)
|
||||
assert variable.name == "new_name"
|
||||
assert variable.get_value().value == new_value.value
|
||||
assert variable.last_edited_at is not None
|
||||
|
||||
def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test that updating a non-editable variable raises an exception.
|
||||
|
||||
This test verifies that the service properly prevents updates to
|
||||
variables that are not marked as editable. This is important for
|
||||
maintaining data integrity and preventing unauthorized modifications
|
||||
to system-controlled variables.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
original_value = StringSegment(value=fake.word())
|
||||
new_value = StringSegment(value=fake.word())
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=app.id,
|
||||
name=fake.word(), # This is typically not editable
|
||||
value=original_value,
|
||||
node_execution_id=fake.uuid4(),
|
||||
editable=False, # Set as non-editable
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(variable)
|
||||
db.session.commit()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
with pytest.raises(UpdateNotSupportedError) as exc_info:
|
||||
service.update_variable(variable, name="new_name", value=new_value)
|
||||
assert "variable not support updating" in str(exc_info.value)
|
||||
assert variable.id in str(exc_info.value)
|
||||
|
||||
def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test resetting conversation variable successfully.
|
||||
|
||||
This test verifies that the service can reset a conversation variable
|
||||
to its default value and clear the last_edited_at timestamp.
|
||||
This functionality is useful for reverting user modifications
|
||||
back to the original workflow configuration.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake)
|
||||
from core.variables.variables import StringVariable
|
||||
|
||||
conv_var = StringVariable(
|
||||
id=fake.uuid4(),
|
||||
name="test_conv_var",
|
||||
value="default_value",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"],
|
||||
)
|
||||
workflow.conversation_variables = [conv_var]
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
modified_value = StringSegment(value=fake.word())
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
"test_conv_var",
|
||||
modified_value,
|
||||
fake=fake,
|
||||
)
|
||||
variable.last_edited_at = fake.date_time()
|
||||
db.session.commit()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
reset_variable = service.reset_variable(workflow, variable)
|
||||
assert reset_variable is not None
|
||||
assert reset_variable.get_value().value == "default_value"
|
||||
assert reset_variable.last_edited_at is None
|
||||
db.session.refresh(variable)
|
||||
assert variable.get_value().value == "default_value"
|
||||
assert variable.last_edited_at is None
|
||||
|
||||
def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test deleting a single variable successfully.
|
||||
|
||||
This test verifies that the service can delete a specific variable
|
||||
and that it's properly removed from the database. It ensures that
|
||||
the deletion operation is atomic and complete.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
test_value = StringSegment(value=fake.word())
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.delete_variable(variable)
|
||||
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None
|
||||
|
||||
def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test deleting all variables for a workflow successfully.
|
||||
|
||||
This test verifies that the service can delete all variables
|
||||
associated with a specific app/workflow. This is useful for
|
||||
cleanup operations when workflows are deleted or reset.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
for i in range(3):
|
||||
test_value = StringSegment(value=fake.numerify("value######"))
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
_get_random_variable_name(fake),
|
||||
test_value,
|
||||
fake=fake,
|
||||
)
|
||||
other_app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
other_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
other_app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
_get_random_variable_name(fake),
|
||||
other_value,
|
||||
fake=fake,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
|
||||
other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
|
||||
assert len(app_variables) == 3
|
||||
assert len(other_app_variables) == 1
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.delete_workflow_variables(app.id)
|
||||
app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
|
||||
other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
|
||||
assert len(app_variables_after) == 0
|
||||
assert len(other_app_variables_after) == 1
|
||||
|
||||
def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test deleting all variables for a specific node successfully.
|
||||
|
||||
This test verifies that the service can delete all variables
|
||||
associated with a specific node while preserving variables
|
||||
for other nodes and conversation variables. This is important
|
||||
for node-specific cleanup operations in workflow management.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
node_id = fake.word()
|
||||
for i in range(2):
|
||||
test_value = StringSegment(value=fake.numerify("node_value######"))
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
node_id,
|
||||
_get_random_variable_name(fake),
|
||||
test_value,
|
||||
variable_type=DraftVariableType.NODE,
|
||||
fake=fake,
|
||||
)
|
||||
other_node_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
"other_node",
|
||||
_get_random_variable_name(fake),
|
||||
other_node_value,
|
||||
variable_type=DraftVariableType.NODE,
|
||||
fake=fake,
|
||||
)
|
||||
conv_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
_get_random_variable_name(fake),
|
||||
conv_value,
|
||||
fake=fake,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
|
||||
other_node_variables = (
|
||||
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
|
||||
)
|
||||
conv_variables = (
|
||||
db.session.query(WorkflowDraftVariable)
|
||||
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
|
||||
.all()
|
||||
)
|
||||
assert len(target_node_variables) == 2
|
||||
assert len(other_node_variables) == 1
|
||||
assert len(conv_variables) == 1
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.delete_node_variables(app.id, node_id)
|
||||
target_node_variables_after = (
|
||||
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
|
||||
)
|
||||
other_node_variables_after = (
|
||||
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
|
||||
)
|
||||
conv_variables_after = (
|
||||
db.session.query(WorkflowDraftVariable)
|
||||
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
|
||||
.all()
|
||||
)
|
||||
assert len(target_node_variables_after) == 0
|
||||
assert len(other_node_variables_after) == 1
|
||||
assert len(conv_variables_after) == 1
|
||||
|
||||
def test_prefill_conversation_variable_default_values_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test prefill conversation variable default values successfully.
|
||||
|
||||
This test verifies that the service can automatically create
|
||||
conversation variables with default values based on the workflow
|
||||
configuration when none exist. This is important for initializing
|
||||
workflow variables with proper defaults from the workflow definition.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake)
|
||||
from core.variables.variables import StringVariable
|
||||
|
||||
conv_var1 = StringVariable(
|
||||
id=fake.uuid4(),
|
||||
name="conv_var1",
|
||||
value="default_value1",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var1"],
|
||||
)
|
||||
conv_var2 = StringVariable(
|
||||
id=fake.uuid4(),
|
||||
name="conv_var2",
|
||||
value="default_value2",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"],
|
||||
)
|
||||
workflow.conversation_variables = [conv_var1, conv_var2]
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.prefill_conversation_variable_default_values(workflow)
|
||||
draft_variables = (
|
||||
db.session.query(WorkflowDraftVariable)
|
||||
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
|
||||
.all()
|
||||
)
|
||||
assert len(draft_variables) == 2
|
||||
var_names = [var.name for var in draft_variables]
|
||||
assert "conv_var1" in var_names
|
||||
assert "conv_var2" in var_names
|
||||
for var in draft_variables:
|
||||
assert var.app_id == app.id
|
||||
assert var.node_id == CONVERSATION_VARIABLE_NODE_ID
|
||||
assert var.editable is True
|
||||
assert var.get_variable_type() == DraftVariableType.CONVERSATION
|
||||
|
||||
def test_get_conversation_id_from_draft_variable_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting conversation ID from draft variable successfully.
|
||||
|
||||
This test verifies that the service can extract the conversation ID
|
||||
from a system variable named "conversation_id". This is important
|
||||
for maintaining conversation context across workflow executions.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
conversation_id = fake.uuid4()
|
||||
conv_id_value = StringSegment(value=conversation_id)
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
"conversation_id",
|
||||
conv_id_value,
|
||||
variable_type=DraftVariableType.SYS,
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
|
||||
assert retrieved_conv_id == conversation_id
|
||||
|
||||
def test_get_conversation_id_from_draft_variable_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting conversation ID when it doesn't exist.
|
||||
|
||||
This test verifies that the service returns None when no
|
||||
conversation_id variable exists for the app. This ensures
|
||||
proper handling of missing conversation context scenarios.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
|
||||
assert retrieved_conv_id is None
|
||||
|
||||
def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test listing system variables successfully.
|
||||
|
||||
This test verifies that the service can filter and return only
|
||||
system variables, excluding conversation and node variables.
|
||||
System variables are internal variables used by the workflow
|
||||
engine for maintaining state and context.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
sys_var1_value = StringSegment(value=fake.word())
|
||||
sys_var2_value = StringSegment(value=fake.word())
|
||||
sys_var1 = self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
"sys_var1",
|
||||
sys_var1_value,
|
||||
variable_type=DraftVariableType.SYS,
|
||||
fake=fake,
|
||||
)
|
||||
sys_var2 = self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
"sys_var2",
|
||||
sys_var2_value,
|
||||
variable_type=DraftVariableType.SYS,
|
||||
fake=fake,
|
||||
)
|
||||
conv_var_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var", conv_var_value, fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_system_variables(app.id)
|
||||
assert len(result.variables) == 2
|
||||
for var in result.variables:
|
||||
assert var.node_id == SYSTEM_VARIABLE_NODE_ID
|
||||
assert var.app_id == app.id
|
||||
assert var.get_variable_type() == DraftVariableType.SYS
|
||||
var_names = [var.name for var in result.variables]
|
||||
assert "sys_var1" in var_names
|
||||
assert "sys_var2" in var_names
|
||||
assert "conv_var" not in var_names
|
||||
|
||||
def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting variables by name successfully for different types.
|
||||
|
||||
This test verifies that the service can retrieve variables by name
|
||||
for different variable types (conversation, system, node). This
|
||||
functionality is important for variable lookup operations during
|
||||
workflow execution and user interactions.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
test_value = StringSegment(value=fake.word())
|
||||
conv_var = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_conv_var", test_value, fake=fake
|
||||
)
|
||||
sys_var = self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
"test_sys_var",
|
||||
test_value,
|
||||
variable_type=DraftVariableType.SYS,
|
||||
fake=fake,
|
||||
)
|
||||
node_var = self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
"test_node",
|
||||
"test_node_var",
|
||||
test_value,
|
||||
variable_type=DraftVariableType.NODE,
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var")
|
||||
assert retrieved_conv_var is not None
|
||||
assert retrieved_conv_var.name == "test_conv_var"
|
||||
assert retrieved_conv_var.node_id == CONVERSATION_VARIABLE_NODE_ID
|
||||
retrieved_sys_var = service.get_system_variable(app.id, "test_sys_var")
|
||||
assert retrieved_sys_var is not None
|
||||
assert retrieved_sys_var.name == "test_sys_var"
|
||||
assert retrieved_sys_var.node_id == SYSTEM_VARIABLE_NODE_ID
|
||||
retrieved_node_var = service.get_node_variable(app.id, "test_node", "test_node_var")
|
||||
assert retrieved_node_var is not None
|
||||
assert retrieved_node_var.name == "test_node_var"
|
||||
assert retrieved_node_var.node_id == "test_node"
|
||||
|
||||
def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting variables by name when they don't exist.
|
||||
|
||||
This test verifies that the service returns None when trying to
|
||||
retrieve variables by name that don't exist. This ensures proper
|
||||
handling of missing variable scenarios for all variable types.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_var = service.get_conversation_variable(app.id, "non_existent_conv_var")
|
||||
assert retrieved_conv_var is None
|
||||
retrieved_sys_var = service.get_system_variable(app.id, "non_existent_sys_var")
|
||||
assert retrieved_sys_var is None
|
||||
retrieved_node_var = service.get_node_variable(app.id, "test_node", "non_existent_node_var")
|
||||
assert retrieved_node_var is None
|
||||
@@ -0,0 +1,713 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import (
|
||||
Message,
|
||||
)
|
||||
from models.workflow import WorkflowRun
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
class TestWorkflowRunService:
|
||||
"""Integration tests for WorkflowRunService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.app_service.FeatureService") as mock_feature_service,
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
):
|
||||
# Setup default mock returns for app service
|
||||
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
|
||||
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
|
||||
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
|
||||
|
||||
# Setup default mock returns for account service
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Mock ModelManager for model configuration
|
||||
mock_model_instance = mock_model_manager.return_value
|
||||
mock_model_instance.get_default_model_instance.return_value = None
|
||||
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||
|
||||
yield {
|
||||
"feature_service": mock_feature_service,
|
||||
"enterprise_service": mock_enterprise_service,
|
||||
"model_manager": mock_model_manager,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (app, account) - Created app and account instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Setup mocks for account creation
|
||||
mock_external_service_dependencies[
|
||||
"account_feature_service"
|
||||
].get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Create account and tenant
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
return app, account
|
||||
|
||||
def _create_test_workflow_run(
|
||||
self, db_session_with_containers, app, account, triggered_from="debugging", offset_minutes=0
|
||||
):
|
||||
"""
|
||||
Helper method to create a test workflow run for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app: App instance
|
||||
account: Account instance
|
||||
triggered_from: Trigger source for workflow run
|
||||
|
||||
Returns:
|
||||
WorkflowRun: Created workflow run instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create workflow run with offset timestamp
|
||||
base_time = datetime.now(UTC)
|
||||
created_time = base_time - timedelta(minutes=offset_minutes)
|
||||
|
||||
workflow_run = WorkflowRun(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=str(uuid.uuid4()),
|
||||
type="chat",
|
||||
triggered_from=triggered_from,
|
||||
version="1.0.0",
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
inputs=json.dumps({"input": "test"}),
|
||||
status="succeeded",
|
||||
outputs=json.dumps({"output": "test result"}),
|
||||
elapsed_time=1.5,
|
||||
total_tokens=100,
|
||||
total_steps=3,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=created_time,
|
||||
finished_at=created_time,
|
||||
)
|
||||
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _create_test_message(self, db_session_with_containers, app, account, workflow_run):
|
||||
"""
|
||||
Helper method to create a test message for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app: App instance
|
||||
account: Account instance
|
||||
workflow_run: WorkflowRun instance
|
||||
|
||||
Returns:
|
||||
Message: Created message instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create conversation first (required for message)
|
||||
from models.model import Conversation
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
name=fake.sentence(),
|
||||
inputs={},
|
||||
status="normal",
|
||||
mode="chat",
|
||||
from_source=CreatorUserRole.ACCOUNT,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
|
||||
# Create message
|
||||
message = Message()
|
||||
message.app_id = app.id
|
||||
message.conversation_id = conversation.id
|
||||
message.query = fake.text(max_nb_chars=100)
|
||||
message.message = {"type": "text", "content": fake.text(max_nb_chars=100)}
|
||||
message.answer = fake.text(max_nb_chars=200)
|
||||
message.message_tokens = 50
|
||||
message.answer_tokens = 100
|
||||
message.message_unit_price = 0.001
|
||||
message.answer_unit_price = 0.002
|
||||
message.message_price_unit = 0.001
|
||||
message.answer_price_unit = 0.001
|
||||
message.currency = "USD"
|
||||
message.status = "normal"
|
||||
message.from_source = CreatorUserRole.ACCOUNT
|
||||
message.from_account_id = account.id
|
||||
message.workflow_run_id = workflow_run.id
|
||||
message.inputs = {"input": "test input"}
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
|
||||
return message
|
||||
|
||||
def test_get_paginate_workflow_runs_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful pagination of workflow runs with debugging trigger.
|
||||
|
||||
This test verifies:
|
||||
- Proper pagination of workflow runs
|
||||
- Correct filtering by triggered_from
|
||||
- Proper limit and last_id handling
|
||||
- Repository method calls
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create multiple workflow runs
|
||||
workflow_runs = []
|
||||
for i in range(5):
|
||||
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
|
||||
workflow_runs.append(workflow_run)
|
||||
|
||||
# Act: Execute the method under test
|
||||
workflow_run_service = WorkflowRunService()
|
||||
args = {"limit": 3, "last_id": None}
|
||||
result = workflow_run_service.get_paginate_workflow_runs(app, args)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert hasattr(result, "data")
|
||||
assert len(result.data) == 3 # Should return 3 items due to limit
|
||||
|
||||
# Verify pagination properties
|
||||
assert hasattr(result, "has_more")
|
||||
assert hasattr(result, "limit")
|
||||
|
||||
# Verify all returned items are debugging runs
|
||||
for workflow_run in result.data:
|
||||
assert workflow_run.triggered_from == "debugging"
|
||||
assert workflow_run.app_id == app.id
|
||||
assert workflow_run.tenant_id == app.tenant_id
|
||||
|
||||
def test_get_paginate_workflow_runs_with_last_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination of workflow runs with last_id parameter.
|
||||
|
||||
This test verifies:
|
||||
- Proper pagination with last_id parameter
|
||||
- Correct handling of pagination state
|
||||
- Repository method calls with proper parameters
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create multiple workflow runs with different timestamps
|
||||
workflow_runs = []
|
||||
for i in range(5):
|
||||
workflow_run = self._create_test_workflow_run(
|
||||
db_session_with_containers, app, account, "debugging", offset_minutes=i
|
||||
)
|
||||
workflow_runs.append(workflow_run)
|
||||
|
||||
# Act: Execute the method under test with last_id
|
||||
workflow_run_service = WorkflowRunService()
|
||||
args = {"limit": 2, "last_id": workflow_runs[1].id}
|
||||
result = workflow_run_service.get_paginate_workflow_runs(app, args)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert hasattr(result, "data")
|
||||
assert len(result.data) == 2 # Should return 2 items due to limit
|
||||
|
||||
# Verify pagination properties
|
||||
assert hasattr(result, "has_more")
|
||||
assert hasattr(result, "limit")
|
||||
|
||||
# Verify all returned items are debugging runs
|
||||
for workflow_run in result.data:
|
||||
assert workflow_run.triggered_from == "debugging"
|
||||
assert workflow_run.app_id == app.id
|
||||
assert workflow_run.tenant_id == app.tenant_id
|
||||
|
||||
def test_get_paginate_workflow_runs_default_limit(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination of workflow runs with default limit.
|
||||
|
||||
This test verifies:
|
||||
- Default limit of 20 when not specified
|
||||
- Proper handling of missing limit parameter
|
||||
- Repository method calls with default values
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create workflow runs
|
||||
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
|
||||
|
||||
# Act: Execute the method under test without limit
|
||||
workflow_run_service = WorkflowRunService()
|
||||
args = {} # No limit specified
|
||||
result = workflow_run_service.get_paginate_workflow_runs(app, args)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert hasattr(result, "data")
|
||||
|
||||
# Verify pagination properties
|
||||
assert hasattr(result, "has_more")
|
||||
assert hasattr(result, "limit")
|
||||
|
||||
# Verify the returned workflow run
|
||||
if result.data:
|
||||
workflow_run_result = result.data[0]
|
||||
assert workflow_run_result.triggered_from == "debugging"
|
||||
assert workflow_run_result.app_id == app.id
|
||||
assert workflow_run_result.tenant_id == app.tenant_id
|
||||
|
||||
def test_get_paginate_advanced_chat_workflow_runs_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pagination of advanced chat workflow runs with message information.
|
||||
|
||||
This test verifies:
|
||||
- Proper pagination of advanced chat workflow runs
|
||||
- Correct filtering by triggered_from
|
||||
- Message information enrichment
|
||||
- WorkflowWithMessage wrapper functionality
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create workflow runs with messages
|
||||
workflow_runs = []
|
||||
for i in range(3):
|
||||
workflow_run = self._create_test_workflow_run(
|
||||
db_session_with_containers, app, account, "debugging", offset_minutes=i
|
||||
)
|
||||
message = self._create_test_message(db_session_with_containers, app, account, workflow_run)
|
||||
workflow_runs.append(workflow_run)
|
||||
|
||||
# Act: Execute the method under test
|
||||
workflow_run_service = WorkflowRunService()
|
||||
args = {"limit": 2, "last_id": None}
|
||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app, args)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert hasattr(result, "data")
|
||||
assert len(result.data) == 2 # Should return 2 items due to limit
|
||||
|
||||
# Verify pagination properties
|
||||
assert hasattr(result, "has_more")
|
||||
assert hasattr(result, "limit")
|
||||
|
||||
# Verify all returned items have message information
|
||||
for workflow_run in result.data:
|
||||
assert hasattr(workflow_run, "message_id")
|
||||
assert hasattr(workflow_run, "conversation_id")
|
||||
assert workflow_run.app_id == app.id
|
||||
assert workflow_run.tenant_id == app.tenant_id
|
||||
|
||||
def test_get_workflow_run_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of workflow run by ID.
|
||||
|
||||
This test verifies:
|
||||
- Proper workflow run retrieval by ID
|
||||
- Correct tenant and app isolation
|
||||
- Repository method calls with proper parameters
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create workflow run
|
||||
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
|
||||
|
||||
# Act: Execute the method under test
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_workflow_run(app, workflow_run.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result.id == workflow_run.id
|
||||
assert result.tenant_id == app.tenant_id
|
||||
assert result.app_id == app.id
|
||||
assert result.triggered_from == "debugging"
|
||||
assert result.status == "succeeded"
|
||||
assert result.type == "chat"
|
||||
assert result.version == "1.0.0"
|
||||
|
||||
def test_get_workflow_run_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test workflow run retrieval when run ID does not exist.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of non-existent workflow run IDs
|
||||
- Repository method calls with proper parameters
|
||||
- Return value for missing records
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Use a non-existent UUID
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
|
||||
# Act: Execute the method under test
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_workflow_run(app, non_existent_id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is None
|
||||
|
||||
def test_get_workflow_run_node_executions_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of workflow run node executions.
|
||||
|
||||
This test verifies:
|
||||
- Proper node execution retrieval for workflow run
|
||||
- Correct tenant and app isolation
|
||||
- Repository method calls with proper parameters
|
||||
- Context setup for plugin tool providers
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create workflow run
|
||||
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
|
||||
|
||||
# Create node executions
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
node_executions = []
|
||||
for i in range(3):
|
||||
node_execution = WorkflowNodeExecutionModel(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=workflow_run.id,
|
||||
index=i,
|
||||
node_id=f"node_{i}",
|
||||
node_type="llm" if i == 0 else "tool",
|
||||
title=f"Node {i}",
|
||||
inputs=json.dumps({"input": f"test_input_{i}"}),
|
||||
process_data=json.dumps({"process": f"test_process_{i}"}),
|
||||
status="succeeded",
|
||||
elapsed_time=0.5,
|
||||
execution_metadata=json.dumps({"tokens": 50}),
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(node_execution)
|
||||
node_executions.append(node_execution)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_workflow_run_node_executions(app, workflow_run.id, account)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert len(result) == 3
|
||||
|
||||
# Verify node execution properties
|
||||
for node_execution in result:
|
||||
assert node_execution.tenant_id == app.tenant_id
|
||||
assert node_execution.app_id == app.id
|
||||
assert node_execution.workflow_run_id == workflow_run.id
|
||||
assert node_execution.index in [0, 1, 2] # Check that index is one of the expected values
|
||||
assert node_execution.node_id.startswith("node_") # Check that node_id starts with "node_"
|
||||
assert node_execution.status == "succeeded"
|
||||
|
||||
def test_get_workflow_run_node_executions_empty(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting node executions for a workflow run with no executions.
|
||||
|
||||
This test verifies:
|
||||
- Empty result when no node executions exist
|
||||
- Proper handling of empty data
|
||||
- No errors when querying non-existent executions
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
account_service = AccountService()
|
||||
tenant_service = TenantService()
|
||||
app_service = AppService()
|
||||
workflow_run_service = WorkflowRunService()
|
||||
|
||||
# Create account and tenant
|
||||
account = account_service.create_account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
password="password123",
|
||||
interface_language="en-US",
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app
|
||||
app_args = {
|
||||
"name": "Test App",
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🚀",
|
||||
"icon_background": "#4ECDC4",
|
||||
}
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Create workflow run without node executions
|
||||
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
|
||||
|
||||
# Act: Get node executions
|
||||
result = workflow_run_service.get_workflow_run_node_executions(
|
||||
app_model=app,
|
||||
run_id=workflow_run.id,
|
||||
user=account,
|
||||
)
|
||||
|
||||
# Assert: Verify empty result
|
||||
assert result is not None
|
||||
assert len(result) == 0
|
||||
|
||||
def test_get_workflow_run_node_executions_invalid_workflow_run_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting node executions with invalid workflow run ID.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of invalid workflow run ID
|
||||
- Empty result when workflow run doesn't exist
|
||||
- No errors when querying with invalid ID
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
account_service = AccountService()
|
||||
tenant_service = TenantService()
|
||||
app_service = AppService()
|
||||
workflow_run_service = WorkflowRunService()
|
||||
|
||||
# Create account and tenant
|
||||
account = account_service.create_account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
password="password123",
|
||||
interface_language="en-US",
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app
|
||||
app_args = {
|
||||
"name": "Test App",
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🚀",
|
||||
"icon_background": "#4ECDC4",
|
||||
}
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Use invalid workflow run ID
|
||||
invalid_workflow_run_id = str(uuid.uuid4())
|
||||
|
||||
# Act: Get node executions with invalid ID
|
||||
result = workflow_run_service.get_workflow_run_node_executions(
|
||||
app_model=app,
|
||||
run_id=invalid_workflow_run_id,
|
||||
user=account,
|
||||
)
|
||||
|
||||
# Assert: Verify empty result
|
||||
assert result is not None
|
||||
assert len(result) == 0
|
||||
|
||||
def test_get_workflow_run_node_executions_database_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting node executions when database encounters an error.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling when database operations fail
|
||||
- Graceful degradation in error scenarios
|
||||
- Error propagation to calling code
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
account_service = AccountService()
|
||||
tenant_service = TenantService()
|
||||
app_service = AppService()
|
||||
workflow_run_service = WorkflowRunService()
|
||||
|
||||
# Create account and tenant
|
||||
account = account_service.create_account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
password="password123",
|
||||
interface_language="en-US",
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app
|
||||
app_args = {
|
||||
"name": "Test App",
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🚀",
|
||||
"icon_background": "#4ECDC4",
|
||||
}
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Create workflow run
|
||||
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
|
||||
|
||||
# Mock database error by closing the session
|
||||
db_session_with_containers.close()
|
||||
|
||||
# Act & Assert: Verify error handling
|
||||
with pytest.raises((Exception, RuntimeError)):
|
||||
workflow_run_service.get_workflow_run_node_executions(
|
||||
app_model=app,
|
||||
run_id=workflow_run.id,
|
||||
user=account,
|
||||
)
|
||||
|
||||
def test_get_workflow_run_node_executions_end_user(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test node execution retrieval for end user.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of end user vs account user
|
||||
- Correct tenant ID extraction for end users
|
||||
- Repository method calls with proper parameters
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create workflow run
|
||||
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
|
||||
|
||||
# Create end user
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
end_user = EndUser(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
type="web_app",
|
||||
is_anonymous=False,
|
||||
session_id=str(uuid.uuid4()),
|
||||
external_user_id=str(uuid.uuid4()),
|
||||
name=fake.name(),
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
# Create node execution
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
node_execution = WorkflowNodeExecutionModel(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=workflow_run.id,
|
||||
index=0,
|
||||
node_id="node_0",
|
||||
node_type="llm",
|
||||
title="Node 0",
|
||||
inputs=json.dumps({"input": "test_input"}),
|
||||
process_data=json.dumps({"process": "test_process"}),
|
||||
status="succeeded",
|
||||
elapsed_time=0.5,
|
||||
execution_metadata=json.dumps({"tokens": 50}),
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by=end_user.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(node_execution)
|
||||
db.session.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_workflow_run_node_executions(app, workflow_run.id, end_user)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
|
||||
# Verify node execution properties
|
||||
node_exec = result[0]
|
||||
assert node_exec.tenant_id == app.tenant_id
|
||||
assert node_exec.app_id == app.id
|
||||
assert node_exec.workflow_run_id == workflow_run.id
|
||||
assert node_exec.created_by == end_user.id
|
||||
assert node_exec.created_by_role == CreatorUserRole.END_USER
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,529 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
|
||||
class TestWorkspaceService:
|
||||
"""Integration tests for WorkspaceService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.workspace_service.FeatureService") as mock_feature_service,
|
||||
patch("services.workspace_service.TenantService") as mock_tenant_service,
|
||||
patch("services.workspace_service.dify_config") as mock_dify_config,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_feature_service.get_features.return_value.can_replace_logo = True
|
||||
mock_tenant_service.has_roles.return_value = True
|
||||
mock_dify_config.FILES_URL = "https://example.com/files"
|
||||
|
||||
yield {
|
||||
"feature_service": mock_feature_service,
|
||||
"tenant_service": mock_tenant_service,
|
||||
"dify_config": mock_dify_config,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
plan="basic",
|
||||
custom_config='{"replace_webapp_logo": true, "remove_webapp_brand": false}',
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join with owner role
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def test_get_tenant_info_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of tenant information with all features enabled.
|
||||
|
||||
This test verifies:
|
||||
- Proper tenant info retrieval with all required fields
|
||||
- Correct role assignment from TenantAccountJoin
|
||||
- Custom config handling when features are enabled
|
||||
- Logo replacement functionality for privileged users
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup mocks for feature service
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result["id"] == tenant.id
|
||||
assert result["name"] == tenant.name
|
||||
assert result["plan"] == tenant.plan
|
||||
assert result["status"] == tenant.status
|
||||
assert result["role"] == TenantAccountRole.OWNER
|
||||
assert result["created_at"] == tenant.created_at
|
||||
assert result["trial_end_reason"] is None
|
||||
|
||||
# Verify custom config is included for privileged users
|
||||
assert "custom_config" in result
|
||||
assert result["custom_config"]["remove_webapp_brand"] is False
|
||||
assert "replace_webapp_logo" in result["custom_config"]
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_without_custom_config(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval when custom config features are disabled.
|
||||
|
||||
This test verifies:
|
||||
- Tenant info retrieval without custom config when features are disabled
|
||||
- Proper handling of disabled logo replacement functionality
|
||||
- Role assignment still works correctly
|
||||
- Basic tenant information is complete
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup mocks to disable custom config features
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = False
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result["id"] == tenant.id
|
||||
assert result["name"] == tenant.name
|
||||
assert result["plan"] == tenant.plan
|
||||
assert result["status"] == tenant.status
|
||||
assert result["role"] == TenantAccountRole.OWNER
|
||||
assert result["created_at"] == tenant.created_at
|
||||
assert result["trial_end_reason"] is None
|
||||
|
||||
# Verify custom config is not included when features are disabled
|
||||
assert "custom_config" not in result
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_normal_user_role(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval for normal user role without privileged features.
|
||||
|
||||
This test verifies:
|
||||
- Tenant info retrieval for non-privileged users
|
||||
- Role assignment for normal users
|
||||
- Custom config is not accessible for normal users
|
||||
- Proper handling of different user roles
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Update the join to have normal role
|
||||
from extensions.ext_database import db
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
join.role = TenantAccountRole.NORMAL
|
||||
db.session.commit()
|
||||
|
||||
# Setup mocks for feature service
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result["id"] == tenant.id
|
||||
assert result["name"] == tenant.name
|
||||
assert result["plan"] == tenant.plan
|
||||
assert result["status"] == tenant.status
|
||||
assert result["role"] == TenantAccountRole.NORMAL
|
||||
assert result["created_at"] == tenant.created_at
|
||||
assert result["trial_end_reason"] is None
|
||||
|
||||
# Verify custom config is not included for normal users
|
||||
assert "custom_config" not in result
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_admin_role_and_logo_replacement(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval for admin role with logo replacement enabled.
|
||||
|
||||
This test verifies:
|
||||
- Admin role can access custom config features
|
||||
- Logo replacement functionality works for admin users
|
||||
- Proper URL construction for logo replacement
|
||||
- Custom config handling for admin role
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Update the join to have admin role
|
||||
from extensions.ext_database import db
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
join.role = TenantAccountRole.ADMIN
|
||||
db.session.commit()
|
||||
|
||||
# Setup mocks for feature service and tenant service
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
|
||||
mock_external_service_dependencies["dify_config"].FILES_URL = "https://cdn.example.com"
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result["role"] == TenantAccountRole.ADMIN
|
||||
|
||||
# Verify custom config is included for admin users
|
||||
assert "custom_config" in result
|
||||
assert result["custom_config"]["remove_webapp_brand"] is False
|
||||
assert "replace_webapp_logo" in result["custom_config"]
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_tenant_none(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test tenant info retrieval when tenant parameter is None.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of None tenant parameter
|
||||
- Method returns None for invalid input
|
||||
- No exceptions are raised for None input
|
||||
- Graceful degradation for invalid data
|
||||
"""
|
||||
# Arrange: No test data needed for this test
|
||||
|
||||
# Act: Execute the method under test with None tenant
|
||||
result = WorkspaceService.get_tenant_info(None)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is None
|
||||
|
||||
def test_get_tenant_info_with_custom_config_variations(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval with various custom config configurations.
|
||||
|
||||
This test verifies:
|
||||
- Different custom config combinations work correctly
|
||||
- Logo replacement URL construction with various configs
|
||||
- Brand removal functionality
|
||||
- Edge cases in custom config handling
|
||||
"""
|
||||
# Arrange: Create test data with different custom configs
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Test different custom config combinations
|
||||
test_configs = [
|
||||
# Case 1: Both logo and brand removal enabled
|
||||
{"replace_webapp_logo": True, "remove_webapp_brand": True},
|
||||
# Case 2: Only logo replacement enabled
|
||||
{"replace_webapp_logo": True, "remove_webapp_brand": False},
|
||||
# Case 3: Only brand removal enabled
|
||||
{"replace_webapp_logo": False, "remove_webapp_brand": True},
|
||||
# Case 4: Neither enabled
|
||||
{"replace_webapp_logo": False, "remove_webapp_brand": False},
|
||||
]
|
||||
|
||||
for config in test_configs:
|
||||
# Update tenant custom config
|
||||
import json
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
tenant.custom_config = json.dumps(config)
|
||||
db.session.commit()
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
|
||||
mock_external_service_dependencies["dify_config"].FILES_URL = "https://files.example.com"
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "custom_config" in result
|
||||
|
||||
if config["replace_webapp_logo"]:
|
||||
assert "replace_webapp_logo" in result["custom_config"]
|
||||
if config["replace_webapp_logo"]:
|
||||
expected_url = f"https://files.example.com/files/workspaces/{tenant.id}/webapp-logo"
|
||||
assert result["custom_config"]["replace_webapp_logo"] == expected_url
|
||||
else:
|
||||
assert result["custom_config"]["replace_webapp_logo"] is None
|
||||
|
||||
assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"]
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_editor_role_and_limited_permissions(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval for editor role with limited permissions.
|
||||
|
||||
This test verifies:
|
||||
- Editor role has limited access to custom config features
|
||||
- Proper role-based permission checking
|
||||
- Custom config handling for different role levels
|
||||
- Role hierarchy and permission boundaries
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Update the join to have editor role
|
||||
from extensions.ext_database import db
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
join.role = TenantAccountRole.EDITOR
|
||||
db.session.commit()
|
||||
|
||||
# Setup mocks for feature service and tenant service
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
# Editor role should not have admin/owner permissions
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
mock_external_service_dependencies["dify_config"].FILES_URL = "https://cdn.example.com"
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result["role"] == TenantAccountRole.EDITOR
|
||||
|
||||
# Verify custom config is not included for editor users without admin privileges
|
||||
assert "custom_config" not in result
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_dataset_operator_role(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval for dataset operator role.
|
||||
|
||||
This test verifies:
|
||||
- Dataset operator role handling
|
||||
- Role assignment for specialized roles
|
||||
- Permission boundaries for dataset operators
|
||||
- Custom config access for dataset operators
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Update the join to have dataset operator role
|
||||
from extensions.ext_database import db
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
join.role = TenantAccountRole.DATASET_OPERATOR
|
||||
db.session.commit()
|
||||
|
||||
# Setup mocks for feature service and tenant service
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
# Dataset operator should not have admin/owner permissions
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
mock_external_service_dependencies["dify_config"].FILES_URL = "https://cdn.example.com"
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result["role"] == TenantAccountRole.DATASET_OPERATOR
|
||||
|
||||
# Verify custom config is not included for dataset operators without admin privileges
|
||||
assert "custom_config" not in result
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_complex_custom_config_scenarios(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval with complex custom config scenarios.
|
||||
|
||||
This test verifies:
|
||||
- Complex custom config combinations
|
||||
- Edge cases in custom config handling
|
||||
- URL construction with various configs
|
||||
- Error handling for malformed configs
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Test complex custom config scenarios
|
||||
test_configs = [
|
||||
# Case 1: Empty custom config
|
||||
{},
|
||||
# Case 2: Custom config with only logo replacement
|
||||
{"replace_webapp_logo": True},
|
||||
# Case 3: Custom config with only brand removal
|
||||
{"remove_webapp_brand": True},
|
||||
# Case 4: Custom config with additional fields
|
||||
{
|
||||
"replace_webapp_logo": True,
|
||||
"remove_webapp_brand": False,
|
||||
"custom_field": "custom_value",
|
||||
"nested_config": {"key": "value"},
|
||||
},
|
||||
# Case 5: Custom config with null values
|
||||
{"replace_webapp_logo": None, "remove_webapp_brand": None},
|
||||
]
|
||||
|
||||
for config in test_configs:
|
||||
# Update tenant custom config
|
||||
import json
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
tenant.custom_config = json.dumps(config)
|
||||
db.session.commit()
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
|
||||
mock_external_service_dependencies["dify_config"].FILES_URL = "https://files.example.com"
|
||||
|
||||
# Mock current_user for flask_login
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
# Act: Execute the method under test
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert "custom_config" in result
|
||||
|
||||
# Verify logo replacement handling
|
||||
if config.get("replace_webapp_logo"):
|
||||
assert "replace_webapp_logo" in result["custom_config"]
|
||||
expected_url = f"https://files.example.com/files/workspaces/{tenant.id}/webapp-logo"
|
||||
assert result["custom_config"]["replace_webapp_logo"] == expected_url
|
||||
else:
|
||||
assert result["custom_config"]["replace_webapp_logo"] is None
|
||||
|
||||
# Verify brand removal handling
|
||||
if "remove_webapp_brand" in config:
|
||||
assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"]
|
||||
else:
|
||||
assert result["custom_config"]["remove_webapp_brand"] is False
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
@@ -0,0 +1,550 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models import Account, Tenant
|
||||
from models.tools import ApiToolProvider
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
|
||||
|
||||
class TestApiToolManageService:
|
||||
"""Integration tests for ApiToolManageService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.tools.api_tools_manage_service.ToolLabelManager") as mock_tool_label_manager,
|
||||
patch("services.tools.api_tools_manage_service.create_tool_provider_encrypter") as mock_encrypter,
|
||||
patch("services.tools.api_tools_manage_service.ApiToolProviderController") as mock_provider_controller,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_tool_label_manager.update_tool_labels.return_value = None
|
||||
mock_encrypter.return_value = (mock_encrypter, None)
|
||||
mock_encrypter.encrypt.return_value = {"encrypted": "credentials"}
|
||||
mock_provider_controller.from_db.return_value = mock_provider_controller
|
||||
mock_provider_controller.load_bundled_tools.return_value = None
|
||||
|
||||
yield {
|
||||
"tool_label_manager": mock_tool_label_manager,
|
||||
"encrypter": mock_encrypter,
|
||||
"provider_controller": mock_provider_controller,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_openapi_schema(self):
|
||||
"""Helper method to create a test OpenAPI schema."""
|
||||
return """
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": "Test API",
|
||||
"version": "1.0.0",
|
||||
"description": "Test API for testing purposes"
|
||||
},
|
||||
"servers": [
|
||||
{
|
||||
"url": "https://api.example.com",
|
||||
"description": "Production server"
|
||||
}
|
||||
],
|
||||
"paths": {
|
||||
"/test": {
|
||||
"get": {
|
||||
"operationId": "testOperation",
|
||||
"summary": "Test operation",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def test_parser_api_schema_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful parsing of API schema.
|
||||
|
||||
This test verifies:
|
||||
- Proper schema parsing with valid OpenAPI schema
|
||||
- Correct credentials schema generation
|
||||
- Proper warning handling
|
||||
- Return value structure
|
||||
"""
|
||||
# Arrange: Create test schema
|
||||
schema = self._create_test_openapi_schema()
|
||||
|
||||
# Act: Parse the schema
|
||||
result = ApiToolManageService.parser_api_schema(schema)
|
||||
|
||||
# Assert: Verify the result structure
|
||||
assert result is not None
|
||||
assert "schema_type" in result
|
||||
assert "parameters_schema" in result
|
||||
assert "credentials_schema" in result
|
||||
assert "warning" in result
|
||||
|
||||
# Verify credentials schema structure
|
||||
credentials_schema = result["credentials_schema"]
|
||||
assert len(credentials_schema) == 3
|
||||
|
||||
# Check auth_type field
|
||||
auth_type_field = next(field for field in credentials_schema if field["name"] == "auth_type")
|
||||
assert auth_type_field["required"] is True
|
||||
assert auth_type_field["default"] == "none"
|
||||
assert len(auth_type_field["options"]) == 2
|
||||
|
||||
# Check api_key_header field
|
||||
api_key_header_field = next(field for field in credentials_schema if field["name"] == "api_key_header")
|
||||
assert api_key_header_field["required"] is False
|
||||
assert api_key_header_field["default"] == "api_key"
|
||||
|
||||
# Check api_key_value field
|
||||
api_key_value_field = next(field for field in credentials_schema if field["name"] == "api_key_value")
|
||||
assert api_key_value_field["required"] is False
|
||||
assert api_key_value_field["default"] == ""
|
||||
|
||||
def test_parser_api_schema_invalid_schema(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test parsing of invalid API schema.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for invalid schemas
|
||||
- Correct exception type and message
|
||||
- Error propagation from underlying parser
|
||||
"""
|
||||
# Arrange: Create invalid schema
|
||||
invalid_schema = "invalid json schema"
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.parser_api_schema(invalid_schema)
|
||||
|
||||
assert "invalid schema" in str(exc_info.value)
|
||||
|
||||
def test_parser_api_schema_malformed_json(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test parsing of malformed JSON schema.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for malformed JSON
|
||||
- Correct exception type and message
|
||||
- Error propagation from JSON parsing
|
||||
"""
|
||||
# Arrange: Create malformed JSON schema
|
||||
malformed_schema = '{"openapi": "3.0.0", "info": {"title": "Test", "version": "1.0.0"}, "paths": {}}'
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.parser_api_schema(malformed_schema)
|
||||
|
||||
assert "invalid schema" in str(exc_info.value)
|
||||
|
||||
def test_convert_schema_to_tool_bundles_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of schema to tool bundles.
|
||||
|
||||
This test verifies:
|
||||
- Proper schema conversion with valid OpenAPI schema
|
||||
- Correct tool bundles generation
|
||||
- Proper schema type detection
|
||||
- Return value structure
|
||||
"""
|
||||
# Arrange: Create test schema
|
||||
schema = self._create_test_openapi_schema()
|
||||
|
||||
# Act: Convert schema to tool bundles
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema)
|
||||
|
||||
# Assert: Verify the result structure
|
||||
assert tool_bundles is not None
|
||||
assert isinstance(tool_bundles, list)
|
||||
assert len(tool_bundles) > 0
|
||||
assert schema_type is not None
|
||||
assert isinstance(schema_type, str)
|
||||
|
||||
# Verify tool bundle structure
|
||||
tool_bundle = tool_bundles[0]
|
||||
assert hasattr(tool_bundle, "operation_id")
|
||||
assert tool_bundle.operation_id == "testOperation"
|
||||
|
||||
def test_convert_schema_to_tool_bundles_with_extra_info(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of schema to tool bundles with extra info.
|
||||
|
||||
This test verifies:
|
||||
- Proper schema conversion with extra info parameter
|
||||
- Correct tool bundles generation
|
||||
- Extra info handling
|
||||
- Return value structure
|
||||
"""
|
||||
# Arrange: Create test schema and extra info
|
||||
schema = self._create_test_openapi_schema()
|
||||
extra_info = {"description": "Custom description", "version": "2.0.0"}
|
||||
|
||||
# Act: Convert schema to tool bundles with extra info
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
# Assert: Verify the result structure
|
||||
assert tool_bundles is not None
|
||||
assert isinstance(tool_bundles, list)
|
||||
assert len(tool_bundles) > 0
|
||||
assert schema_type is not None
|
||||
assert isinstance(schema_type, str)
|
||||
|
||||
def test_convert_schema_to_tool_bundles_invalid_schema(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test conversion of invalid schema to tool bundles.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for invalid schemas
|
||||
- Correct exception type and message
|
||||
- Error propagation from underlying parser
|
||||
"""
|
||||
# Arrange: Create invalid schema
|
||||
invalid_schema = "invalid schema content"
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.convert_schema_to_tool_bundles(invalid_schema)
|
||||
|
||||
assert "invalid schema" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful creation of API tool provider.
|
||||
|
||||
This test verifies:
|
||||
- Proper provider creation with valid parameters
|
||||
- Correct database state after creation
|
||||
- Proper relationship establishment
|
||||
- External service integration
|
||||
- Return value correctness
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""}
|
||||
schema_type = "openapi"
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
labels = ["test", "api"]
|
||||
|
||||
# Act: Create API tool provider
|
||||
result = ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
# Assert: Verify the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert provider is not None
|
||||
assert provider.name == provider_name
|
||||
assert provider.tenant_id == tenant.id
|
||||
assert provider.user_id == account.id
|
||||
assert provider.schema_type_str == schema_type
|
||||
assert provider.privacy_policy == privacy_policy
|
||||
assert provider.custom_disclaimer == custom_disclaimer
|
||||
|
||||
# Verify mock interactions
|
||||
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()
|
||||
mock_external_service_dependencies["encrypter"].assert_called_once()
|
||||
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
|
||||
mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once()
|
||||
|
||||
def test_create_api_tool_provider_duplicate_name(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creation of API tool provider with duplicate name.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for duplicate provider names
|
||||
- Correct exception type and message
|
||||
- Database constraint enforcement
|
||||
"""
|
||||
# Arrange: Create test data and existing provider
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {"auth_type": "none"}
|
||||
schema_type = "openapi"
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
labels = ["test"]
|
||||
|
||||
# Create first provider
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
# Act & Assert: Try to create duplicate provider
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
assert f"provider {provider_name} already exists" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_invalid_schema_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creation of API tool provider with invalid schema type.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for invalid schema types
|
||||
- Correct exception type and message
|
||||
- Schema type validation
|
||||
"""
|
||||
# Arrange: Create test data with invalid schema type
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {"auth_type": "none"}
|
||||
schema_type = "invalid_type"
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
labels = ["test"]
|
||||
|
||||
# Act & Assert: Try to create provider with invalid schema type
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
assert "invalid schema type" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_missing_auth_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creation of API tool provider with missing auth type.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing auth type
|
||||
- Correct exception type and message
|
||||
- Credentials validation
|
||||
"""
|
||||
# Arrange: Create test data with missing auth type
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {} # Missing auth_type
|
||||
schema_type = "openapi"
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
labels = ["test"]
|
||||
|
||||
# Act & Assert: Try to create provider with missing auth type
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
assert "auth_type is required" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_with_api_key_auth(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful creation of API tool provider with API key authentication.
|
||||
|
||||
This test verifies:
|
||||
- Proper provider creation with API key auth
|
||||
- Correct credentials handling
|
||||
- Proper authentication type processing
|
||||
"""
|
||||
# Arrange: Create test data with API key auth
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔑"}
|
||||
credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()}
|
||||
schema_type = "openapi"
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
labels = ["api_key", "secure"]
|
||||
|
||||
# Act: Create API tool provider
|
||||
result = ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
# Assert: Verify the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert provider is not None
|
||||
assert provider.name == provider_name
|
||||
assert provider.tenant_id == tenant.id
|
||||
assert provider.user_id == account.id
|
||||
assert provider.schema_type_str == schema_type
|
||||
|
||||
# Verify mock interactions
|
||||
mock_external_service_dependencies["encrypter"].assert_called_once()
|
||||
mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,788 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class TestToolTransformService:
|
||||
"""Integration tests for ToolTransformService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with patch("services.tools.tools_transform_service.dify_config") as mock_dify_config:
|
||||
with patch("services.plugin.plugin_service.dify_config", new=mock_dify_config):
|
||||
# Setup default mock returns
|
||||
mock_dify_config.CONSOLE_API_URL = "https://console.example.com"
|
||||
|
||||
yield {
|
||||
"dify_config": mock_dify_config,
|
||||
}
|
||||
|
||||
def _create_test_tool_provider(
|
||||
self, db_session_with_containers, mock_external_service_dependencies, provider_type="api"
|
||||
):
|
||||
"""
|
||||
Helper method to create a test tool provider for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
provider_type: Type of provider to create
|
||||
|
||||
Returns:
|
||||
Tool provider instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
if provider_type == "api":
|
||||
provider = ApiToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
icon_dark='{"background": "#252525", "content": "🔧"}',
|
||||
tenant_id="test_tenant_id",
|
||||
user_id="test_user_id",
|
||||
credentials={"auth_type": "api_key_header", "api_key": "test_key"},
|
||||
provider_type="api",
|
||||
)
|
||||
elif provider_type == "builtin":
|
||||
provider = BuiltinToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon="🔧",
|
||||
icon_dark="🔧",
|
||||
tenant_id="test_tenant_id",
|
||||
provider="test_provider",
|
||||
credential_type="api_key",
|
||||
credentials={"api_key": "test_key"},
|
||||
)
|
||||
elif provider_type == "workflow":
|
||||
provider = WorkflowToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
icon_dark='{"background": "#252525", "content": "🔧"}',
|
||||
tenant_id="test_tenant_id",
|
||||
user_id="test_user_id",
|
||||
workflow_id="test_workflow_id",
|
||||
)
|
||||
elif provider_type == "mcp":
|
||||
provider = MCPToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
provider_icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
tenant_id="test_tenant_id",
|
||||
user_id="test_user_id",
|
||||
server_url="https://mcp.example.com",
|
||||
server_identifier="test_server",
|
||||
tools='[{"name": "test_tool", "description": "Test tool"}]',
|
||||
authed=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
return provider
|
||||
|
||||
def test_get_plugin_icon_url_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful plugin icon URL generation.
|
||||
|
||||
This test verifies:
|
||||
- Proper URL construction for plugin icons
|
||||
- Correct tenant_id and filename handling
|
||||
- URL format compliance
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
tenant_id = fake.uuid4()
|
||||
filename = "test_icon.png"
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = PluginService.get_plugin_icon_url(str(tenant_id), filename)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert "console/api/workspaces/current/plugin/icon" in result
|
||||
assert str(tenant_id) in result
|
||||
assert filename in result
|
||||
assert result.startswith("https://console.example.com")
|
||||
|
||||
# Verify URL structure
|
||||
expected_url = f"https://console.example.com/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={filename}"
|
||||
assert result == expected_url
|
||||
|
||||
def test_get_plugin_icon_url_with_empty_console_url(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test plugin icon URL generation when CONSOLE_API_URL is empty.
|
||||
|
||||
This test verifies:
|
||||
- Fallback to relative URL when CONSOLE_API_URL is None
|
||||
- Proper URL construction with relative path
|
||||
"""
|
||||
# Arrange: Setup mock with empty console URL
|
||||
mock_external_service_dependencies["dify_config"].CONSOLE_API_URL = None
|
||||
fake = Faker()
|
||||
tenant_id = fake.uuid4()
|
||||
filename = "test_icon.png"
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = PluginService.get_plugin_icon_url(str(tenant_id), filename)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert result.startswith("/console/api/workspaces/current/plugin/icon")
|
||||
assert str(tenant_id) in result
|
||||
assert filename in result
|
||||
|
||||
# Verify URL structure
|
||||
expected_url = f"/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={filename}"
|
||||
assert result == expected_url
|
||||
|
||||
def test_get_tool_provider_icon_url_builtin_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful tool provider icon URL generation for builtin providers.
|
||||
|
||||
This test verifies:
|
||||
- Proper URL construction for builtin tool providers
|
||||
- Correct provider type handling
|
||||
- URL format compliance
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
provider_type = ToolProviderType.BUILT_IN
|
||||
provider_name = fake.company()
|
||||
icon = "🔧"
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert "console/api/workspaces/current/tool-provider/builtin" in result
|
||||
# Note: provider_name may contain spaces that get URL encoded
|
||||
assert provider_name.replace(" ", "%20") in result or provider_name in result
|
||||
assert result.endswith("/icon")
|
||||
assert result.startswith("https://console.example.com")
|
||||
|
||||
# Verify URL structure (accounting for URL encoding)
|
||||
# The actual result will have URL-encoded spaces (%20), so we need to compare accordingly
|
||||
expected_url = (
|
||||
f"https://console.example.com/console/api/workspaces/current/tool-provider/builtin/{provider_name}/icon"
|
||||
)
|
||||
# Convert expected URL to match the actual URL encoding
|
||||
expected_encoded = expected_url.replace(" ", "%20")
|
||||
assert result == expected_encoded
|
||||
|
||||
def test_get_tool_provider_icon_url_api_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful tool provider icon URL generation for API providers.
|
||||
|
||||
This test verifies:
|
||||
- Proper icon handling for API tool providers
|
||||
- JSON string parsing for icon data
|
||||
- Fallback icon when parsing fails
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
provider_type = ToolProviderType.API
|
||||
provider_name = fake.company()
|
||||
icon = '{"background": "#FF6B6B", "content": "🔧"}'
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert isinstance(result, dict)
|
||||
assert result["background"] == "#FF6B6B"
|
||||
assert result["content"] == "🔧"
|
||||
|
||||
def test_get_tool_provider_icon_url_api_invalid_json(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tool provider icon URL generation for API providers with invalid JSON.
|
||||
|
||||
This test verifies:
|
||||
- Proper fallback when JSON parsing fails
|
||||
- Default icon structure when exception occurs
|
||||
"""
|
||||
# Arrange: Setup test data with invalid JSON
|
||||
fake = Faker()
|
||||
provider_type = ToolProviderType.API
|
||||
provider_name = fake.company()
|
||||
icon = '{"invalid": json}'
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert isinstance(result, dict)
|
||||
assert result["background"] == "#252525"
|
||||
# Note: emoji characters may be represented as Unicode escape sequences
|
||||
assert result["content"] == "😁" or result["content"] == "\ud83d\ude01"
|
||||
|
||||
def test_get_tool_provider_icon_url_workflow_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful tool provider icon URL generation for workflow providers.
|
||||
|
||||
This test verifies:
|
||||
- Proper icon handling for workflow tool providers
|
||||
- Direct icon return for workflow type
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
provider_type = ToolProviderType.WORKFLOW
|
||||
provider_name = fake.company()
|
||||
icon = {"background": "#FF6B6B", "content": "🔧"}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert isinstance(result, dict)
|
||||
assert result["background"] == "#FF6B6B"
|
||||
assert result["content"] == "🔧"
|
||||
|
||||
def test_get_tool_provider_icon_url_mcp_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful tool provider icon URL generation for MCP providers.
|
||||
|
||||
This test verifies:
|
||||
- Direct icon return for MCP type
|
||||
- No URL transformation for MCP providers
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
provider_type = ToolProviderType.MCP
|
||||
provider_name = fake.company()
|
||||
icon = {"background": "#FF6B6B", "content": "🔧"}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert isinstance(result, dict)
|
||||
assert result["background"] == "#FF6B6B"
|
||||
assert result["content"] == "🔧"
|
||||
|
||||
def test_get_tool_provider_icon_url_unknown_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tool provider icon URL generation for unknown provider types.
|
||||
|
||||
This test verifies:
|
||||
- Empty string return for unknown provider types
|
||||
- Proper handling of unsupported types
|
||||
"""
|
||||
# Arrange: Setup test data with unknown type
|
||||
fake = Faker()
|
||||
provider_type = "unknown_type"
|
||||
provider_name = fake.company()
|
||||
icon = "🔧"
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result == ""
|
||||
|
||||
def test_repack_provider_dict_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful provider repacking with dictionary input.
|
||||
|
||||
This test verifies:
|
||||
- Proper icon URL generation for dictionary providers
|
||||
- Correct provider type handling
|
||||
- Icon transformation for different provider types
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
tenant_id = fake.uuid4()
|
||||
provider = {"type": ToolProviderType.BUILT_IN, "name": fake.company(), "icon": "🔧"}
|
||||
|
||||
# Act: Execute the method under test
|
||||
ToolTransformService.repack_provider(str(tenant_id), provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert "icon" in provider
|
||||
assert isinstance(provider["icon"], str)
|
||||
assert "console/api/workspaces/current/tool-provider/builtin" in provider["icon"]
|
||||
# Note: provider name may contain spaces that get URL encoded
|
||||
assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"]
|
||||
|
||||
def test_repack_provider_entity_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful provider repacking with ToolProviderApiEntity input.
|
||||
|
||||
This test verifies:
|
||||
- Proper icon URL generation for entity providers
|
||||
- Plugin icon handling when plugin_id is present
|
||||
- Regular icon handling when plugin_id is not present
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
tenant_id = fake.uuid4()
|
||||
|
||||
# Create provider entity with plugin_id
|
||||
provider = ToolProviderApiEntity(
|
||||
id=str(fake.uuid4()),
|
||||
author=fake.name(),
|
||||
name=fake.company(),
|
||||
description=I18nObject(en_US=fake.text(max_nb_chars=100)),
|
||||
icon="test_icon.png",
|
||||
icon_dark="test_icon_dark.png",
|
||||
label=I18nObject(en_US=fake.company()),
|
||||
type=ToolProviderType.API,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
plugin_id="test_plugin_id",
|
||||
tools=[],
|
||||
labels=[],
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
ToolTransformService.repack_provider(tenant_id, provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert provider.icon is not None
|
||||
assert isinstance(provider.icon, str)
|
||||
assert "console/api/workspaces/current/plugin/icon" in provider.icon
|
||||
assert str(tenant_id) in provider.icon
|
||||
assert "test_icon.png" in provider.icon
|
||||
|
||||
# Verify dark icon handling
|
||||
assert provider.icon_dark is not None
|
||||
assert isinstance(provider.icon_dark, str)
|
||||
assert "console/api/workspaces/current/plugin/icon" in provider.icon_dark
|
||||
assert str(tenant_id) in provider.icon_dark
|
||||
assert "test_icon_dark.png" in provider.icon_dark
|
||||
|
||||
def test_repack_provider_entity_no_plugin_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful provider repacking with ToolProviderApiEntity input without plugin_id.
|
||||
|
||||
This test verifies:
|
||||
- Proper icon URL generation for non-plugin providers
|
||||
- Regular tool provider icon handling
|
||||
- Dark icon handling when present
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
tenant_id = fake.uuid4()
|
||||
|
||||
# Create provider entity without plugin_id
|
||||
provider = ToolProviderApiEntity(
|
||||
id=fake.uuid4(),
|
||||
author=fake.name(),
|
||||
name=fake.company(),
|
||||
description=I18nObject(en_US=fake.text(max_nb_chars=100)),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
icon_dark='{"background": "#252525", "content": "🔧"}',
|
||||
label=I18nObject(en_US=fake.company()),
|
||||
type=ToolProviderType.API,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
plugin_id=None,
|
||||
tools=[],
|
||||
labels=[],
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
ToolTransformService.repack_provider(str(tenant_id), provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert provider.icon is not None
|
||||
assert isinstance(provider.icon, dict)
|
||||
assert provider.icon["background"] == "#FF6B6B"
|
||||
assert provider.icon["content"] == "🔧"
|
||||
|
||||
# Verify dark icon handling
|
||||
assert provider.icon_dark is not None
|
||||
assert isinstance(provider.icon_dark, dict)
|
||||
assert provider.icon_dark["background"] == "#252525"
|
||||
assert provider.icon_dark["content"] == "🔧"
|
||||
|
||||
def test_repack_provider_entity_no_dark_icon(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test provider repacking with ToolProviderApiEntity input without dark icon.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling when icon_dark is None or empty
|
||||
- No errors when dark icon is not present
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
tenant_id = fake.uuid4()
|
||||
|
||||
# Create provider entity without dark icon
|
||||
provider = ToolProviderApiEntity(
|
||||
id=fake.uuid4(),
|
||||
author=fake.name(),
|
||||
name=fake.company(),
|
||||
description=I18nObject(en_US=fake.text(max_nb_chars=100)),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
icon_dark="",
|
||||
label=I18nObject(en_US=fake.company()),
|
||||
type=ToolProviderType.API,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
plugin_id=None,
|
||||
tools=[],
|
||||
labels=[],
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
ToolTransformService.repack_provider(tenant_id, provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert provider.icon is not None
|
||||
assert isinstance(provider.icon, dict)
|
||||
assert provider.icon["background"] == "#FF6B6B"
|
||||
assert provider.icon["content"] == "🔧"
|
||||
|
||||
# Verify dark icon remains empty string
|
||||
assert provider.icon_dark == ""
|
||||
|
||||
def test_builtin_provider_to_user_provider_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of builtin provider to user provider.
|
||||
|
||||
This test verifies:
|
||||
- Proper entity creation with all required fields
|
||||
- Credentials schema handling
|
||||
- Team authorization setup
|
||||
- Plugin ID handling
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
|
||||
# Create mock provider controller
|
||||
mock_controller = Mock()
|
||||
mock_controller.entity.identity.name = fake.company()
|
||||
mock_controller.entity.identity.author = fake.name()
|
||||
mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
|
||||
mock_controller.entity.identity.icon = "🔧"
|
||||
mock_controller.entity.identity.icon_dark = "🔧"
|
||||
mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
|
||||
mock_controller.plugin_id = None
|
||||
mock_controller.plugin_unique_identifier = None
|
||||
mock_controller.tool_labels = ["label1", "label2"]
|
||||
mock_controller.need_credentials = True
|
||||
|
||||
# Mock credentials schema
|
||||
mock_credential = Mock()
|
||||
mock_credential.to_basic_provider_config.return_value.name = "api_key"
|
||||
mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
|
||||
|
||||
# Create mock database provider
|
||||
mock_db_provider = Mock()
|
||||
mock_db_provider.credential_type = "api-key"
|
||||
mock_db_provider.tenant_id = fake.uuid4()
|
||||
mock_db_provider.credentials = {"api_key": "encrypted_key"}
|
||||
|
||||
# Mock encryption
|
||||
with patch("services.tools.tools_transform_service.create_provider_encrypter") as mock_encrypter:
|
||||
mock_encrypter_instance = Mock()
|
||||
mock_encrypter_instance.decrypt.return_value = {"api_key": "decrypted_key"}
|
||||
mock_encrypter_instance.mask_plugin_credentials.return_value = {"api_key": ""}
|
||||
mock_encrypter.return_value = (mock_encrypter_instance, None)
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.builtin_provider_to_user_provider(
|
||||
mock_controller, mock_db_provider, decrypt_credentials=True
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result.id == mock_controller.entity.identity.name
|
||||
assert result.author == mock_controller.entity.identity.author
|
||||
assert result.name == mock_controller.entity.identity.name
|
||||
assert result.description == mock_controller.entity.identity.description
|
||||
assert result.icon == mock_controller.entity.identity.icon
|
||||
assert result.icon_dark == mock_controller.entity.identity.icon_dark
|
||||
assert result.label == mock_controller.entity.identity.label
|
||||
assert result.type == ToolProviderType.BUILT_IN
|
||||
assert result.is_team_authorization is True
|
||||
assert result.plugin_id is None
|
||||
assert result.tools == []
|
||||
assert result.labels == ["label1", "label2"]
|
||||
assert result.masked_credentials == {"api_key": ""}
|
||||
assert result.original_credentials == {"api_key": "decrypted_key"}
|
||||
|
||||
def test_builtin_provider_to_user_provider_plugin_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of builtin provider to user provider with plugin.
|
||||
|
||||
This test verifies:
|
||||
- Plugin ID and unique identifier handling
|
||||
- Proper entity creation for plugin providers
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
|
||||
# Create mock provider controller with plugin
|
||||
mock_controller = Mock()
|
||||
mock_controller.entity.identity.name = fake.company()
|
||||
mock_controller.entity.identity.author = fake.name()
|
||||
mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
|
||||
mock_controller.entity.identity.icon = "🔧"
|
||||
mock_controller.entity.identity.icon_dark = "🔧"
|
||||
mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
|
||||
mock_controller.plugin_id = "test_plugin_id"
|
||||
mock_controller.plugin_unique_identifier = "test_unique_id"
|
||||
mock_controller.tool_labels = ["label1"]
|
||||
mock_controller.need_credentials = False
|
||||
|
||||
# Mock credentials schema
|
||||
mock_credential = Mock()
|
||||
mock_credential.to_basic_provider_config.return_value.name = "api_key"
|
||||
mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.builtin_provider_to_user_provider(
|
||||
mock_controller, None, decrypt_credentials=False
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
# Note: The method checks isinstance(provider_controller, PluginToolProviderController)
|
||||
# Since we're using a Mock, this check will fail, so plugin_id will remain None
|
||||
# In a real test with actual PluginToolProviderController, this would work
|
||||
assert result.is_team_authorization is True
|
||||
assert result.allow_delete is False
|
||||
|
||||
def test_builtin_provider_to_user_provider_no_credentials(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test conversion of builtin provider to user provider without credentials.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling when no credentials are needed
|
||||
- Team authorization setup for no-credentials providers
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
|
||||
# Create mock provider controller
|
||||
mock_controller = Mock()
|
||||
mock_controller.entity.identity.name = fake.company()
|
||||
mock_controller.entity.identity.author = fake.name()
|
||||
mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
|
||||
mock_controller.entity.identity.icon = "🔧"
|
||||
mock_controller.entity.identity.icon_dark = "🔧"
|
||||
mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
|
||||
mock_controller.plugin_id = None
|
||||
mock_controller.plugin_unique_identifier = None
|
||||
mock_controller.tool_labels = []
|
||||
mock_controller.need_credentials = False
|
||||
|
||||
# Mock credentials schema
|
||||
mock_credential = Mock()
|
||||
mock_credential.to_basic_provider_config.return_value.name = "api_key"
|
||||
mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.builtin_provider_to_user_provider(
|
||||
mock_controller, None, decrypt_credentials=False
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result.is_team_authorization is True
|
||||
assert result.allow_delete is False
|
||||
assert result.masked_credentials == {"api_key": ""}
|
||||
|
||||
def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful conversion of API provider to controller.
|
||||
|
||||
This test verifies:
|
||||
- Proper controller creation from database provider
|
||||
- Auth type handling for different credential types
|
||||
- Backward compatibility for auth types
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
|
||||
# Create API tool provider with api_key_header auth
|
||||
provider = ApiToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
tenant_id=fake.uuid4(),
|
||||
user_id=fake.uuid4(),
|
||||
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.api_provider_to_controller(provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert hasattr(result, "from_db")
|
||||
# Additional assertions would depend on the actual controller implementation
|
||||
|
||||
def test_api_provider_to_controller_api_key_query(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test conversion of API provider to controller with api_key_query auth type.
|
||||
|
||||
This test verifies:
|
||||
- Proper auth type handling for query parameter authentication
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
|
||||
# Create API tool provider with api_key_query auth
|
||||
provider = ApiToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
tenant_id=fake.uuid4(),
|
||||
user_id=fake.uuid4(),
|
||||
credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.api_provider_to_controller(provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert hasattr(result, "from_db")
|
||||
|
||||
def test_api_provider_to_controller_backward_compatibility(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test conversion of API provider to controller with backward compatibility auth types.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of legacy auth type values
|
||||
- Backward compatibility for api_key and api_key_header
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
|
||||
# Create API tool provider with legacy auth type
|
||||
provider = ApiToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
tenant_id=fake.uuid4(),
|
||||
user_id=fake.uuid4(),
|
||||
credentials_str='{"auth_type": "api_key", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.api_provider_to_controller(provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert hasattr(result, "from_db")
|
||||
|
||||
def test_workflow_provider_to_controller_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of workflow provider to controller.
|
||||
|
||||
This test verifies:
|
||||
- Proper controller creation from workflow provider
|
||||
- Workflow-specific controller handling
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
|
||||
# Create workflow tool provider
|
||||
provider = WorkflowToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
tenant_id=fake.uuid4(),
|
||||
user_id=fake.uuid4(),
|
||||
app_id=fake.uuid4(),
|
||||
label="Test Workflow",
|
||||
version="1.0.0",
|
||||
parameter_configuration="[]",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# Mock the WorkflowToolProviderController.from_db method to avoid app dependency
|
||||
with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db:
|
||||
mock_controller = Mock()
|
||||
mock_from_db.return_value = mock_controller
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.workflow_provider_to_controller(provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result == mock_controller
|
||||
mock_from_db.assert_called_once_with(provider)
|
||||
@@ -0,0 +1,716 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow as WorkflowModel
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
|
||||
|
||||
class TestWorkflowToolManageService:
|
||||
"""Integration tests for WorkflowToolManageService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.app_service.FeatureService") as mock_feature_service,
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
patch(
|
||||
"services.tools.workflow_tools_manage_service.WorkflowToolProviderController"
|
||||
) as mock_workflow_tool_provider_controller,
|
||||
patch("services.tools.workflow_tools_manage_service.ToolLabelManager") as mock_tool_label_manager,
|
||||
patch("services.tools.workflow_tools_manage_service.ToolTransformService") as mock_tool_transform_service,
|
||||
):
|
||||
# Setup default mock returns for app service
|
||||
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
|
||||
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
|
||||
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
|
||||
|
||||
# Setup default mock returns for account service
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Mock ModelManager for model configuration
|
||||
mock_model_instance = mock_model_manager.return_value
|
||||
mock_model_instance.get_default_model_instance.return_value = None
|
||||
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||
|
||||
# Mock WorkflowToolProviderController
|
||||
mock_workflow_tool_provider_controller.from_db.return_value = None
|
||||
|
||||
# Mock ToolLabelManager
|
||||
mock_tool_label_manager.update_tool_labels.return_value = None
|
||||
|
||||
# Mock ToolTransformService
|
||||
mock_tool_transform_service.workflow_provider_to_controller.return_value = None
|
||||
|
||||
yield {
|
||||
"feature_service": mock_feature_service,
|
||||
"enterprise_service": mock_enterprise_service,
|
||||
"model_manager": mock_model_manager,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"workflow_tool_provider_controller": mock_workflow_tool_provider_controller,
|
||||
"tool_label_manager": mock_tool_label_manager,
|
||||
"tool_transform_service": mock_tool_transform_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (app, account, workflow) - Created app, account and workflow instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Setup mocks for account creation
|
||||
mock_external_service_dependencies[
|
||||
"account_feature_service"
|
||||
].get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Create account and tenant
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "workflow",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Create workflow for the app
|
||||
workflow = WorkflowModel(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
type="workflow",
|
||||
version="1.0.0",
|
||||
graph=json.dumps({}),
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
|
||||
# Update app to reference the workflow
|
||||
app.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
|
||||
return app, account, workflow
|
||||
|
||||
def _create_test_workflow_tool_parameters(self):
|
||||
"""Helper method to create valid workflow tool parameters."""
|
||||
return [
|
||||
{
|
||||
"name": "input_text",
|
||||
"description": "Input text for processing",
|
||||
"form": "form",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "output_format",
|
||||
"description": "Output format specification",
|
||||
"form": "form",
|
||||
"type": "select",
|
||||
"required": False,
|
||||
},
|
||||
]
|
||||
|
||||
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful workflow tool creation with valid parameters.
|
||||
|
||||
This test verifies:
|
||||
- Proper workflow tool creation with all required fields
|
||||
- Correct database state after creation
|
||||
- Proper relationship establishment
|
||||
- External service integration
|
||||
- Return value correctness
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup workflow tool creation parameters
|
||||
tool_name = fake.word()
|
||||
tool_label = fake.word()
|
||||
tool_icon = {"type": "emoji", "emoji": "🔧"}
|
||||
tool_description = fake.text(max_nb_chars=200)
|
||||
tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
tool_privacy_policy = fake.text(max_nb_chars=100)
|
||||
tool_labels = ["automation", "workflow"]
|
||||
|
||||
# Execute the method under test
|
||||
result = WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=tool_name,
|
||||
label=tool_label,
|
||||
icon=tool_icon,
|
||||
description=tool_description,
|
||||
parameters=tool_parameters,
|
||||
privacy_policy=tool_privacy_policy,
|
||||
labels=tool_labels,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Check if workflow tool provider was created
|
||||
created_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
WorkflowToolProvider.app_id == app.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
assert created_tool_provider is not None
|
||||
assert created_tool_provider.name == tool_name
|
||||
assert created_tool_provider.label == tool_label
|
||||
assert created_tool_provider.icon == json.dumps(tool_icon)
|
||||
assert created_tool_provider.description == tool_description
|
||||
assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters)
|
||||
assert created_tool_provider.privacy_policy == tool_privacy_policy
|
||||
assert created_tool_provider.version == workflow.version
|
||||
assert created_tool_provider.user_id == account.id
|
||||
assert created_tool_provider.tenant_id == account.current_tenant.id
|
||||
assert created_tool_provider.app_id == app.id
|
||||
|
||||
# Verify external service calls
|
||||
mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.assert_called_once()
|
||||
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()
|
||||
mock_external_service_dependencies[
|
||||
"tool_transform_service"
|
||||
].workflow_provider_to_controller.assert_called_once()
|
||||
|
||||
def test_create_workflow_tool_duplicate_name_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation fails when name already exists.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for duplicate tool names
|
||||
- Database constraint enforcement
|
||||
- Correct error message
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create first workflow tool
|
||||
first_tool_name = fake.word()
|
||||
first_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=first_tool_name,
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=first_tool_parameters,
|
||||
)
|
||||
|
||||
# Attempt to create second workflow tool with same name
|
||||
second_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=first_tool_name, # Same name
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "⚙️"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=second_tool_parameters,
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value)
|
||||
|
||||
# Verify only one tool was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
assert tool_count == 1
|
||||
|
||||
def test_create_workflow_tool_invalid_app_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation fails when app does not exist.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for non-existent apps
|
||||
- Correct error message
|
||||
- No database changes when app is invalid
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Generate non-existent app ID
|
||||
non_existent_app_id = fake.uuid4()
|
||||
|
||||
# Attempt to create workflow tool with non-existent app
|
||||
tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=non_existent_app_id, # Non-existent app ID
|
||||
name=fake.word(),
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=tool_parameters,
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert f"App {non_existent_app_id} not found" in str(exc_info.value)
|
||||
|
||||
# Verify no workflow tool was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
assert tool_count == 0
|
||||
|
||||
def test_create_workflow_tool_invalid_parameters_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation fails when parameters are invalid.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for invalid parameter configurations
|
||||
- Parameter validation enforcement
|
||||
- Correct error message
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Setup invalid workflow tool parameters (missing required fields)
|
||||
invalid_parameters = [
|
||||
{
|
||||
"name": "input_text",
|
||||
# Missing description and form fields
|
||||
"type": "string",
|
||||
"required": True,
|
||||
}
|
||||
]
|
||||
|
||||
# Attempt to create workflow tool with invalid parameters
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=fake.word(),
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=invalid_parameters,
|
||||
)
|
||||
|
||||
# Verify error message contains validation error
|
||||
assert "validation error" in str(exc_info.value).lower()
|
||||
|
||||
# Verify no workflow tool was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
assert tool_count == 0
|
||||
|
||||
def test_create_workflow_tool_duplicate_app_id_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation fails when app_id already exists.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for duplicate app_id
|
||||
- Database constraint enforcement for app_id uniqueness
|
||||
- Correct error message
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create first workflow tool
|
||||
first_tool_name = fake.word()
|
||||
first_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=first_tool_name,
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=first_tool_parameters,
|
||||
)
|
||||
|
||||
# Attempt to create second workflow tool with same app_id but different name
|
||||
second_tool_name = fake.word()
|
||||
second_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id, # Same app_id
|
||||
name=second_tool_name, # Different name
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "⚙️"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=second_tool_parameters,
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value)
|
||||
|
||||
# Verify only one tool was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
assert tool_count == 1
|
||||
|
||||
def test_create_workflow_tool_workflow_not_found_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation fails when app has no workflow.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for apps without workflows
|
||||
- Correct error message
|
||||
- No database changes when workflow is missing
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data but without workflow
|
||||
app, account, _ = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Remove workflow reference from app
|
||||
from extensions.ext_database import db
|
||||
|
||||
app.workflow_id = None
|
||||
db.session.commit()
|
||||
|
||||
# Attempt to create workflow tool for app without workflow
|
||||
tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=fake.word(),
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=tool_parameters,
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert f"Workflow not found for app {app.id}" in str(exc_info.value)
|
||||
|
||||
# Verify no workflow tool was created
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
assert tool_count == 0
|
||||
|
||||
def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful workflow tool update with valid parameters.
|
||||
|
||||
This test verifies:
|
||||
- Proper workflow tool update with all required fields
|
||||
- Correct database state after update
|
||||
- Proper relationship maintenance
|
||||
- External service integration
|
||||
- Return value correctness
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create initial workflow tool
|
||||
initial_tool_name = fake.word()
|
||||
initial_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=initial_tool_name,
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=initial_tool_parameters,
|
||||
)
|
||||
|
||||
# Get the created tool
|
||||
from extensions.ext_database import db
|
||||
|
||||
created_tool = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
WorkflowToolProvider.app_id == app.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Setup update parameters
|
||||
updated_tool_name = fake.word()
|
||||
updated_tool_label = fake.word()
|
||||
updated_tool_icon = {"type": "emoji", "emoji": "⚙️"}
|
||||
updated_tool_description = fake.text(max_nb_chars=200)
|
||||
updated_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
updated_tool_privacy_policy = fake.text(max_nb_chars=100)
|
||||
updated_tool_labels = ["automation", "updated"]
|
||||
|
||||
# Execute the update method
|
||||
result = WorkflowToolManageService.update_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_tool_id=created_tool.id,
|
||||
name=updated_tool_name,
|
||||
label=updated_tool_label,
|
||||
icon=updated_tool_icon,
|
||||
description=updated_tool_description,
|
||||
parameters=updated_tool_parameters,
|
||||
privacy_policy=updated_tool_privacy_policy,
|
||||
labels=updated_tool_labels,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Verify database state was updated
|
||||
db.session.refresh(created_tool)
|
||||
assert created_tool.name == updated_tool_name
|
||||
assert created_tool.label == updated_tool_label
|
||||
assert created_tool.icon == json.dumps(updated_tool_icon)
|
||||
assert created_tool.description == updated_tool_description
|
||||
assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters)
|
||||
assert created_tool.privacy_policy == updated_tool_privacy_policy
|
||||
assert created_tool.version == workflow.version
|
||||
assert created_tool.updated_at is not None
|
||||
|
||||
# Verify external service calls
|
||||
mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.assert_called()
|
||||
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called()
|
||||
mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called()
|
||||
|
||||
def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test workflow tool update fails when tool does not exist.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for non-existent tools
|
||||
- Correct error message
|
||||
- No database changes when tool is invalid
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Generate non-existent tool ID
|
||||
non_existent_tool_id = fake.uuid4()
|
||||
|
||||
# Attempt to update non-existent workflow tool
|
||||
tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WorkflowToolManageService.update_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_tool_id=non_existent_tool_id, # Non-existent tool ID
|
||||
name=fake.word(),
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=tool_parameters,
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value)
|
||||
|
||||
# Verify no workflow tool was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
assert tool_count == 0
|
||||
|
||||
def test_update_workflow_tool_same_name_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool update succeeds when keeping the same name.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling when updating tool with same name
|
||||
- Database state maintenance
|
||||
- Update timestamp is set
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create first workflow tool
|
||||
first_tool_name = fake.word()
|
||||
first_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=first_tool_name,
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=first_tool_parameters,
|
||||
)
|
||||
|
||||
# Get the created tool
|
||||
from extensions.ext_database import db
|
||||
|
||||
created_tool = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
WorkflowToolProvider.app_id == app.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Attempt to update tool with same name (should not fail)
|
||||
result = WorkflowToolManageService.update_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_tool_id=created_tool.id,
|
||||
name=first_tool_name, # Same name
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "⚙️"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=first_tool_parameters,
|
||||
)
|
||||
|
||||
# Verify update was successful
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Verify tool still exists with the same name
|
||||
db.session.refresh(created_tool)
|
||||
assert created_tool.name == first_tool_name
|
||||
assert created_tool.updated_at is not None
|
||||
@@ -0,0 +1,554 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
DatasetRetrieveConfigEntity,
|
||||
ExternalDataVariableEntity,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
VariableEntity,
|
||||
VariableEntityType,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from models import Account, Tenant
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.workflow import Workflow
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
|
||||
class TestWorkflowConverter:
|
||||
"""Integration tests for WorkflowConverter using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.workflow.workflow_converter.encrypter") as mock_encrypter,
|
||||
patch("services.workflow.workflow_converter.SimplePromptTransform") as mock_prompt_transform,
|
||||
patch("services.workflow.workflow_converter.AgentChatAppConfigManager") as mock_agent_chat_config_manager,
|
||||
patch("services.workflow.workflow_converter.ChatAppConfigManager") as mock_chat_config_manager,
|
||||
patch("services.workflow.workflow_converter.CompletionAppConfigManager") as mock_completion_config_manager,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_encrypter.decrypt_token.return_value = "decrypted_api_key"
|
||||
mock_prompt_transform.return_value.get_prompt_template.return_value = {
|
||||
"prompt_template": PromptTemplateParser(template="You are a helpful assistant {{text_input}}"),
|
||||
"prompt_rules": {"human_prefix": "Human", "assistant_prefix": "Assistant"},
|
||||
}
|
||||
mock_agent_chat_config_manager.get_app_config.return_value = self._create_mock_app_config()
|
||||
mock_chat_config_manager.get_app_config.return_value = self._create_mock_app_config()
|
||||
mock_completion_config_manager.get_app_config.return_value = self._create_mock_app_config()
|
||||
|
||||
yield {
|
||||
"encrypter": mock_encrypter,
|
||||
"prompt_transform": mock_prompt_transform,
|
||||
"agent_chat_config_manager": mock_agent_chat_config_manager,
|
||||
"chat_config_manager": mock_chat_config_manager,
|
||||
"completion_config_manager": mock_completion_config_manager,
|
||||
}
|
||||
|
||||
def _create_mock_app_config(self):
|
||||
"""Helper method to create a mock app config."""
|
||||
mock_config = type("obj", (object,), {})()
|
||||
mock_config.variables = [
|
||||
VariableEntity(
|
||||
variable="text_input",
|
||||
label="Text Input",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
)
|
||||
]
|
||||
mock_config.model = ModelConfigEntity(
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
mode=LLMMode.CHAT,
|
||||
parameters={},
|
||||
stop=[],
|
||||
)
|
||||
mock_config.prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="You are a helpful assistant {{text_input}}",
|
||||
)
|
||||
mock_config.dataset = None
|
||||
mock_config.external_data_variables = []
|
||||
mock_config.additional_features = type("obj", (object,), {"file_upload": None})()
|
||||
mock_config.app_model_config_dict = {}
|
||||
return mock_config
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
tuple: (account, tenant) - Created account and tenant instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant, account):
|
||||
"""
|
||||
Helper method to create a test app for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
tenant: Tenant instance
|
||||
account: Account instance
|
||||
|
||||
Returns:
|
||||
App: Created app instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create app
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
mode=AppMode.CHAT,
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=10,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
|
||||
# Create app model config
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id,
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
configs={},
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db.session.add(app_model_config)
|
||||
db.session.commit()
|
||||
|
||||
# Link app model config to app
|
||||
app.app_model_config_id = app_model_config.id
|
||||
db.session.commit()
|
||||
|
||||
return app
|
||||
|
||||
def test_convert_to_workflow_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful conversion of app to workflow.
|
||||
|
||||
This test verifies:
|
||||
- Proper app to workflow conversion
|
||||
- Correct database state after conversion
|
||||
- Proper relationship establishment
|
||||
- Workflow creation with correct configuration
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant, account)
|
||||
|
||||
# Act: Execute the conversion
|
||||
workflow_converter = WorkflowConverter()
|
||||
new_app = workflow_converter.convert_to_workflow(
|
||||
app_model=app,
|
||||
account=account,
|
||||
name="Test Workflow App",
|
||||
icon_type="emoji",
|
||||
icon="🚀",
|
||||
icon_background="#4CAF50",
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert new_app is not None
|
||||
assert new_app.name == "Test Workflow App"
|
||||
assert new_app.mode == AppMode.ADVANCED_CHAT
|
||||
assert new_app.icon_type == "emoji"
|
||||
assert new_app.icon == "🚀"
|
||||
assert new_app.icon_background == "#4CAF50"
|
||||
assert new_app.tenant_id == app.tenant_id
|
||||
assert new_app.created_by == account.id
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(new_app)
|
||||
assert new_app.id is not None
|
||||
|
||||
# Verify workflow was created
|
||||
workflow = db.session.query(Workflow).where(Workflow.app_id == new_app.id).first()
|
||||
assert workflow is not None
|
||||
assert workflow.tenant_id == app.tenant_id
|
||||
assert workflow.type == "chat"
|
||||
|
||||
def test_convert_to_workflow_without_app_model_config_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when app model config is missing.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing app model config
|
||||
- Correct exception type and message
|
||||
- Database state remains unchanged
|
||||
"""
|
||||
# Arrange: Create test data without app model config
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
mode=AppMode.CHAT,
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=10,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
workflow_converter = WorkflowConverter()
|
||||
|
||||
# Check initial state
|
||||
initial_workflow_count = db.session.query(Workflow).count()
|
||||
|
||||
with pytest.raises(ValueError, match="App model config is required"):
|
||||
workflow_converter.convert_to_workflow(
|
||||
app_model=app,
|
||||
account=account,
|
||||
name="Test Workflow App",
|
||||
icon_type="emoji",
|
||||
icon="🚀",
|
||||
icon_background="#4CAF50",
|
||||
)
|
||||
|
||||
# Verify database state remains unchanged
|
||||
# The workflow creation happens in convert_app_model_config_to_workflow
|
||||
# which is called before the app_model_config check, so we need to clean up
|
||||
db.session.rollback()
|
||||
final_workflow_count = db.session.query(Workflow).count()
|
||||
assert final_workflow_count == initial_workflow_count
|
||||
|
||||
def test_convert_app_model_config_to_workflow_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of app model config to workflow.
|
||||
|
||||
This test verifies:
|
||||
- Proper app model config to workflow conversion
|
||||
- Correct workflow graph structure
|
||||
- Proper node creation and configuration
|
||||
- Database state management
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant, account)
|
||||
|
||||
# Act: Execute the conversion
|
||||
workflow_converter = WorkflowConverter()
|
||||
workflow = workflow_converter.convert_app_model_config_to_workflow(
|
||||
app_model=app,
|
||||
app_model_config=app.app_model_config,
|
||||
account_id=account.id,
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert workflow is not None
|
||||
assert workflow.tenant_id == app.tenant_id
|
||||
assert workflow.app_id == app.id
|
||||
assert workflow.type == "chat"
|
||||
assert workflow.version == Workflow.VERSION_DRAFT
|
||||
assert workflow.created_by == account.id
|
||||
|
||||
# Verify workflow graph structure
|
||||
graph = json.loads(workflow.graph)
|
||||
assert "nodes" in graph
|
||||
assert "edges" in graph
|
||||
assert len(graph["nodes"]) > 0
|
||||
assert len(graph["edges"]) > 0
|
||||
|
||||
# Verify start node exists
|
||||
start_node = next((node for node in graph["nodes"] if node["data"]["type"] == "start"), None)
|
||||
assert start_node is not None
|
||||
assert start_node["id"] == "start"
|
||||
|
||||
# Verify LLM node exists
|
||||
llm_node = next((node for node in graph["nodes"] if node["data"]["type"] == "llm"), None)
|
||||
assert llm_node is not None
|
||||
assert llm_node["id"] == "llm"
|
||||
|
||||
# Verify answer node exists for chat mode
|
||||
answer_node = next((node for node in graph["nodes"] if node["data"]["type"] == "answer"), None)
|
||||
assert answer_node is not None
|
||||
assert answer_node["id"] == "answer"
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(workflow)
|
||||
assert workflow.id is not None
|
||||
|
||||
# Verify features were set
|
||||
features = json.loads(workflow._features) if workflow._features else {}
|
||||
assert isinstance(features, dict)
|
||||
|
||||
def test_convert_to_start_node_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful conversion to start node.
|
||||
|
||||
This test verifies:
|
||||
- Proper start node creation with variables
|
||||
- Correct node structure and data
|
||||
- Variable encoding and formatting
|
||||
"""
|
||||
# Arrange: Create test variables
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="text_input",
|
||||
label="Text Input",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
),
|
||||
VariableEntity(
|
||||
variable="number_input",
|
||||
label="Number Input",
|
||||
type=VariableEntityType.NUMBER,
|
||||
),
|
||||
]
|
||||
|
||||
# Act: Execute the conversion
|
||||
workflow_converter = WorkflowConverter()
|
||||
start_node = workflow_converter._convert_to_start_node(variables=variables)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert start_node is not None
|
||||
assert start_node["id"] == "start"
|
||||
assert start_node["data"]["title"] == "START"
|
||||
assert start_node["data"]["type"] == "start"
|
||||
assert len(start_node["data"]["variables"]) == 2
|
||||
|
||||
# Verify variable encoding
|
||||
first_variable = start_node["data"]["variables"][0]
|
||||
assert first_variable["variable"] == "text_input"
|
||||
assert first_variable["label"] == "Text Input"
|
||||
assert first_variable["type"] == "text-input"
|
||||
|
||||
second_variable = start_node["data"]["variables"][1]
|
||||
assert second_variable["variable"] == "number_input"
|
||||
assert second_variable["label"] == "Number Input"
|
||||
assert second_variable["type"] == "number"
|
||||
|
||||
def test_convert_to_http_request_node_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful conversion to HTTP request node.
|
||||
|
||||
This test verifies:
|
||||
- Proper HTTP request node creation
|
||||
- Correct API configuration and authorization
|
||||
- Code node creation for response parsing
|
||||
- External data variable mapping
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant, account)
|
||||
|
||||
# Create API based extension
|
||||
api_based_extension = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name="Test API Extension",
|
||||
api_key="encrypted_api_key",
|
||||
api_endpoint="https://api.example.com/test",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(api_based_extension)
|
||||
db.session.commit()
|
||||
|
||||
# Mock encrypter
|
||||
mock_external_service_dependencies["encrypter"].decrypt_token.return_value = "decrypted_api_key"
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="user_input",
|
||||
label="User Input",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
)
|
||||
]
|
||||
|
||||
external_data_variables = [
|
||||
ExternalDataVariableEntity(
|
||||
variable="external_data", type="api", config={"api_based_extension_id": api_based_extension.id}
|
||||
)
|
||||
]
|
||||
|
||||
# Act: Execute the conversion
|
||||
workflow_converter = WorkflowConverter()
|
||||
nodes, external_data_variable_node_mapping = workflow_converter._convert_to_http_request_node(
|
||||
app_model=app,
|
||||
variables=variables,
|
||||
external_data_variables=external_data_variables,
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert len(nodes) == 2 # HTTP request node + code node
|
||||
assert len(external_data_variable_node_mapping) == 1
|
||||
|
||||
# Verify HTTP request node
|
||||
http_request_node = nodes[0]
|
||||
assert http_request_node["data"]["type"] == "http-request"
|
||||
assert http_request_node["data"]["method"] == "post"
|
||||
assert http_request_node["data"]["url"] == api_based_extension.api_endpoint
|
||||
assert http_request_node["data"]["authorization"]["type"] == "api-key"
|
||||
assert http_request_node["data"]["authorization"]["config"]["type"] == "bearer"
|
||||
assert http_request_node["data"]["authorization"]["config"]["api_key"] == "decrypted_api_key"
|
||||
|
||||
# Verify code node
|
||||
code_node = nodes[1]
|
||||
assert code_node["data"]["type"] == "code"
|
||||
assert code_node["data"]["code_language"] == "python3"
|
||||
assert "response_json" in code_node["data"]["variables"][0]["variable"]
|
||||
|
||||
# Verify mapping
|
||||
assert external_data_variable_node_mapping["external_data"] == code_node["id"]
|
||||
|
||||
def test_convert_to_knowledge_retrieval_node_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion to knowledge retrieval node.
|
||||
|
||||
This test verifies:
|
||||
- Proper knowledge retrieval node creation
|
||||
- Correct dataset configuration
|
||||
- Model configuration integration
|
||||
- Query variable selector setup
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create dataset config
|
||||
dataset_config = DatasetEntity(
|
||||
dataset_ids=["dataset_1", "dataset_2"],
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
top_k=10,
|
||||
score_threshold=0.8,
|
||||
reranking_model={"provider": "cohere", "model": "rerank-v2"},
|
||||
reranking_enabled=True,
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ModelConfigEntity(
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
mode=LLMMode.CHAT,
|
||||
parameters={"temperature": 0.7},
|
||||
stop=[],
|
||||
)
|
||||
|
||||
# Act: Execute the conversion for advanced chat mode
|
||||
workflow_converter = WorkflowConverter()
|
||||
node = workflow_converter._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=AppMode.ADVANCED_CHAT,
|
||||
dataset_config=dataset_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert node is not None
|
||||
assert node["data"]["type"] == "knowledge-retrieval"
|
||||
assert node["data"]["title"] == "KNOWLEDGE RETRIEVAL"
|
||||
assert node["data"]["dataset_ids"] == ["dataset_1", "dataset_2"]
|
||||
assert node["data"]["retrieval_mode"] == "multiple"
|
||||
assert node["data"]["query_variable_selector"] == ["sys", "query"]
|
||||
|
||||
# Verify multiple retrieval config
|
||||
multiple_config = node["data"]["multiple_retrieval_config"]
|
||||
assert multiple_config["top_k"] == 10
|
||||
assert multiple_config["score_threshold"] == 0.8
|
||||
assert multiple_config["reranking_model"]["provider"] == "cohere"
|
||||
assert multiple_config["reranking_model"]["model"] == "rerank-v2"
|
||||
|
||||
# Verify single retrieval config is None for multiple strategy
|
||||
assert node["data"]["single_retrieval_config"] is None
|
||||
Reference in New Issue
Block a user