This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1 @@
# API authentication service test module

View File

@@ -0,0 +1,49 @@
import pytest
from services.auth.api_key_auth_base import ApiKeyAuthBase
class ConcreteApiKeyAuth(ApiKeyAuthBase):
"""Concrete implementation for testing abstract base class"""
def validate_credentials(self):
return True
class TestApiKeyAuthBase:
def test_should_store_credentials_on_init(self):
"""Test that credentials are properly stored during initialization"""
credentials = {"api_key": "test_key", "auth_type": "bearer"}
auth = ConcreteApiKeyAuth(credentials)
assert auth.credentials == credentials
def test_should_not_instantiate_abstract_class(self):
"""Test that ApiKeyAuthBase cannot be instantiated directly"""
credentials = {"api_key": "test_key"}
with pytest.raises(TypeError) as exc_info:
ApiKeyAuthBase(credentials)
assert "Can't instantiate abstract class" in str(exc_info.value)
assert "validate_credentials" in str(exc_info.value)
def test_should_allow_subclass_implementation(self):
"""Test that subclasses can properly implement the abstract method"""
credentials = {"api_key": "test_key", "auth_type": "bearer"}
auth = ConcreteApiKeyAuth(credentials)
# Should not raise any exception
result = auth.validate_credentials()
assert result is True
def test_should_handle_empty_credentials(self):
"""Test initialization with empty credentials"""
credentials = {}
auth = ConcreteApiKeyAuth(credentials)
assert auth.credentials == {}
def test_should_handle_none_credentials(self):
"""Test initialization with None credentials"""
credentials = None
auth = ConcreteApiKeyAuth(credentials)
assert auth.credentials is None

View File

@@ -0,0 +1,81 @@
from unittest.mock import MagicMock, patch
import pytest
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
from services.auth.auth_type import AuthType
class TestApiKeyAuthFactory:
"""Test cases for ApiKeyAuthFactory"""
@pytest.mark.parametrize(
("provider", "auth_class_path"),
[
(AuthType.FIRECRAWL, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
(AuthType.WATERCRAWL, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
(AuthType.JINA, "services.auth.jina.jina.JinaAuth"),
],
)
def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path):
"""Test getting auth factory for all valid providers"""
with patch(auth_class_path) as mock_auth:
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
assert auth_class == mock_auth
@pytest.mark.parametrize(
"invalid_provider",
[
"invalid_provider",
"",
None,
123,
"UNSUPPORTED",
],
)
def test_get_apikey_auth_factory_invalid_providers(self, invalid_provider):
"""Test getting auth factory with various invalid providers"""
with pytest.raises(ValueError) as exc_info:
ApiKeyAuthFactory.get_apikey_auth_factory(invalid_provider)
assert str(exc_info.value) == "Invalid provider"
@pytest.mark.parametrize(
("credentials_return_value", "expected_result"),
[
(True, True),
(False, False),
],
)
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
def test_validate_credentials_delegates_to_auth_instance(
self, mock_get_factory, credentials_return_value, expected_result
):
"""Test that validate_credentials delegates to auth instance correctly"""
# Arrange
mock_auth_instance = MagicMock()
mock_auth_instance.validate_credentials.return_value = credentials_return_value
mock_auth_class = MagicMock(return_value=mock_auth_instance)
mock_get_factory.return_value = mock_auth_class
# Act
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
result = factory.validate_credentials()
# Assert
assert result is expected_result
mock_auth_instance.validate_credentials.assert_called_once()
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
def test_validate_credentials_propagates_exceptions(self, mock_get_factory):
"""Test that exceptions from auth instance are propagated"""
# Arrange
mock_auth_instance = MagicMock()
mock_auth_instance.validate_credentials.side_effect = Exception("Authentication error")
mock_auth_class = MagicMock(return_value=mock_auth_instance)
mock_get_factory.return_value = mock_auth_class
# Act & Assert
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
with pytest.raises(Exception) as exc_info:
factory.validate_credentials()
assert str(exc_info.value) == "Authentication error"

View File

@@ -0,0 +1,387 @@
import json
from unittest.mock import Mock, patch
import pytest
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_service import ApiKeyAuthService
class TestApiKeyAuthService:
"""API key authentication service security tests"""
def setup_method(self):
"""Setup test fixtures"""
self.tenant_id = "test_tenant_123"
self.category = "search"
self.provider = "google"
self.binding_id = "binding_123"
self.mock_credentials = {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}}
self.mock_args = {"category": self.category, "provider": self.provider, "credentials": self.mock_credentials}
@patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_success(self, mock_session):
"""Test get provider auth list - success scenario"""
# Mock database query result
mock_binding = Mock()
mock_binding.tenant_id = self.tenant_id
mock_binding.provider = self.provider
mock_binding.disabled = False
mock_session.scalars.return_value.all.return_value = [mock_binding]
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
assert len(result) == 1
assert result[0].tenant_id == self.tenant_id
assert mock_session.scalars.call_count == 1
select_arg = mock_session.scalars.call_args[0][0]
assert "data_source_api_key_auth_binding" in str(select_arg).lower()
@patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_empty(self, mock_session):
"""Test get provider auth list - empty result"""
mock_session.scalars.return_value.all.return_value = []
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
assert result == []
@patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_filters_disabled(self, mock_session):
"""Test get provider auth list - filters disabled items"""
mock_session.scalars.return_value.all.return_value = []
ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
select_stmt = mock_session.scalars.call_args[0][0]
where_clauses = list(getattr(select_stmt, "_where_criteria", []) or [])
# Ensure both tenant filter and disabled filter exist
where_strs = [str(c).lower() for c in where_clauses]
assert any("tenant_id" in s for s in where_strs)
assert any("disabled" in s for s in where_strs)
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_success(self, mock_encrypter, mock_factory, mock_session):
"""Test create provider auth - success scenario"""
# Mock successful auth validation
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
mock_factory.return_value = mock_auth_instance
# Mock encryption
encrypted_key = "encrypted_test_key_123"
mock_encrypter.encrypt_token.return_value = encrypted_key
# Mock database operations
mock_session.add = Mock()
mock_session.commit = Mock()
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
# Verify factory class calls
mock_factory.assert_called_once_with(self.provider, self.mock_credentials)
mock_auth_instance.validate_credentials.assert_called_once()
# Verify encryption calls
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, "test_secret_key_123")
# Verify database operations
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
def test_create_provider_auth_validation_failed(self, mock_factory, mock_session):
"""Test create provider auth - validation failed"""
# Mock failed auth validation
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = False
mock_factory.return_value = mock_auth_instance
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
# Verify no database operations when validation fails
mock_session.add.assert_not_called()
mock_session.commit.assert_not_called()
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_encrypts_api_key(self, mock_encrypter, mock_factory, mock_session):
"""Test create provider auth - ensures API key is encrypted"""
# Mock successful auth validation
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
mock_factory.return_value = mock_auth_instance
# Mock encryption
encrypted_key = "encrypted_test_key_123"
mock_encrypter.encrypt_token.return_value = encrypted_key
# Mock database operations
mock_session.add = Mock()
mock_session.commit = Mock()
args_copy = self.mock_args.copy()
original_key = args_copy["credentials"]["config"]["api_key"]
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
# Verify original key is replaced with encrypted key
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key
assert args_copy["credentials"]["config"]["api_key"] != original_key
# Verify encryption function is called correctly
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_success(self, mock_session):
"""Test get auth credentials - success scenario"""
# Mock database query result
mock_binding = Mock()
mock_binding.credentials = json.dumps(self.mock_credentials)
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
assert result == self.mock_credentials
mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_not_found(self, mock_session):
"""Test get auth credentials - not found"""
mock_session.query.return_value.where.return_value.first.return_value = None
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
assert result is None
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_filters_correctly(self, mock_session):
"""Test get auth credentials - applies correct filters"""
mock_session.query.return_value.where.return_value.first.return_value = None
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
# Verify where conditions are correct
where_call = mock_session.query.return_value.where.call_args[0]
assert len(where_call) == 4 # tenant_id, category, provider, disabled
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_json_parsing(self, mock_session):
"""Test get auth credentials - JSON parsing"""
# Mock credentials with special characters
special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}}
mock_binding = Mock()
mock_binding.credentials = json.dumps(special_credentials, ensure_ascii=False)
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
assert result == special_credentials
assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%"
@patch("services.auth.api_key_auth_service.db.session")
def test_delete_provider_auth_success(self, mock_session):
"""Test delete provider auth - success scenario"""
# Mock database query result
mock_binding = Mock()
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
# Verify delete operations
mock_session.delete.assert_called_once_with(mock_binding)
mock_session.commit.assert_called_once()
@patch("services.auth.api_key_auth_service.db.session")
def test_delete_provider_auth_not_found(self, mock_session):
"""Test delete provider auth - not found"""
mock_session.query.return_value.where.return_value.first.return_value = None
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
# Verify no delete operations when not found
mock_session.delete.assert_not_called()
mock_session.commit.assert_not_called()
@patch("services.auth.api_key_auth_service.db.session")
def test_delete_provider_auth_filters_by_tenant(self, mock_session):
"""Test delete provider auth - filters by tenant"""
mock_session.query.return_value.where.return_value.first.return_value = None
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
# Verify where conditions include tenant_id and binding_id
where_call = mock_session.query.return_value.where.call_args[0]
assert len(where_call) == 2
def test_validate_api_key_auth_args_success(self):
"""Test API key auth args validation - success scenario"""
# Should not raise any exception
ApiKeyAuthService.validate_api_key_auth_args(self.mock_args)
def test_validate_api_key_auth_args_missing_category(self):
"""Test API key auth args validation - missing category"""
args = self.mock_args.copy()
del args["category"]
with pytest.raises(ValueError, match="category is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_empty_category(self):
"""Test API key auth args validation - empty category"""
args = self.mock_args.copy()
args["category"] = ""
with pytest.raises(ValueError, match="category is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_missing_provider(self):
"""Test API key auth args validation - missing provider"""
args = self.mock_args.copy()
del args["provider"]
with pytest.raises(ValueError, match="provider is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_empty_provider(self):
"""Test API key auth args validation - empty provider"""
args = self.mock_args.copy()
args["provider"] = ""
with pytest.raises(ValueError, match="provider is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_missing_credentials(self):
"""Test API key auth args validation - missing credentials"""
args = self.mock_args.copy()
del args["credentials"]
with pytest.raises(ValueError, match="credentials is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_empty_credentials(self):
"""Test API key auth args validation - empty credentials"""
args = self.mock_args.copy()
args["credentials"] = None
with pytest.raises(ValueError, match="credentials is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_invalid_credentials_type(self):
"""Test API key auth args validation - invalid credentials type"""
args = self.mock_args.copy()
args["credentials"] = "not_a_dict"
with pytest.raises(ValueError, match="credentials must be a dictionary"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_missing_auth_type(self):
"""Test API key auth args validation - missing auth_type"""
args = self.mock_args.copy()
del args["credentials"]["auth_type"]
with pytest.raises(ValueError, match="auth_type is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
def test_validate_api_key_auth_args_empty_auth_type(self):
"""Test API key auth args validation - empty auth_type"""
args = self.mock_args.copy()
args["credentials"]["auth_type"] = ""
with pytest.raises(ValueError, match="auth_type is required"):
ApiKeyAuthService.validate_api_key_auth_args(args)
@pytest.mark.parametrize(
"malicious_input",
[
"<script>alert('xss')</script>",
"'; DROP TABLE users; --",
"../../../etc/passwd",
"\\x00\\x00", # null bytes
"A" * 10000, # very long input
],
)
def test_validate_api_key_auth_args_malicious_input(self, malicious_input):
"""Test API key auth args validation - malicious input"""
args = self.mock_args.copy()
args["category"] = malicious_input
# Verify parameter validator doesn't crash on malicious input
# Should validate normally rather than raising security-related exceptions
ApiKeyAuthService.validate_api_key_auth_args(args)
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_database_error_handling(self, mock_encrypter, mock_factory, mock_session):
"""Test create provider auth - database error handling"""
# Mock successful auth validation
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
mock_factory.return_value = mock_auth_instance
# Mock encryption
mock_encrypter.encrypt_token.return_value = "encrypted_key"
# Mock database error
mock_session.commit.side_effect = Exception("Database error")
with pytest.raises(Exception, match="Database error"):
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_invalid_json(self, mock_session):
"""Test get auth credentials - invalid JSON"""
# Mock database returning invalid JSON
mock_binding = Mock()
mock_binding.credentials = "invalid json content"
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
with pytest.raises(json.JSONDecodeError):
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
def test_create_provider_auth_factory_exception(self, mock_factory, mock_session):
"""Test create provider auth - factory exception"""
# Mock factory raising exception
mock_factory.side_effect = Exception("Factory error")
with pytest.raises(Exception, match="Factory error"):
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, mock_session):
"""Test create provider auth - encryption exception"""
# Mock successful auth validation
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
mock_factory.return_value = mock_auth_instance
# Mock encryption exception
mock_encrypter.encrypt_token.side_effect = Exception("Encryption error")
with pytest.raises(Exception, match="Encryption error"):
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
def test_validate_api_key_auth_args_none_input(self):
"""Test API key auth args validation - None input"""
with pytest.raises(TypeError):
ApiKeyAuthService.validate_api_key_auth_args(None)
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
"""Test API key auth args validation - dict credentials with list auth_type"""
args = self.mock_args.copy()
args["credentials"]["auth_type"] = ["api_key"]
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
# So this should not raise exception, this test should pass
ApiKeyAuthService.validate_api_key_auth_args(args)

View File

@@ -0,0 +1,231 @@
"""
API Key Authentication System Integration Tests
"""
import json
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, patch
import httpx
import pytest
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
from services.auth.api_key_auth_service import ApiKeyAuthService
from services.auth.auth_type import AuthType
class TestAuthIntegration:
def setup_method(self):
self.tenant_id_1 = "tenant_123"
self.tenant_id_2 = "tenant_456" # For multi-tenant isolation testing
self.category = "search"
# Realistic authentication configurations
self.firecrawl_credentials = {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}}
self.jina_credentials = {"auth_type": "bearer", "config": {"api_key": "jina_test_key_456"}}
self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}}
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session):
"""Test complete authentication flow: request → validation → encryption → storage"""
mock_http.return_value = self._create_success_response()
mock_encrypt.return_value = "encrypted_fc_test_key_123"
mock_session.add = Mock()
mock_session.commit = Mock()
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
mock_http.assert_called_once()
call_args = mock_http.call_args
assert "https://api.firecrawl.dev/v1/crawl" in call_args[0][0]
assert call_args[1]["headers"]["Authorization"] == "Bearer fc_test_key_123"
mock_encrypt.assert_called_once_with(self.tenant_id_1, "fc_test_key_123")
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_cross_component_integration(self, mock_http):
"""Test factory → provider → HTTP call integration"""
mock_http.return_value = self._create_success_response()
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
result = factory.validate_credentials()
assert result is True
mock_http.assert_called_once()
@patch("services.auth.api_key_auth_service.db.session")
def test_multi_tenant_isolation(self, mock_session):
"""Ensure complete tenant data isolation"""
tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)
mock_session.scalars.return_value.all.return_value = [tenant1_binding]
result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)
mock_session.scalars.return_value.all.return_value = [tenant2_binding]
result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)
assert len(result1) == 1
assert result1[0].tenant_id == self.tenant_id_1
assert len(result2) == 1
assert result2[0].tenant_id == self.tenant_id_2
@patch("services.auth.api_key_auth_service.db.session")
def test_cross_tenant_access_prevention(self, mock_session):
"""Test prevention of cross-tenant credential access"""
mock_session.query.return_value.where.return_value.first.return_value = None
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL)
assert result is None
def test_sensitive_data_protection(self):
"""Ensure API keys don't leak to logs"""
credentials_with_secrets = {
"auth_type": "bearer",
"config": {"api_key": "super_secret_key_do_not_log", "secret": "another_secret"},
}
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, credentials_with_secrets)
factory_str = str(factory)
assert "super_secret_key_do_not_log" not in factory_str
assert "another_secret" not in factory_str
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session):
"""Test concurrent authentication creation safety"""
mock_http.return_value = self._create_success_response()
mock_encrypt.return_value = "encrypted_key"
mock_session.add = Mock()
mock_session.commit = Mock()
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
results = []
exceptions = []
def create_auth():
try:
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
results.append("success")
except Exception as e:
exceptions.append(e)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(create_auth) for _ in range(5)]
for future in futures:
future.result()
assert len(results) == 5
assert len(exceptions) == 0
assert mock_session.add.call_count == 5
assert mock_session.commit.call_count == 5
@pytest.mark.parametrize(
"invalid_input",
[
None, # Null input
{}, # Empty dictionary - missing required fields
{"auth_type": "bearer"}, # Missing config section
{"auth_type": "bearer", "config": {}}, # Missing api_key
],
)
def test_invalid_input_boundary(self, invalid_input):
"""Test boundary handling for invalid inputs"""
with pytest.raises((ValueError, KeyError, TypeError, AttributeError)):
ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_http_error_handling(self, mock_http):
"""Test proper HTTP error handling"""
mock_response = Mock()
mock_response.status_code = 401
mock_response.text = '{"error": "Unauthorized"}'
mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized")
mock_http.return_value = mock_response
# PT012: Split into single statement for pytest.raises
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
with pytest.raises((httpx.HTTPError, Exception)):
factory.validate_credentials()
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_network_failure_recovery(self, mock_http, mock_session):
"""Test system recovery from network failures"""
mock_http.side_effect = httpx.RequestError("Network timeout")
mock_session.add = Mock()
mock_session.commit = Mock()
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
with pytest.raises(httpx.RequestError):
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
mock_session.commit.assert_not_called()
@pytest.mark.parametrize(
("provider", "credentials"),
[
(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "fc_key"}}),
(AuthType.JINA, {"auth_type": "bearer", "config": {"api_key": "jina_key"}}),
(AuthType.WATERCRAWL, {"auth_type": "x-api-key", "config": {"api_key": "wc_key"}}),
],
)
def test_all_providers_factory_creation(self, provider, credentials):
"""Test factory creation for all supported providers"""
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
assert auth_class is not None
factory = ApiKeyAuthFactory(provider, credentials)
assert factory.auth is not None
def _create_success_response(self, status_code=200):
"""Create successful HTTP response mock"""
mock_response = Mock()
mock_response.status_code = status_code
mock_response.json.return_value = {"status": "success"}
mock_response.raise_for_status.return_value = None
return mock_response
def _create_mock_binding(self, tenant_id: str, provider: str, credentials: dict) -> Mock:
"""Create realistic database binding mock"""
mock_binding = Mock()
mock_binding.id = f"binding_{provider}_{tenant_id}"
mock_binding.tenant_id = tenant_id
mock_binding.category = self.category
mock_binding.provider = provider
mock_binding.credentials = json.dumps(credentials, ensure_ascii=False)
mock_binding.disabled = False
mock_binding.created_at = Mock()
mock_binding.created_at.timestamp.return_value = 1640995200
mock_binding.updated_at = Mock()
mock_binding.updated_at.timestamp.return_value = 1640995200
return mock_binding
def test_integration_coverage_validation(self):
"""Validate integration test coverage meets quality standards"""
core_scenarios = {
"business_logic": ["end_to_end_auth_flow", "cross_component_integration"],
"security": ["multi_tenant_isolation", "cross_tenant_access_prevention", "sensitive_data_protection"],
"reliability": ["concurrent_creation_safety", "network_failure_recovery"],
"compatibility": ["all_providers_factory_creation"],
"boundaries": ["invalid_input_boundary", "http_error_handling"],
}
total_scenarios = sum(len(scenarios) for scenarios in core_scenarios.values())
assert total_scenarios >= 10
security_tests = core_scenarios["security"]
assert "multi_tenant_isolation" in security_tests
assert "sensitive_data_protection" in security_tests
assert True

View File

@@ -0,0 +1,150 @@
import pytest
from services.auth.auth_type import AuthType
class TestAuthType:
"""Test cases for AuthType enum"""
def test_auth_type_is_str_enum(self):
"""Test that AuthType is properly a StrEnum"""
assert issubclass(AuthType, str)
assert hasattr(AuthType, "__members__")
def test_auth_type_has_expected_values(self):
"""Test that all expected auth types exist with correct values"""
expected_values = {
"FIRECRAWL": "firecrawl",
"WATERCRAWL": "watercrawl",
"JINA": "jinareader",
}
# Verify all expected members exist
for member_name, expected_value in expected_values.items():
assert hasattr(AuthType, member_name)
assert getattr(AuthType, member_name).value == expected_value
# Verify no extra members exist
assert len(AuthType) == len(expected_values)
@pytest.mark.parametrize(
("auth_type", "expected_string"),
[
(AuthType.FIRECRAWL, "firecrawl"),
(AuthType.WATERCRAWL, "watercrawl"),
(AuthType.JINA, "jinareader"),
],
)
def test_auth_type_string_representation(self, auth_type, expected_string):
"""Test string representation of auth types"""
assert str(auth_type) == expected_string
assert auth_type.value == expected_string
@pytest.mark.parametrize(
("auth_type", "compare_value", "expected_result"),
[
(AuthType.FIRECRAWL, "firecrawl", True),
(AuthType.WATERCRAWL, "watercrawl", True),
(AuthType.JINA, "jinareader", True),
(AuthType.FIRECRAWL, "FIRECRAWL", False), # Case sensitive
(AuthType.FIRECRAWL, "watercrawl", False),
(AuthType.JINA, "jina", False), # Full value mismatch
],
)
def test_auth_type_comparison(self, auth_type, compare_value, expected_result):
"""Test auth type comparison with strings"""
assert (auth_type == compare_value) is expected_result
def test_auth_type_iteration(self):
"""Test that AuthType can be iterated over"""
auth_types = list(AuthType)
assert len(auth_types) == 3
assert AuthType.FIRECRAWL in auth_types
assert AuthType.WATERCRAWL in auth_types
assert AuthType.JINA in auth_types
def test_auth_type_membership(self):
"""Test membership checking for AuthType"""
assert "firecrawl" in [auth.value for auth in AuthType]
assert "watercrawl" in [auth.value for auth in AuthType]
assert "jinareader" in [auth.value for auth in AuthType]
assert "invalid" not in [auth.value for auth in AuthType]
def test_auth_type_invalid_attribute_access(self):
"""Test accessing non-existent auth type raises AttributeError"""
with pytest.raises(AttributeError):
_ = AuthType.INVALID_TYPE
def test_auth_type_immutability(self):
"""Test that enum values cannot be modified"""
# In Python 3.11+, enum members are read-only
with pytest.raises(AttributeError):
AuthType.FIRECRAWL = "modified"
def test_auth_type_from_value(self):
"""Test creating AuthType from string value"""
assert AuthType("firecrawl") == AuthType.FIRECRAWL
assert AuthType("watercrawl") == AuthType.WATERCRAWL
assert AuthType("jinareader") == AuthType.JINA
# Test invalid value
with pytest.raises(ValueError) as exc_info:
AuthType("invalid_auth_type")
assert "invalid_auth_type" in str(exc_info.value)
def test_auth_type_name_property(self):
"""Test the name property of enum members"""
assert AuthType.FIRECRAWL.name == "FIRECRAWL"
assert AuthType.WATERCRAWL.name == "WATERCRAWL"
assert AuthType.JINA.name == "JINA"
@pytest.mark.parametrize(
"auth_type",
[AuthType.FIRECRAWL, AuthType.WATERCRAWL, AuthType.JINA],
)
def test_auth_type_isinstance_checks(self, auth_type):
"""Test isinstance checks for auth types"""
assert isinstance(auth_type, AuthType)
assert isinstance(auth_type, str)
assert isinstance(auth_type.value, str)
def test_auth_type_hash(self):
"""Test that auth types are hashable and can be used in sets/dicts"""
auth_set = {AuthType.FIRECRAWL, AuthType.WATERCRAWL, AuthType.JINA}
assert len(auth_set) == 3
auth_dict = {
AuthType.FIRECRAWL: "firecrawl_handler",
AuthType.WATERCRAWL: "watercrawl_handler",
AuthType.JINA: "jina_handler",
}
assert auth_dict[AuthType.FIRECRAWL] == "firecrawl_handler"
def test_auth_type_json_serializable(self):
"""Test that auth types can be JSON serialized"""
import json
auth_data = {
"provider": AuthType.FIRECRAWL,
"enabled": True,
}
# Should serialize to string value
json_str = json.dumps(auth_data, default=str)
assert '"provider": "firecrawl"' in json_str
def test_auth_type_matches_factory_usage(self):
"""Test that all AuthType values are handled by ApiKeyAuthFactory"""
# This test verifies that the enum values match what's expected
# by the factory implementation
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
for auth_type in AuthType:
# Should not raise ValueError for valid auth types
try:
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(auth_type)
assert auth_class is not None
except ImportError:
# It's OK if the actual auth implementation doesn't exist
# We're just testing that the enum value is recognized
pass

View File

@@ -0,0 +1,191 @@
from unittest.mock import MagicMock, patch
import httpx
import pytest
from services.auth.firecrawl.firecrawl import FirecrawlAuth
class TestFirecrawlAuth:
@pytest.fixture
def valid_credentials(self):
"""Fixture for valid bearer credentials"""
return {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
@pytest.fixture
def auth_instance(self, valid_credentials):
"""Fixture for FirecrawlAuth instance with valid credentials"""
return FirecrawlAuth(valid_credentials)
def test_should_initialize_with_valid_bearer_credentials(self, valid_credentials):
"""Test successful initialization with valid bearer credentials"""
auth = FirecrawlAuth(valid_credentials)
assert auth.api_key == "test_api_key_123"
assert auth.base_url == "https://api.firecrawl.dev"
assert auth.credentials == valid_credentials
def test_should_initialize_with_custom_base_url(self):
"""Test initialization with custom base URL"""
credentials = {
"auth_type": "bearer",
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
}
auth = FirecrawlAuth(credentials)
assert auth.api_key == "test_api_key_123"
assert auth.base_url == "https://custom.firecrawl.dev"
@pytest.mark.parametrize(
("auth_type", "expected_error"),
[
("basic", "Invalid auth type, Firecrawl auth type must be Bearer"),
("x-api-key", "Invalid auth type, Firecrawl auth type must be Bearer"),
("", "Invalid auth type, Firecrawl auth type must be Bearer"),
],
)
def test_should_raise_error_for_invalid_auth_type(self, auth_type, expected_error):
"""Test that non-bearer auth types raise ValueError"""
credentials = {"auth_type": auth_type, "config": {"api_key": "test_api_key_123"}}
with pytest.raises(ValueError) as exc_info:
FirecrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@pytest.mark.parametrize(
("credentials", "expected_error"),
[
({"auth_type": "bearer", "config": {}}, "No API key provided"),
({"auth_type": "bearer"}, "No API key provided"),
({"auth_type": "bearer", "config": {"api_key": ""}}, "No API key provided"),
({"auth_type": "bearer", "config": {"api_key": None}}, "No API key provided"),
],
)
def test_should_raise_error_for_missing_api_key(self, credentials, expected_error):
"""Test that missing or empty API key raises ValueError"""
with pytest.raises(ValueError) as exc_info:
FirecrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
"""Test successful credential validation"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_post.return_value = mock_response
result = auth_instance.validate_credentials()
assert result is True
expected_data = {
"url": "https://example.com",
"includePaths": [],
"excludePaths": [],
"limit": 1,
"scrapeOptions": {"onlyMainContent": True},
}
mock_post.assert_called_once_with(
"https://api.firecrawl.dev/v1/crawl",
headers={"Content-Type": "application/json", "Authorization": "Bearer test_api_key_123"},
json=expected_data,
)
@pytest.mark.parametrize(
("status_code", "error_message"),
[
(402, "Payment required"),
(409, "Conflict error"),
(500, "Internal server error"),
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes"""
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.json.return_value = {"error": error_message}
mock_post.return_value = mock_response
with pytest.raises(Exception) as exc_info:
auth_instance.validate_credentials()
assert str(exc_info.value) == f"Failed to authorize. Status code: {status_code}. Error: {error_message}"
@pytest.mark.parametrize(
("status_code", "response_text", "has_json_error", "expected_error_contains"),
[
(403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
(404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
(401, "Not JSON", True, "Expecting value"), # JSON decode error
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_unexpected_errors(
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
):
"""Test handling of unexpected errors with various response formats"""
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.text = response_text
if has_json_error:
mock_response.json.side_effect = Exception("Not JSON")
mock_post.return_value = mock_response
with pytest.raises(Exception) as exc_info:
auth_instance.validate_credentials()
assert expected_error_contains in str(exc_info.value)
@pytest.mark.parametrize(
("exception_type", "exception_message"),
[
(httpx.ConnectError, "Network error"),
(httpx.TimeoutException, "Request timeout"),
(httpx.ReadTimeout, "Read timeout"),
(httpx.ConnectTimeout, "Connection timeout"),
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts"""
mock_post.side_effect = exception_type(exception_message)
with pytest.raises(exception_type) as exc_info:
auth_instance.validate_credentials()
assert exception_message in str(exc_info.value)
def test_should_not_expose_api_key_in_error_messages(self):
"""Test that API key is not exposed in error messages"""
credentials = {"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}
auth = FirecrawlAuth(credentials)
# Verify API key is stored but not in any error message
assert auth.api_key == "super_secret_key_12345"
# Test various error scenarios don't expose the key
with pytest.raises(ValueError) as exc_info:
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_use_custom_base_url_in_validation(self, mock_post):
"""Test that custom base URL is used in validation"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_post.return_value = mock_response
credentials = {
"auth_type": "bearer",
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
}
auth = FirecrawlAuth(credentials)
result = auth.validate_credentials()
assert result is True
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message"""
mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
with pytest.raises(httpx.TimeoutException) as exc_info:
auth_instance.validate_credentials()
# Verify the timeout exception is raised with original message
assert "timed out" in str(exc_info.value)

View File

@@ -0,0 +1,155 @@
from unittest.mock import MagicMock, patch
import httpx
import pytest
from services.auth.jina.jina import JinaAuth
class TestJinaAuth:
def test_should_initialize_with_valid_bearer_credentials(self):
"""Test successful initialization with valid bearer credentials"""
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
assert auth.api_key == "test_api_key_123"
assert auth.credentials == credentials
def test_should_raise_error_for_invalid_auth_type(self):
"""Test that non-bearer auth type raises ValueError"""
credentials = {"auth_type": "basic", "config": {"api_key": "test_api_key_123"}}
with pytest.raises(ValueError) as exc_info:
JinaAuth(credentials)
assert str(exc_info.value) == "Invalid auth type, Jina Reader auth type must be Bearer"
def test_should_raise_error_for_missing_api_key(self):
"""Test that missing API key raises ValueError"""
credentials = {"auth_type": "bearer", "config": {}}
with pytest.raises(ValueError) as exc_info:
JinaAuth(credentials)
assert str(exc_info.value) == "No API key provided"
def test_should_raise_error_for_missing_config(self):
"""Test that missing config section raises ValueError"""
credentials = {"auth_type": "bearer"}
with pytest.raises(ValueError) as exc_info:
JinaAuth(credentials)
assert str(exc_info.value) == "No API key provided"
@patch("services.auth.jina.jina.httpx.post")
def test_should_validate_valid_credentials_successfully(self, mock_post):
"""Test successful credential validation"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_post.return_value = mock_response
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
result = auth.validate_credentials()
assert result is True
mock_post.assert_called_once_with(
"https://r.jina.ai",
headers={"Content-Type": "application/json", "Authorization": "Bearer test_api_key_123"},
json={"url": "https://example.com"},
)
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_402_error(self, mock_post):
"""Test handling of 402 Payment Required error"""
mock_response = MagicMock()
mock_response.status_code = 402
mock_response.json.return_value = {"error": "Payment required"}
mock_post.return_value = mock_response
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
with pytest.raises(Exception) as exc_info:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_409_error(self, mock_post):
"""Test handling of 409 Conflict error"""
mock_response = MagicMock()
mock_response.status_code = 409
mock_response.json.return_value = {"error": "Conflict error"}
mock_post.return_value = mock_response
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
with pytest.raises(Exception) as exc_info:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_500_error(self, mock_post):
"""Test handling of 500 Internal Server Error"""
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.json.return_value = {"error": "Internal server error"}
mock_post.return_value = mock_response
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
with pytest.raises(Exception) as exc_info:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
"""Test handling of unexpected errors with text response"""
mock_response = MagicMock()
mock_response.status_code = 403
mock_response.text = '{"error": "Forbidden"}'
mock_response.json.side_effect = Exception("Not JSON")
mock_post.return_value = mock_response
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
with pytest.raises(Exception) as exc_info:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_unexpected_error_without_text(self, mock_post):
"""Test handling of unexpected errors without text response"""
mock_response = MagicMock()
mock_response.status_code = 404
mock_response.text = ""
mock_response.json.side_effect = Exception("Not JSON")
mock_post.return_value = mock_response
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
with pytest.raises(Exception) as exc_info:
auth.validate_credentials()
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
@patch("services.auth.jina.jina.httpx.post")
def test_should_handle_network_errors(self, mock_post):
"""Test handling of network connection errors"""
mock_post.side_effect = httpx.ConnectError("Network error")
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials)
with pytest.raises(httpx.ConnectError):
auth.validate_credentials()
def test_should_not_expose_api_key_in_error_messages(self):
"""Test that API key is not exposed in error messages"""
credentials = {"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}
auth = JinaAuth(credentials)
# Verify API key is stored but not in any error message
assert auth.api_key == "super_secret_key_12345"
# Test various error scenarios don't expose the key
with pytest.raises(ValueError) as exc_info:
JinaAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value)

View File

@@ -0,0 +1,205 @@
from unittest.mock import MagicMock, patch
import httpx
import pytest
from services.auth.watercrawl.watercrawl import WatercrawlAuth
class TestWatercrawlAuth:
@pytest.fixture
def valid_credentials(self):
"""Fixture for valid x-api-key credentials"""
return {"auth_type": "x-api-key", "config": {"api_key": "test_api_key_123"}}
@pytest.fixture
def auth_instance(self, valid_credentials):
"""Fixture for WatercrawlAuth instance with valid credentials"""
return WatercrawlAuth(valid_credentials)
def test_should_initialize_with_valid_x_api_key_credentials(self, valid_credentials):
"""Test successful initialization with valid x-api-key credentials"""
auth = WatercrawlAuth(valid_credentials)
assert auth.api_key == "test_api_key_123"
assert auth.base_url == "https://app.watercrawl.dev"
assert auth.credentials == valid_credentials
def test_should_initialize_with_custom_base_url(self):
"""Test initialization with custom base URL"""
credentials = {
"auth_type": "x-api-key",
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.watercrawl.dev"},
}
auth = WatercrawlAuth(credentials)
assert auth.api_key == "test_api_key_123"
assert auth.base_url == "https://custom.watercrawl.dev"
@pytest.mark.parametrize(
("auth_type", "expected_error"),
[
("bearer", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
("basic", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
("", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
],
)
def test_should_raise_error_for_invalid_auth_type(self, auth_type, expected_error):
"""Test that non-x-api-key auth types raise ValueError"""
credentials = {"auth_type": auth_type, "config": {"api_key": "test_api_key_123"}}
with pytest.raises(ValueError) as exc_info:
WatercrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@pytest.mark.parametrize(
("credentials", "expected_error"),
[
({"auth_type": "x-api-key", "config": {}}, "No API key provided"),
({"auth_type": "x-api-key"}, "No API key provided"),
({"auth_type": "x-api-key", "config": {"api_key": ""}}, "No API key provided"),
({"auth_type": "x-api-key", "config": {"api_key": None}}, "No API key provided"),
],
)
def test_should_raise_error_for_missing_api_key(self, credentials, expected_error):
"""Test that missing or empty API key raises ValueError"""
with pytest.raises(ValueError) as exc_info:
WatercrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
"""Test successful credential validation"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_get.return_value = mock_response
result = auth_instance.validate_credentials()
assert result is True
mock_get.assert_called_once_with(
"https://app.watercrawl.dev/api/v1/core/crawl-requests/",
headers={"Content-Type": "application/json", "X-API-KEY": "test_api_key_123"},
)
@pytest.mark.parametrize(
("status_code", "error_message"),
[
(402, "Payment required"),
(409, "Conflict error"),
(500, "Internal server error"),
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes"""
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.json.return_value = {"error": error_message}
mock_get.return_value = mock_response
with pytest.raises(Exception) as exc_info:
auth_instance.validate_credentials()
assert str(exc_info.value) == f"Failed to authorize. Status code: {status_code}. Error: {error_message}"
@pytest.mark.parametrize(
("status_code", "response_text", "has_json_error", "expected_error_contains"),
[
(403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
(404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
(401, "Not JSON", True, "Expecting value"), # JSON decode error
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_unexpected_errors(
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
):
"""Test handling of unexpected errors with various response formats"""
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.text = response_text
if has_json_error:
mock_response.json.side_effect = Exception("Not JSON")
mock_get.return_value = mock_response
with pytest.raises(Exception) as exc_info:
auth_instance.validate_credentials()
assert expected_error_contains in str(exc_info.value)
@pytest.mark.parametrize(
("exception_type", "exception_message"),
[
(httpx.ConnectError, "Network error"),
(httpx.TimeoutException, "Request timeout"),
(httpx.ReadTimeout, "Read timeout"),
(httpx.ConnectTimeout, "Connection timeout"),
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts"""
mock_get.side_effect = exception_type(exception_message)
with pytest.raises(exception_type) as exc_info:
auth_instance.validate_credentials()
assert exception_message in str(exc_info.value)
def test_should_not_expose_api_key_in_error_messages(self):
"""Test that API key is not exposed in error messages"""
credentials = {"auth_type": "x-api-key", "config": {"api_key": "super_secret_key_12345"}}
auth = WatercrawlAuth(credentials)
# Verify API key is stored but not in any error message
assert auth.api_key == "super_secret_key_12345"
# Test various error scenarios don't expose the key
with pytest.raises(ValueError) as exc_info:
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_use_custom_base_url_in_validation(self, mock_get):
"""Test that custom base URL is used in validation"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_get.return_value = mock_response
credentials = {
"auth_type": "x-api-key",
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.watercrawl.dev"},
}
auth = WatercrawlAuth(credentials)
result = auth.validate_credentials()
assert result is True
assert mock_get.call_args[0][0] == "https://custom.watercrawl.dev/api/v1/core/crawl-requests/"
@pytest.mark.parametrize(
("base_url", "expected_url"),
[
("https://app.watercrawl.dev", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
("https://app.watercrawl.dev/", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
"""Test that urljoin is used correctly for URL construction with various base URLs"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_get.return_value = mock_response
credentials = {"auth_type": "x-api-key", "config": {"api_key": "test_api_key_123", "base_url": base_url}}
auth = WatercrawlAuth(credentials)
auth.validate_credentials()
# Verify the correct URL was called
assert mock_get.call_args[0][0] == expected_url
@patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message"""
mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
with pytest.raises(httpx.TimeoutException) as exc_info:
auth_instance.validate_credentials()
# Verify the timeout exception is raised with original message
assert "timed out" in str(exc_info.value)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,59 @@
from unittest.mock import MagicMock
class ServiceDbTestHelper:
"""
Helper class for service database query tests.
"""
@staticmethod
def setup_db_query_filter_by_mock(mock_db, query_results):
"""
Smart database query mock that responds based on model type and query parameters.
Args:
mock_db: Mock database session
query_results: Dict mapping (model_name, filter_key, filter_value) to return value
Example: {('Account', 'email', 'test@example.com'): mock_account}
"""
def query_side_effect(model):
mock_query = MagicMock()
def filter_by_side_effect(**kwargs):
mock_filter_result = MagicMock()
def first_side_effect():
# Find matching result based on model and filter parameters
for (model_name, filter_key, filter_value), result in query_results.items():
if model.__name__ == model_name and filter_key in kwargs and kwargs[filter_key] == filter_value:
return result
return None
mock_filter_result.first.side_effect = first_side_effect
# Handle order_by calls for complex queries
def order_by_side_effect(*args, **kwargs):
mock_order_result = MagicMock()
def order_first_side_effect():
# Look for order_by results in the same query_results dict
for (model_name, filter_key, filter_value), result in query_results.items():
if (
model.__name__ == model_name
and filter_key == "order_by"
and filter_value == "first_available"
):
return result
return None
mock_order_result.first.side_effect = order_first_side_effect
return mock_order_result
mock_filter_result.order_by.side_effect = order_by_side_effect
return mock_filter_result
mock_query.filter_by.side_effect = filter_by_side_effect
return mock_query
mock_db.session.query.side_effect = query_side_effect

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,106 @@
from unittest.mock import patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from models.model import AppMode
from services.app_task_service import AppTaskService
class TestAppTaskService:
"""Test suite for AppTaskService.stop_task method."""
@pytest.mark.parametrize(
("app_mode", "should_call_graph_engine"),
[
(AppMode.CHAT, False),
(AppMode.COMPLETION, False),
(AppMode.AGENT_CHAT, False),
(AppMode.CHANNEL, False),
(AppMode.RAG_PIPELINE, False),
(AppMode.ADVANCED_CHAT, True),
(AppMode.WORKFLOW, True),
],
)
@patch("services.app_task_service.AppQueueManager")
@patch("services.app_task_service.GraphEngineManager")
def test_stop_task_with_different_app_modes(
self, mock_graph_engine_manager, mock_app_queue_manager, app_mode, should_call_graph_engine
):
"""Test stop_task behavior with different app modes.
Verifies that:
- Legacy Redis flag is always set via AppQueueManager
- GraphEngine stop command is only sent for ADVANCED_CHAT and WORKFLOW modes
"""
# Arrange
task_id = "task-123"
invoke_from = InvokeFrom.WEB_APP
user_id = "user-456"
# Act
AppTaskService.stop_task(task_id, invoke_from, user_id, app_mode)
# Assert
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
if should_call_graph_engine:
mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id)
else:
mock_graph_engine_manager.send_stop_command.assert_not_called()
@pytest.mark.parametrize(
"invoke_from",
[
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
InvokeFrom.DEBUGGER,
InvokeFrom.EXPLORE,
],
)
@patch("services.app_task_service.AppQueueManager")
@patch("services.app_task_service.GraphEngineManager")
def test_stop_task_with_different_invoke_sources(
self, mock_graph_engine_manager, mock_app_queue_manager, invoke_from
):
"""Test stop_task behavior with different invoke sources.
Verifies that the method works correctly regardless of the invoke source.
"""
# Arrange
task_id = "task-789"
user_id = "user-999"
app_mode = AppMode.ADVANCED_CHAT
# Act
AppTaskService.stop_task(task_id, invoke_from, user_id, app_mode)
# Assert
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id)
@patch("services.app_task_service.GraphEngineManager")
@patch("services.app_task_service.AppQueueManager")
def test_stop_task_legacy_mechanism_called_even_if_graph_engine_fails(
self, mock_app_queue_manager, mock_graph_engine_manager
):
"""Test that legacy Redis flag is set even if GraphEngine fails.
This ensures backward compatibility: the legacy mechanism should complete
before attempting the GraphEngine command, so the stop flag is set
regardless of GraphEngine success.
"""
# Arrange
task_id = "task-123"
invoke_from = InvokeFrom.WEB_APP
user_id = "user-456"
app_mode = AppMode.ADVANCED_CHAT
# Simulate GraphEngine failure
mock_graph_engine_manager.send_stop_command.side_effect = Exception("GraphEngine error")
# Act & Assert - should raise the exception since it's not caught
with pytest.raises(Exception, match="GraphEngine error"):
AppTaskService.stop_task(task_id, invoke_from, user_id, app_mode)
# Verify legacy mechanism was still called before the exception
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)

View File

@@ -0,0 +1,236 @@
import json
from unittest.mock import MagicMock, patch
import httpx
import pytest
from werkzeug.exceptions import InternalServerError
from services.billing_service import BillingService
class TestBillingServiceSendRequest:
"""Unit tests for BillingService._send_request method."""
@pytest.fixture
def mock_httpx_request(self):
"""Mock httpx.request for testing."""
with patch("services.billing_service.httpx.request") as mock_request:
yield mock_request
@pytest.fixture
def mock_billing_config(self):
"""Mock BillingService configuration."""
with (
patch.object(BillingService, "base_url", "https://billing-api.example.com"),
patch.object(BillingService, "secret_key", "test-secret-key"),
):
yield
def test_get_request_success(self, mock_httpx_request, mock_billing_config):
"""Test successful GET request."""
# Arrange
expected_response = {"result": "success", "data": {"info": "test"}}
mock_response = MagicMock()
mock_response.status_code = httpx.codes.OK
mock_response.json.return_value = expected_response
mock_httpx_request.return_value = mock_response
# Act
result = BillingService._send_request("GET", "/test", params={"key": "value"})
# Assert
assert result == expected_response
mock_httpx_request.assert_called_once()
call_args = mock_httpx_request.call_args
assert call_args[0][0] == "GET"
assert call_args[0][1] == "https://billing-api.example.com/test"
assert call_args[1]["params"] == {"key": "value"}
assert call_args[1]["headers"]["Billing-Api-Secret-Key"] == "test-secret-key"
assert call_args[1]["headers"]["Content-Type"] == "application/json"
@pytest.mark.parametrize(
"status_code", [httpx.codes.NOT_FOUND, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.BAD_REQUEST]
)
def test_get_request_non_200_status_code(self, mock_httpx_request, mock_billing_config, status_code):
"""Test GET request with non-200 status code raises ValueError."""
# Arrange
mock_response = MagicMock()
mock_response.status_code = status_code
mock_httpx_request.return_value = mock_response
# Act & Assert
with pytest.raises(ValueError) as exc_info:
BillingService._send_request("GET", "/test")
assert "Unable to retrieve billing information" in str(exc_info.value)
def test_put_request_success(self, mock_httpx_request, mock_billing_config):
"""Test successful PUT request."""
# Arrange
expected_response = {"result": "success"}
mock_response = MagicMock()
mock_response.status_code = httpx.codes.OK
mock_response.json.return_value = expected_response
mock_httpx_request.return_value = mock_response
# Act
result = BillingService._send_request("PUT", "/test", json={"key": "value"})
# Assert
assert result == expected_response
call_args = mock_httpx_request.call_args
assert call_args[0][0] == "PUT"
def test_put_request_internal_server_error(self, mock_httpx_request, mock_billing_config):
"""Test PUT request with INTERNAL_SERVER_ERROR raises InternalServerError."""
# Arrange
mock_response = MagicMock()
mock_response.status_code = httpx.codes.INTERNAL_SERVER_ERROR
mock_httpx_request.return_value = mock_response
# Act & Assert
with pytest.raises(InternalServerError) as exc_info:
BillingService._send_request("PUT", "/test", json={"key": "value"})
assert exc_info.value.code == 500
assert "Unable to process billing request" in str(exc_info.value.description)
@pytest.mark.parametrize(
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.NOT_FOUND, httpx.codes.UNAUTHORIZED, httpx.codes.FORBIDDEN]
)
def test_put_request_non_200_non_500(self, mock_httpx_request, mock_billing_config, status_code):
"""Test PUT request with non-200 and non-500 status code raises ValueError."""
# Arrange
mock_response = MagicMock()
mock_response.status_code = status_code
mock_httpx_request.return_value = mock_response
# Act & Assert
with pytest.raises(ValueError) as exc_info:
BillingService._send_request("PUT", "/test", json={"key": "value"})
assert "Invalid arguments." in str(exc_info.value)
@pytest.mark.parametrize("method", ["POST", "DELETE"])
def test_non_get_non_put_request_success(self, mock_httpx_request, mock_billing_config, method):
"""Test successful POST/DELETE request."""
# Arrange
expected_response = {"result": "success"}
mock_response = MagicMock()
mock_response.status_code = httpx.codes.OK
mock_response.json.return_value = expected_response
mock_httpx_request.return_value = mock_response
# Act
result = BillingService._send_request(method, "/test", json={"key": "value"})
# Assert
assert result == expected_response
call_args = mock_httpx_request.call_args
assert call_args[0][0] == method
@pytest.mark.parametrize(
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND]
)
def test_post_request_non_200_with_valid_json(self, mock_httpx_request, mock_billing_config, status_code):
"""Test POST request with non-200 status code raises ValueError."""
# Arrange
error_response = {"detail": "Error message"}
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.json.return_value = error_response
mock_httpx_request.return_value = mock_response
# Act & Assert
with pytest.raises(ValueError) as exc_info:
BillingService._send_request("POST", "/test", json={"key": "value"})
assert "Unable to send request to" in str(exc_info.value)
@pytest.mark.parametrize(
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND]
)
def test_delete_request_non_200_with_valid_json(self, mock_httpx_request, mock_billing_config, status_code):
"""Test DELETE request with non-200 status code but valid JSON response.
DELETE doesn't check status code, so it returns the error JSON.
"""
# Arrange
error_response = {"detail": "Error message"}
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.json.return_value = error_response
mock_httpx_request.return_value = mock_response
# Act
result = BillingService._send_request("DELETE", "/test", json={"key": "value"})
# Assert
assert result == error_response
@pytest.mark.parametrize(
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND]
)
def test_post_request_non_200_with_invalid_json(self, mock_httpx_request, mock_billing_config, status_code):
"""Test POST request with non-200 status code raises ValueError before JSON parsing."""
# Arrange
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.text = ""
mock_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
mock_httpx_request.return_value = mock_response
# Act & Assert
# POST checks status code before calling response.json(), so ValueError is raised
with pytest.raises(ValueError) as exc_info:
BillingService._send_request("POST", "/test", json={"key": "value"})
assert "Unable to send request to" in str(exc_info.value)
@pytest.mark.parametrize(
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND]
)
def test_delete_request_non_200_with_invalid_json(self, mock_httpx_request, mock_billing_config, status_code):
"""Test DELETE request with non-200 status code and invalid JSON response raises exception.
DELETE doesn't check status code, so it calls response.json() which raises JSONDecodeError
when the response cannot be parsed as JSON (e.g., empty response).
"""
# Arrange
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.text = ""
mock_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
mock_httpx_request.return_value = mock_response
# Act & Assert
with pytest.raises(json.JSONDecodeError):
BillingService._send_request("DELETE", "/test", json={"key": "value"})
def test_retry_on_request_error(self, mock_httpx_request, mock_billing_config):
"""Test that _send_request retries on httpx.RequestError."""
# Arrange
expected_response = {"result": "success"}
mock_response = MagicMock()
mock_response.status_code = httpx.codes.OK
mock_response.json.return_value = expected_response
# First call raises RequestError, second succeeds
mock_httpx_request.side_effect = [
httpx.RequestError("Network error"),
mock_response,
]
# Act
result = BillingService._send_request("GET", "/test")
# Assert
assert result == expected_response
assert mock_httpx_request.call_count == 2
def test_retry_exhausted_raises_exception(self, mock_httpx_request, mock_billing_config):
"""Test that _send_request raises exception after retries are exhausted."""
# Arrange
mock_httpx_request.side_effect = httpx.RequestError("Network error")
# Act & Assert
with pytest.raises(httpx.RequestError):
BillingService._send_request("GET", "/test")
# Should retry multiple times (wait=2, stop_before_delay=10 means ~5 attempts)
assert mock_httpx_request.call_count > 1

View File

@@ -0,0 +1,168 @@
import datetime
from unittest.mock import Mock, patch
import pytest
from sqlalchemy.orm import Session
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
class TestClearFreePlanTenantExpiredLogs:
"""Unit tests for ClearFreePlanTenantExpiredLogs._clear_message_related_tables method."""
@pytest.fixture
def mock_session(self):
"""Create a mock database session."""
session = Mock(spec=Session)
session.query.return_value.filter.return_value.all.return_value = []
session.query.return_value.filter.return_value.delete.return_value = 0
return session
@pytest.fixture
def mock_storage(self):
"""Create a mock storage object."""
storage = Mock()
storage.save.return_value = None
return storage
@pytest.fixture
def sample_message_ids(self):
"""Sample message IDs for testing."""
return ["msg-1", "msg-2", "msg-3"]
@pytest.fixture
def sample_records(self):
"""Sample records for testing."""
records = []
for i in range(3):
record = Mock()
record.id = f"record-{i}"
record.to_dict.return_value = {
"id": f"record-{i}",
"message_id": f"msg-{i}",
"created_at": datetime.datetime.now().isoformat(),
}
records.append(record)
return records
def test_clear_message_related_tables_empty_message_ids(self, mock_session):
"""Test that method returns early when message_ids is empty."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", [])
# Should not call any database operations
mock_session.query.assert_not_called()
mock_storage.save.assert_not_called()
def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids):
"""Test when no related records are found."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.where.return_value.all.return_value = []
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should call query for each related table but find no records
assert mock_session.query.call_count > 0
mock_storage.save.assert_not_called()
def test_clear_message_related_tables_with_records_and_to_dict(
self, mock_session, sample_message_ids, sample_records
):
"""Test when records are found and have to_dict method."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.where.return_value.all.return_value = sample_records
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should call to_dict on each record (called once per table, so 7 times total)
for record in sample_records:
assert record.to_dict.call_count == 7
# Should save backup data
assert mock_storage.save.call_count > 0
def test_clear_message_related_tables_with_records_no_to_dict(self, mock_session, sample_message_ids):
"""Test when records are found but don't have to_dict method."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
# Create records without to_dict method
records = []
for i in range(2):
record = Mock()
mock_table = Mock()
mock_id_column = Mock()
mock_id_column.name = "id"
mock_message_id_column = Mock()
mock_message_id_column.name = "message_id"
mock_table.columns = [mock_id_column, mock_message_id_column]
record.__table__ = mock_table
record.id = f"record-{i}"
record.message_id = f"msg-{i}"
del record.to_dict
records.append(record)
# Mock records for first table only, empty for others
mock_session.query.return_value.where.return_value.all.side_effect = [
records,
[],
[],
[],
[],
[],
[],
]
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should save backup data even without to_dict
assert mock_storage.save.call_count > 0
def test_clear_message_related_tables_storage_error_continues(
self, mock_session, sample_message_ids, sample_records
):
"""Test that method continues even when storage.save fails."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_storage.save.side_effect = Exception("Storage error")
mock_session.query.return_value.where.return_value.all.return_value = sample_records
# Should not raise exception
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should still delete records even if backup fails
assert mock_session.query.return_value.where.return_value.delete.called
def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids):
"""Test that method continues even when record serialization fails."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
record = Mock()
record.id = "record-1"
record.to_dict.side_effect = Exception("Serialization error")
mock_session.query.return_value.where.return_value.all.return_value = [record]
# Should not raise exception
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should still delete records even if serialization fails
assert mock_session.query.return_value.where.return_value.delete.called
def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records):
"""Test that deletion is called for found records."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.where.return_value.all.return_value = sample_records
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should call delete for each table that has records
assert mock_session.query.return_value.where.return_value.delete.called
def test_clear_message_related_tables_logging_output(
self, mock_session, sample_message_ids, sample_records, capsys
):
"""Test that logging output is generated."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.where.return_value.all.return_value = sample_records
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
pass

View File

@@ -0,0 +1,127 @@
import uuid
from unittest.mock import MagicMock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from services.conversation_service import ConversationService
class TestConversationService:
def test_pagination_with_empty_include_ids(self):
"""Test that empty include_ids returns empty result"""
mock_session = MagicMock()
mock_app_model = MagicMock(id=str(uuid.uuid4()))
mock_user = MagicMock(id=str(uuid.uuid4()))
result = ConversationService.pagination_by_last_id(
session=mock_session,
app_model=mock_app_model,
user=mock_user,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
include_ids=[], # Empty include_ids should return empty result
exclude_ids=None,
)
assert result.data == []
assert result.has_more is False
assert result.limit == 20
def test_pagination_with_non_empty_include_ids(self):
"""Test that non-empty include_ids filters properly"""
mock_session = MagicMock()
mock_app_model = MagicMock(id=str(uuid.uuid4()))
mock_user = MagicMock(id=str(uuid.uuid4()))
# Mock the query results
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)]
mock_session.scalars.return_value.all.return_value = mock_conversations
mock_session.scalar.return_value = 0
with patch("services.conversation_service.select") as mock_select:
mock_stmt = MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
mock_stmt.order_by.return_value = mock_stmt
mock_stmt.limit.return_value = mock_stmt
mock_stmt.subquery.return_value = MagicMock()
result = ConversationService.pagination_by_last_id(
session=mock_session,
app_model=mock_app_model,
user=mock_user,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
include_ids=["conv1", "conv2"], # Non-empty include_ids
exclude_ids=None,
)
# Verify the where clause was called with id.in_
assert mock_stmt.where.called
def test_pagination_with_empty_exclude_ids(self):
"""Test that empty exclude_ids doesn't filter"""
mock_session = MagicMock()
mock_app_model = MagicMock(id=str(uuid.uuid4()))
mock_user = MagicMock(id=str(uuid.uuid4()))
# Mock the query results
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(5)]
mock_session.scalars.return_value.all.return_value = mock_conversations
mock_session.scalar.return_value = 0
with patch("services.conversation_service.select") as mock_select:
mock_stmt = MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
mock_stmt.order_by.return_value = mock_stmt
mock_stmt.limit.return_value = mock_stmt
mock_stmt.subquery.return_value = MagicMock()
result = ConversationService.pagination_by_last_id(
session=mock_session,
app_model=mock_app_model,
user=mock_user,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
include_ids=None,
exclude_ids=[], # Empty exclude_ids should not filter
)
# Result should contain the mocked conversations
assert len(result.data) == 5
def test_pagination_with_non_empty_exclude_ids(self):
"""Test that non-empty exclude_ids filters properly"""
mock_session = MagicMock()
mock_app_model = MagicMock(id=str(uuid.uuid4()))
mock_user = MagicMock(id=str(uuid.uuid4()))
# Mock the query results
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)]
mock_session.scalars.return_value.all.return_value = mock_conversations
mock_session.scalar.return_value = 0
with patch("services.conversation_service.select") as mock_select:
mock_stmt = MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
mock_stmt.order_by.return_value = mock_stmt
mock_stmt.limit.return_value = mock_stmt
mock_stmt.subquery.return_value = MagicMock()
result = ConversationService.pagination_by_last_id(
session=mock_session,
app_model=mock_app_model,
user=mock_user,
last_id=None,
limit=20,
invoke_from=InvokeFrom.WEB_APP,
include_ids=None,
exclude_ids=["conv1", "conv2"], # Non-empty exclude_ids
)
# Verify the where clause was called for exclusion
assert mock_stmt.where.called

View File

@@ -0,0 +1,305 @@
from unittest.mock import Mock, patch
import pytest
from models.account import Account, TenantAccountRole
from models.dataset import Dataset, DatasetPermission, DatasetPermissionEnum
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
class DatasetPermissionTestDataFactory:
"""Factory class for creating test data and mock objects for dataset permission tests."""
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
tenant_id: str = "test-tenant-123",
created_by: str = "creator-456",
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
**kwargs,
) -> Mock:
"""Create a mock dataset with specified attributes."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.created_by = created_by
dataset.permission = permission
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_user_mock(
user_id: str = "user-789",
tenant_id: str = "test-tenant-123",
role: TenantAccountRole = TenantAccountRole.NORMAL,
**kwargs,
) -> Mock:
"""Create a mock user with specified attributes."""
user = Mock(spec=Account)
user.id = user_id
user.current_tenant_id = tenant_id
user.current_role = role
for key, value in kwargs.items():
setattr(user, key, value)
return user
@staticmethod
def create_dataset_permission_mock(
dataset_id: str = "dataset-123",
account_id: str = "user-789",
**kwargs,
) -> Mock:
"""Create a mock dataset permission record."""
permission = Mock(spec=DatasetPermission)
permission.dataset_id = dataset_id
permission.account_id = account_id
for key, value in kwargs.items():
setattr(permission, key, value)
return permission
class TestDatasetPermissionService:
"""
Comprehensive unit tests for DatasetService.check_dataset_permission method.
This test suite covers all permission scenarios including:
- Cross-tenant access restrictions
- Owner privilege checks
- Different permission levels (ONLY_ME, ALL_TEAM, PARTIAL_TEAM)
- Explicit permission checks for PARTIAL_TEAM
- Error conditions and logging
"""
@pytest.fixture
def mock_dataset_service_dependencies(self):
"""Common mock setup for dataset service dependencies."""
with patch("services.dataset_service.db.session") as mock_session:
yield {
"db_session": mock_session,
}
@pytest.fixture
def mock_logging_dependencies(self):
"""Mock setup for logging tests."""
with patch("services.dataset_service.logger") as mock_logging:
yield {
"logging": mock_logging,
}
def _assert_permission_check_passes(self, dataset: Mock, user: Mock):
"""Helper method to verify that permission check passes without raising exceptions."""
# Should not raise any exception
DatasetService.check_dataset_permission(dataset, user)
def _assert_permission_check_fails(
self, dataset: Mock, user: Mock, expected_message: str = "You do not have permission to access this dataset."
):
"""Helper method to verify that permission check fails with expected error."""
with pytest.raises(NoPermissionError, match=expected_message):
DatasetService.check_dataset_permission(dataset, user)
def _assert_database_query_called(self, mock_session: Mock, dataset_id: str, account_id: str):
"""Helper method to verify database query calls for permission checks."""
mock_session.query().filter_by.assert_called_with(dataset_id=dataset_id, account_id=account_id)
def _assert_database_query_not_called(self, mock_session: Mock):
"""Helper method to verify that database query was not called."""
mock_session.query.assert_not_called()
# ==================== Cross-Tenant Access Tests ====================
def test_permission_check_different_tenant_should_fail(self):
"""Test that users from different tenants cannot access dataset regardless of other permissions."""
# Create dataset and user from different tenants
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
tenant_id="tenant-123", permission=DatasetPermissionEnum.ALL_TEAM
)
user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="user-789", tenant_id="different-tenant-456", role=TenantAccountRole.EDITOR
)
# Should fail due to different tenant
self._assert_permission_check_fails(dataset, user)
# ==================== Owner Privilege Tests ====================
def test_owner_can_access_any_dataset(self):
"""Test that tenant owners can access any dataset regardless of permission level."""
# Create dataset with restrictive permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME)
# Create owner user
owner_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="owner-999", role=TenantAccountRole.OWNER
)
# Owner should have access regardless of dataset permission
self._assert_permission_check_passes(dataset, owner_user)
# ==================== ONLY_ME Permission Tests ====================
def test_only_me_permission_creator_can_access(self):
"""Test ONLY_ME permission allows only the dataset creator to access."""
# Create dataset with ONLY_ME permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME
)
# Create creator user
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="creator-456", role=TenantAccountRole.EDITOR
)
# Creator should be able to access
self._assert_permission_check_passes(dataset, creator_user)
def test_only_me_permission_others_cannot_access(self):
"""Test ONLY_ME permission denies access to non-creators."""
# Create dataset with ONLY_ME permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME
)
# Create normal user (not the creator)
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="normal-789", role=TenantAccountRole.NORMAL
)
# Non-creator should be denied access
self._assert_permission_check_fails(dataset, normal_user)
# ==================== ALL_TEAM Permission Tests ====================
def test_all_team_permission_allows_access(self):
"""Test ALL_TEAM permission allows any team member to access the dataset."""
# Create dataset with ALL_TEAM permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ALL_TEAM)
# Create different types of team members
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="normal-789", role=TenantAccountRole.NORMAL
)
editor_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="editor-456", role=TenantAccountRole.EDITOR
)
# All team members should have access
self._assert_permission_check_passes(dataset, normal_user)
self._assert_permission_check_passes(dataset, editor_user)
# ==================== PARTIAL_TEAM Permission Tests ====================
def test_partial_team_permission_creator_can_access(self, mock_dataset_service_dependencies):
"""Test PARTIAL_TEAM permission allows creator to access without database query."""
# Create dataset with PARTIAL_TEAM permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
)
# Create creator user
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="creator-456", role=TenantAccountRole.EDITOR
)
# Creator should have access without database query
self._assert_permission_check_passes(dataset, creator_user)
self._assert_database_query_not_called(mock_dataset_service_dependencies["db_session"])
def test_partial_team_permission_with_explicit_permission(self, mock_dataset_service_dependencies):
"""Test PARTIAL_TEAM permission allows users with explicit permission records."""
# Create dataset with PARTIAL_TEAM permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
# Create normal user (not the creator)
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="normal-789", role=TenantAccountRole.NORMAL
)
# Mock database query to return a permission record
mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock(
dataset_id=dataset.id, account_id=normal_user.id
)
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = mock_permission
# User with explicit permission should have access
self._assert_permission_check_passes(dataset, normal_user)
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id)
def test_partial_team_permission_without_explicit_permission(self, mock_dataset_service_dependencies):
"""Test PARTIAL_TEAM permission denies users without explicit permission records."""
# Create dataset with PARTIAL_TEAM permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
# Create normal user (not the creator)
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="normal-789", role=TenantAccountRole.NORMAL
)
# Mock database query to return None (no permission record)
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
# User without explicit permission should be denied access
self._assert_permission_check_fails(dataset, normal_user)
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id)
def test_partial_team_permission_non_creator_without_permission_fails(self, mock_dataset_service_dependencies):
"""Test that non-creators without explicit permission are denied access to PARTIAL_TEAM datasets."""
# Create dataset with PARTIAL_TEAM permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
)
# Create a different user (not the creator)
other_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="other-user-123", role=TenantAccountRole.NORMAL
)
# Mock database query to return None (no permission record)
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
# Non-creator without explicit permission should be denied access
self._assert_permission_check_fails(dataset, other_user)
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, other_user.id)
# ==================== Enum Usage Tests ====================
def test_partial_team_permission_uses_correct_enum(self):
"""Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM instead of string literals."""
# Create dataset with PARTIAL_TEAM permission using enum
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
)
# Create creator user
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="creator-456", role=TenantAccountRole.EDITOR
)
# Creator should always have access regardless of permission level
self._assert_permission_check_passes(dataset, creator_user)
# ==================== Logging Tests ====================
def test_permission_denied_logs_debug_message(self, mock_dataset_service_dependencies, mock_logging_dependencies):
"""Test that permission denied events are properly logged for debugging purposes."""
# Create dataset with PARTIAL_TEAM permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
# Create normal user (not the creator)
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="normal-789", role=TenantAccountRole.NORMAL
)
# Mock database query to return None (no permission record)
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
# Attempt permission check (should fail)
with pytest.raises(NoPermissionError):
DatasetService.check_dataset_permission(dataset, normal_user)
# Verify debug message was logged with correct user and dataset information
mock_logging_dependencies["logging"].debug.assert_called_with(
"User %s does not have permission to access dataset %s", normal_user.id, dataset.id
)

View File

@@ -0,0 +1,800 @@
import datetime
# Mock redis_client before importing dataset_service
from unittest.mock import Mock, call, patch
import pytest
from models.dataset import Dataset, Document
from services.dataset_service import DocumentService
from services.errors.document import DocumentIndexingError
from tests.unit_tests.conftest import redis_mock
class DocumentBatchUpdateTestDataFactory:
"""Factory class for creating test data and mock objects for document batch update tests."""
@staticmethod
def create_dataset_mock(dataset_id: str = "dataset-123", tenant_id: str = "tenant-456") -> Mock:
"""Create a mock dataset with specified attributes."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
return dataset
@staticmethod
def create_user_mock(user_id: str = "user-789") -> Mock:
"""Create a mock user."""
user = Mock()
user.id = user_id
return user
@staticmethod
def create_document_mock(
document_id: str = "doc-1",
name: str = "test_document.pdf",
enabled: bool = True,
archived: bool = False,
indexing_status: str = "completed",
completed_at: datetime.datetime | None = None,
**kwargs,
) -> Mock:
"""Create a mock document with specified attributes."""
document = Mock(spec=Document)
document.id = document_id
document.name = name
document.enabled = enabled
document.archived = archived
document.indexing_status = indexing_status
document.completed_at = completed_at or datetime.datetime.now()
# Set default values for optional fields
document.disabled_at = None
document.disabled_by = None
document.archived_at = None
document.archived_by = None
document.updated_at = None
for key, value in kwargs.items():
setattr(document, key, value)
return document
@staticmethod
def create_multiple_documents(
document_ids: list[str], enabled: bool = True, archived: bool = False, indexing_status: str = "completed"
) -> list[Mock]:
"""Create multiple mock documents with specified attributes."""
documents = []
for doc_id in document_ids:
doc = DocumentBatchUpdateTestDataFactory.create_document_mock(
document_id=doc_id,
name=f"document_{doc_id}.pdf",
enabled=enabled,
archived=archived,
indexing_status=indexing_status,
)
documents.append(doc)
return documents
class TestDatasetServiceBatchUpdateDocumentStatus:
"""
Comprehensive unit tests for DocumentService.batch_update_document_status method.
This test suite covers all supported actions (enable, disable, archive, un_archive),
error conditions, edge cases, and validates proper interaction with Redis cache,
database operations, and async task triggers.
"""
@pytest.fixture
def mock_document_service_dependencies(self):
"""Common mock setup for document service dependencies."""
with (
patch("services.dataset_service.DocumentService.get_document") as mock_get_doc,
patch("extensions.ext_database.db.session") as mock_db,
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
):
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_naive_utc_now.return_value = current_time
yield {
"get_document": mock_get_doc,
"db_session": mock_db,
"naive_utc_now": mock_naive_utc_now,
"current_time": current_time,
}
@pytest.fixture
def mock_async_task_dependencies(self):
"""Mock setup for async task dependencies."""
with (
patch("services.dataset_service.add_document_to_index_task") as mock_add_task,
patch("services.dataset_service.remove_document_from_index_task") as mock_remove_task,
):
yield {"add_task": mock_add_task, "remove_task": mock_remove_task}
def _assert_document_enabled(self, document: Mock, user_id: str, current_time: datetime.datetime):
"""Helper method to verify document was enabled correctly."""
assert document.enabled == True
assert document.disabled_at is None
assert document.disabled_by is None
assert document.updated_at == current_time
def _assert_document_disabled(self, document: Mock, user_id: str, current_time: datetime.datetime):
"""Helper method to verify document was disabled correctly."""
assert document.enabled == False
assert document.disabled_at == current_time
assert document.disabled_by == user_id
assert document.updated_at == current_time
def _assert_document_archived(self, document: Mock, user_id: str, current_time: datetime.datetime):
"""Helper method to verify document was archived correctly."""
assert document.archived == True
assert document.archived_at == current_time
assert document.archived_by == user_id
assert document.updated_at == current_time
def _assert_document_unarchived(self, document: Mock):
"""Helper method to verify document was unarchived correctly."""
assert document.archived == False
assert document.archived_at is None
assert document.archived_by is None
def _assert_redis_cache_operations(self, document_ids: list[str], action: str = "setex"):
"""Helper method to verify Redis cache operations."""
if action == "setex":
expected_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids]
redis_mock.setex.assert_has_calls(expected_calls)
elif action == "get":
expected_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids]
redis_mock.get.assert_has_calls(expected_calls)
def _assert_async_task_calls(self, mock_task, document_ids: list[str], task_type: str):
"""Helper method to verify async task calls."""
expected_calls = [call(doc_id) for doc_id in document_ids]
if task_type in {"add", "remove"}:
mock_task.delay.assert_has_calls(expected_calls)
# ==================== Enable Document Tests ====================
def test_batch_update_enable_documents_success(
self, mock_document_service_dependencies, mock_async_task_dependencies
):
"""Test successful enabling of disabled documents."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create disabled documents
disabled_docs = DocumentBatchUpdateTestDataFactory.create_multiple_documents(["doc-1", "doc-2"], enabled=False)
mock_document_service_dependencies["get_document"].side_effect = disabled_docs
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Call the method to enable documents
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-1", "doc-2"], action="enable", user=user
)
# Verify document attributes were updated correctly
for doc in disabled_docs:
self._assert_document_enabled(doc, user.id, mock_document_service_dependencies["current_time"])
# Verify Redis cache operations
self._assert_redis_cache_operations(["doc-1", "doc-2"], "get")
self._assert_redis_cache_operations(["doc-1", "doc-2"], "setex")
# Verify async tasks were triggered for indexing
self._assert_async_task_calls(mock_async_task_dependencies["add_task"], ["doc-1", "doc-2"], "add")
# Verify database operations
mock_db = mock_document_service_dependencies["db_session"]
assert mock_db.add.call_count == 2
assert mock_db.commit.call_count == 1
def test_batch_update_enable_already_enabled_document_skipped(self, mock_document_service_dependencies):
"""Test enabling documents that are already enabled."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create already enabled document
enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True)
mock_document_service_dependencies["get_document"].return_value = enabled_doc
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Attempt to enable already enabled document
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-1"], action="enable", user=user
)
# Verify no database operations occurred (document was skipped)
mock_db = mock_document_service_dependencies["db_session"]
mock_db.commit.assert_not_called()
# Verify no Redis setex operations occurred (document was skipped)
redis_mock.setex.assert_not_called()
# ==================== Disable Document Tests ====================
def test_batch_update_disable_documents_success(
self, mock_document_service_dependencies, mock_async_task_dependencies
):
"""Test successful disabling of enabled and completed documents."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create enabled documents
enabled_docs = DocumentBatchUpdateTestDataFactory.create_multiple_documents(["doc-1", "doc-2"], enabled=True)
mock_document_service_dependencies["get_document"].side_effect = enabled_docs
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Call the method to disable documents
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-1", "doc-2"], action="disable", user=user
)
# Verify document attributes were updated correctly
for doc in enabled_docs:
self._assert_document_disabled(doc, user.id, mock_document_service_dependencies["current_time"])
# Verify Redis cache operations for indexing prevention
self._assert_redis_cache_operations(["doc-1", "doc-2"], "setex")
# Verify async tasks were triggered to remove from index
self._assert_async_task_calls(mock_async_task_dependencies["remove_task"], ["doc-1", "doc-2"], "remove")
# Verify database operations
mock_db = mock_document_service_dependencies["db_session"]
assert mock_db.add.call_count == 2
assert mock_db.commit.call_count == 1
def test_batch_update_disable_already_disabled_document_skipped(
self, mock_document_service_dependencies, mock_async_task_dependencies
):
"""Test disabling documents that are already disabled."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create already disabled document
disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False)
mock_document_service_dependencies["get_document"].return_value = disabled_doc
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Attempt to disable already disabled document
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-1"], action="disable", user=user
)
# Verify no database operations occurred (document was skipped)
mock_db = mock_document_service_dependencies["db_session"]
mock_db.commit.assert_not_called()
# Verify no Redis setex operations occurred (document was skipped)
redis_mock.setex.assert_not_called()
# Verify no async tasks were triggered (document was skipped)
mock_async_task_dependencies["add_task"].delay.assert_not_called()
def test_batch_update_disable_non_completed_document_error(self, mock_document_service_dependencies):
"""Test that DocumentIndexingError is raised when trying to disable non-completed documents."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create a document that's not completed
non_completed_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(
enabled=True,
indexing_status="indexing", # Not completed
completed_at=None, # Not completed
)
mock_document_service_dependencies["get_document"].return_value = non_completed_doc
# Verify that DocumentIndexingError is raised
with pytest.raises(DocumentIndexingError) as exc_info:
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-1"], action="disable", user=user
)
# Verify error message indicates document is not completed
assert "is not completed" in str(exc_info.value)
# ==================== Archive Document Tests ====================
def test_batch_update_archive_documents_success(
self, mock_document_service_dependencies, mock_async_task_dependencies
):
"""Test successful archiving of unarchived documents."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create unarchived enabled document
unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=False)
mock_document_service_dependencies["get_document"].return_value = unarchived_doc
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Call the method to archive documents
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-1"], action="archive", user=user
)
# Verify document attributes were updated correctly
self._assert_document_archived(unarchived_doc, user.id, mock_document_service_dependencies["current_time"])
# Verify Redis cache was set (because document was enabled)
redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1)
# Verify async task was triggered to remove from index (because enabled)
mock_async_task_dependencies["remove_task"].delay.assert_called_once_with("doc-1")
# Verify database operations
mock_db = mock_document_service_dependencies["db_session"]
mock_db.add.assert_called_once()
mock_db.commit.assert_called_once()
def test_batch_update_archive_already_archived_document_skipped(self, mock_document_service_dependencies):
"""Test archiving documents that are already archived."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create already archived document
archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=True)
mock_document_service_dependencies["get_document"].return_value = archived_doc
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Attempt to archive already archived document
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-3"], action="archive", user=user
)
# Verify no database operations occurred (document was skipped)
mock_db = mock_document_service_dependencies["db_session"]
mock_db.commit.assert_not_called()
# Verify no Redis setex operations occurred (document was skipped)
redis_mock.setex.assert_not_called()
def test_batch_update_archive_disabled_document_no_index_removal(
self, mock_document_service_dependencies, mock_async_task_dependencies
):
"""Test archiving disabled documents (should not trigger index removal)."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Set up disabled, unarchived document
disabled_unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False, archived=False)
mock_document_service_dependencies["get_document"].return_value = disabled_unarchived_doc
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Archive the disabled document
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-1"], action="archive", user=user
)
# Verify document was archived
self._assert_document_archived(
disabled_unarchived_doc, user.id, mock_document_service_dependencies["current_time"]
)
# Verify no Redis cache was set (document is disabled)
redis_mock.setex.assert_not_called()
# Verify no index removal task was triggered (document is disabled)
mock_async_task_dependencies["remove_task"].delay.assert_not_called()
# Verify database operations still occurred
mock_db = mock_document_service_dependencies["db_session"]
mock_db.add.assert_called_once()
mock_db.commit.assert_called_once()
# ==================== Unarchive Document Tests ====================
def test_batch_update_unarchive_documents_success(
self, mock_document_service_dependencies, mock_async_task_dependencies
):
"""Test successful unarchiving of archived documents."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create mock archived document
archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=True)
mock_document_service_dependencies["get_document"].return_value = archived_doc
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Call the method to unarchive documents
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user
)
# Verify document attributes were updated correctly
self._assert_document_unarchived(archived_doc)
assert archived_doc.updated_at == mock_document_service_dependencies["current_time"]
# Verify Redis cache was set (because document is enabled)
redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1)
# Verify async task was triggered to add back to index (because enabled)
mock_async_task_dependencies["add_task"].delay.assert_called_once_with("doc-1")
# Verify database operations
mock_db = mock_document_service_dependencies["db_session"]
mock_db.add.assert_called_once()
mock_db.commit.assert_called_once()
def test_batch_update_unarchive_already_unarchived_document_skipped(
self, mock_document_service_dependencies, mock_async_task_dependencies
):
"""Test unarchiving documents that are already unarchived."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create already unarchived document
unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=False)
mock_document_service_dependencies["get_document"].return_value = unarchived_doc
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Attempt to unarchive already unarchived document
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user
)
# Verify no database operations occurred (document was skipped)
mock_db = mock_document_service_dependencies["db_session"]
mock_db.commit.assert_not_called()
# Verify no Redis setex operations occurred (document was skipped)
redis_mock.setex.assert_not_called()
# Verify no async tasks were triggered (document was skipped)
mock_async_task_dependencies["add_task"].delay.assert_not_called()
def test_batch_update_unarchive_disabled_document_no_index_addition(
self, mock_document_service_dependencies, mock_async_task_dependencies
):
"""Test unarchiving disabled documents (should not trigger index addition)."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create mock archived but disabled document
archived_disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False, archived=True)
mock_document_service_dependencies["get_document"].return_value = archived_disabled_doc
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Unarchive the disabled document
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user
)
# Verify document was unarchived
self._assert_document_unarchived(archived_disabled_doc)
assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"]
# Verify no Redis cache was set (document is disabled)
redis_mock.setex.assert_not_called()
# Verify no index addition task was triggered (document is disabled)
mock_async_task_dependencies["add_task"].delay.assert_not_called()
# Verify database operations still occurred
mock_db = mock_document_service_dependencies["db_session"]
mock_db.add.assert_called_once()
mock_db.commit.assert_called_once()
# ==================== Error Handling Tests ====================
def test_batch_update_document_indexing_error_redis_cache_hit(self, mock_document_service_dependencies):
"""Test that DocumentIndexingError is raised when documents are currently being indexed."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create mock enabled document
enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True)
mock_document_service_dependencies["get_document"].return_value = enabled_doc
# Set up mock to indicate document is being indexed
redis_mock.reset_mock()
redis_mock.get.return_value = "indexing"
# Verify that DocumentIndexingError is raised
with pytest.raises(DocumentIndexingError) as exc_info:
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-1"], action="enable", user=user
)
# Verify error message contains document name
assert "test_document.pdf" in str(exc_info.value)
assert "is being indexed" in str(exc_info.value)
# Verify Redis cache was checked
redis_mock.get.assert_called_once_with("document_doc-1_indexing")
def test_batch_update_invalid_action_error(self, mock_document_service_dependencies):
"""Test that ValueError is raised when an invalid action is provided."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create mock document
doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True)
mock_document_service_dependencies["get_document"].return_value = doc
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Test with invalid action
invalid_action = "invalid_action"
with pytest.raises(ValueError) as exc_info:
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-1"], action=invalid_action, user=user
)
# Verify error message contains the invalid action
assert invalid_action in str(exc_info.value)
assert "Invalid action" in str(exc_info.value)
# Verify no Redis operations occurred
redis_mock.setex.assert_not_called()
def test_batch_update_async_task_error_handling(
self, mock_document_service_dependencies, mock_async_task_dependencies
):
"""Test handling of async task errors during batch operations."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create mock disabled document
disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False)
mock_document_service_dependencies["get_document"].return_value = disabled_doc
# Mock async task to raise an exception
mock_async_task_dependencies["add_task"].delay.side_effect = Exception("Celery task error")
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Verify that async task error is propagated
with pytest.raises(Exception) as exc_info:
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-1"], action="enable", user=user
)
# Verify error message
assert "Celery task error" in str(exc_info.value)
# Verify database operations completed successfully
mock_db = mock_document_service_dependencies["db_session"]
mock_db.add.assert_called_once()
mock_db.commit.assert_called_once()
# Verify Redis cache was set successfully
redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1)
# Verify document was updated
self._assert_document_enabled(disabled_doc, user.id, mock_document_service_dependencies["current_time"])
# ==================== Edge Case Tests ====================
def test_batch_update_empty_document_list(self, mock_document_service_dependencies):
"""Test batch operations with an empty document ID list."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Call method with empty document list
result = DocumentService.batch_update_document_status(
dataset=dataset, document_ids=[], action="enable", user=user
)
# Verify no document lookups were performed
mock_document_service_dependencies["get_document"].assert_not_called()
# Verify method returns None (early return)
assert result is None
def test_batch_update_document_not_found_skipped(self, mock_document_service_dependencies):
"""Test behavior when some documents don't exist in the database."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Mock document service to return None (document not found)
mock_document_service_dependencies["get_document"].return_value = None
# Call method with non-existent document ID
# This should not raise an error, just skip the missing document
try:
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["non-existent-doc"], action="enable", user=user
)
except Exception as e:
pytest.fail(f"Method should not raise exception for missing documents: {e}")
# Verify document lookup was attempted
mock_document_service_dependencies["get_document"].assert_called_once_with(dataset.id, "non-existent-doc")
def test_batch_update_mixed_document_states_and_actions(
self, mock_document_service_dependencies, mock_async_task_dependencies
):
"""Test batch operations on documents with mixed states and various scenarios."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create documents in various states
disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-1", enabled=False)
enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-2", enabled=True)
archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-3", enabled=True, archived=True)
# Mix of different document states
documents = [disabled_doc, enabled_doc, archived_doc]
mock_document_service_dependencies["get_document"].side_effect = documents
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Perform enable operation on mixed state documents
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=["doc-1", "doc-2", "doc-3"], action="enable", user=user
)
# Verify only the disabled document was processed
# (enabled and archived documents should be skipped for enable action)
# Only one add should occur (for the disabled document that was enabled)
mock_db = mock_document_service_dependencies["db_session"]
mock_db.add.assert_called_once()
# Only one commit should occur
mock_db.commit.assert_called_once()
# Only one Redis setex should occur (for the document that was enabled)
redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1)
# Only one async task should be triggered (for the document that was enabled)
mock_async_task_dependencies["add_task"].delay.assert_called_once_with("doc-1")
# ==================== Performance Tests ====================
def test_batch_update_large_document_list_performance(
self, mock_document_service_dependencies, mock_async_task_dependencies
):
"""Test batch operations with a large number of documents."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create large list of document IDs
document_ids = [f"doc-{i}" for i in range(1, 101)] # 100 documents
# Create mock documents
mock_documents = DocumentBatchUpdateTestDataFactory.create_multiple_documents(
document_ids,
enabled=False, # All disabled, will be enabled
)
mock_document_service_dependencies["get_document"].side_effect = mock_documents
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Perform batch enable operation
DocumentService.batch_update_document_status(
dataset=dataset, document_ids=document_ids, action="enable", user=user
)
# Verify all documents were processed
assert mock_document_service_dependencies["get_document"].call_count == 100
# Verify all documents were updated
for mock_doc in mock_documents:
self._assert_document_enabled(mock_doc, user.id, mock_document_service_dependencies["current_time"])
# Verify database operations
mock_db = mock_document_service_dependencies["db_session"]
assert mock_db.add.call_count == 100
assert mock_db.commit.call_count == 1
# Verify Redis cache operations occurred for each document
assert redis_mock.setex.call_count == 100
# Verify async tasks were triggered for each document
assert mock_async_task_dependencies["add_task"].delay.call_count == 100
# Verify correct Redis cache keys were set
expected_redis_calls = [call(f"document_doc-{i}_indexing", 600, 1) for i in range(1, 101)]
redis_mock.setex.assert_has_calls(expected_redis_calls)
# Verify correct async task calls
expected_task_calls = [call(f"doc-{i}") for i in range(1, 101)]
mock_async_task_dependencies["add_task"].delay.assert_has_calls(expected_task_calls)
def test_batch_update_mixed_document_states_complex_scenario(
self, mock_document_service_dependencies, mock_async_task_dependencies
):
"""Test complex batch operations with documents in various states."""
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
# Create documents in various states
doc1 = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-1", enabled=False) # Will be enabled
doc2 = DocumentBatchUpdateTestDataFactory.create_document_mock(
"doc-2", enabled=True
) # Already enabled, will be skipped
doc3 = DocumentBatchUpdateTestDataFactory.create_document_mock(
"doc-3", enabled=True
) # Already enabled, will be skipped
doc4 = DocumentBatchUpdateTestDataFactory.create_document_mock(
"doc-4", enabled=True
) # Not affected by enable action
doc5 = DocumentBatchUpdateTestDataFactory.create_document_mock(
"doc-5", enabled=True, archived=True
) # Not affected by enable action
doc6 = None # Non-existent, will be skipped
mock_document_service_dependencies["get_document"].side_effect = [doc1, doc2, doc3, doc4, doc5, doc6]
# Reset module-level Redis mock
redis_mock.reset_mock()
redis_mock.get.return_value = None
# Perform mixed batch operations
DocumentService.batch_update_document_status(
dataset=dataset,
document_ids=["doc-1", "doc-2", "doc-3", "doc-4", "doc-5", "doc-6"],
action="enable", # This will only affect doc1
user=user,
)
# Verify document 1 was enabled
self._assert_document_enabled(doc1, user.id, mock_document_service_dependencies["current_time"])
# Verify other documents were skipped appropriately
assert doc2.enabled == True # No change
assert doc3.enabled == True # No change
assert doc4.enabled == True # No change
assert doc5.enabled == True # No change
# Verify database commits occurred for processed documents
# Only doc1 should be added (others were skipped, doc6 doesn't exist)
mock_db = mock_document_service_dependencies["db_session"]
assert mock_db.add.call_count == 1
assert mock_db.commit.call_count == 1
# Verify Redis cache operations occurred for processed documents
# Only doc1 should have Redis operations
assert redis_mock.setex.call_count == 1
# Verify async tasks were triggered for processed documents
# Only doc1 should trigger tasks
assert mock_async_task_dependencies["add_task"].delay.call_count == 1
# Verify correct Redis cache keys were set
expected_redis_calls = [call("document_doc-1_indexing", 600, 1)]
redis_mock.setex.assert_has_calls(expected_redis_calls)
# Verify correct async task calls
expected_task_calls = [call("doc-1")]
mock_async_task_dependencies["add_task"].delay.assert_has_calls(expected_task_calls)

View File

@@ -0,0 +1,819 @@
"""
Comprehensive unit tests for DatasetService creation methods.
This test suite covers:
- create_empty_dataset for internal datasets
- create_empty_dataset for external datasets
- create_empty_rag_pipeline_dataset
- Error conditions and edge cases
"""
from unittest.mock import Mock, create_autospec, patch
from uuid import uuid4
import pytest
from core.model_runtime.entities.model_entities import ModelType
from models.account import Account
from models.dataset import Dataset, Pipeline
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.entities.knowledge_entities.rag_pipeline_entities import (
IconInfo,
RagPipelineDatasetCreateEntity,
)
from services.errors.dataset import DatasetNameDuplicateError
class DatasetCreateTestDataFactory:
"""Factory class for creating test data and mock objects for dataset creation tests."""
@staticmethod
def create_account_mock(
account_id: str = "account-123",
tenant_id: str = "tenant-123",
**kwargs,
) -> Mock:
"""Create a mock account."""
account = create_autospec(Account, instance=True)
account.id = account_id
account.current_tenant_id = tenant_id
for key, value in kwargs.items():
setattr(account, key, value)
return account
@staticmethod
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
"""Create a mock embedding model."""
embedding_model = Mock()
embedding_model.model = model
embedding_model.provider = provider
return embedding_model
@staticmethod
def create_retrieval_model_mock() -> Mock:
"""Create a mock retrieval model."""
retrieval_model = Mock(spec=RetrievalModel)
retrieval_model.model_dump.return_value = {
"search_method": "semantic_search",
"top_k": 2,
"score_threshold": 0.0,
}
retrieval_model.reranking_model = None
return retrieval_model
@staticmethod
def create_external_knowledge_api_mock(api_id: str = "api-123", **kwargs) -> Mock:
"""Create a mock external knowledge API."""
api = Mock()
api.id = api_id
for key, value in kwargs.items():
setattr(api, key, value)
return api
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
name: str = "Test Dataset",
tenant_id: str = "tenant-123",
**kwargs,
) -> Mock:
"""Create a mock dataset."""
dataset = create_autospec(Dataset, instance=True)
dataset.id = dataset_id
dataset.name = name
dataset.tenant_id = tenant_id
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_pipeline_mock(
pipeline_id: str = "pipeline-123",
name: str = "Test Pipeline",
**kwargs,
) -> Mock:
"""Create a mock pipeline."""
pipeline = Mock(spec=Pipeline)
pipeline.id = pipeline_id
pipeline.name = name
for key, value in kwargs.items():
setattr(pipeline, key, value)
return pipeline
class TestDatasetServiceCreateEmptyDataset:
"""
Comprehensive unit tests for DatasetService.create_empty_dataset method.
This test suite covers:
- Internal dataset creation (vendor provider)
- External dataset creation
- High quality indexing technique with embedding models
- Economy indexing technique
- Retrieval model configuration
- Error conditions (duplicate names, missing external knowledge IDs)
"""
@pytest.fixture
def mock_dataset_service_dependencies(self):
"""Common mock setup for dataset service dependencies."""
with (
patch("services.dataset_service.db.session") as mock_db,
patch("services.dataset_service.ModelManager") as mock_model_manager,
patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding,
patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking,
patch("services.dataset_service.ExternalDatasetService") as mock_external_service,
):
yield {
"db_session": mock_db,
"model_manager": mock_model_manager,
"check_embedding": mock_check_embedding,
"check_reranking": mock_check_reranking,
"external_service": mock_external_service,
}
# ==================== Internal Dataset Creation Tests ====================
def test_create_internal_dataset_basic_success(self, mock_dataset_service_dependencies):
"""Test successful creation of basic internal dataset."""
# Arrange
tenant_id = str(uuid4())
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
name = "Test Dataset"
description = "Test description"
# Mock database query to return None (no duplicate name)
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
# Mock database session operations
mock_db = mock_dataset_service_dependencies["db_session"]
mock_db.add = Mock()
mock_db.flush = Mock()
mock_db.commit = Mock()
# Act
result = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=name,
description=description,
indexing_technique=None,
account=account,
)
# Assert
assert result is not None
assert result.name == name
assert result.description == description
assert result.tenant_id == tenant_id
assert result.created_by == account.id
assert result.updated_by == account.id
assert result.provider == "vendor"
assert result.permission == "only_me"
mock_db.add.assert_called_once()
mock_db.commit.assert_called_once()
def test_create_internal_dataset_with_economy_indexing(self, mock_dataset_service_dependencies):
"""Test successful creation of internal dataset with economy indexing."""
# Arrange
tenant_id = str(uuid4())
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
name = "Economy Dataset"
# Mock database query
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
mock_db = mock_dataset_service_dependencies["db_session"]
mock_db.add = Mock()
mock_db.flush = Mock()
mock_db.commit = Mock()
# Act
result = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=name,
description=None,
indexing_technique="economy",
account=account,
)
# Assert
assert result.indexing_technique == "economy"
assert result.embedding_model_provider is None
assert result.embedding_model is None
mock_db.commit.assert_called_once()
def test_create_internal_dataset_with_high_quality_indexing_default_embedding(
self, mock_dataset_service_dependencies
):
"""Test creation with high_quality indexing using default embedding model."""
# Arrange
tenant_id = str(uuid4())
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
name = "High Quality Dataset"
# Mock database query
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
# Mock model manager
embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock()
mock_model_manager_instance = Mock()
mock_model_manager_instance.get_default_model_instance.return_value = embedding_model
mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance
mock_db = mock_dataset_service_dependencies["db_session"]
mock_db.add = Mock()
mock_db.flush = Mock()
mock_db.commit = Mock()
# Act
result = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=name,
description=None,
indexing_technique="high_quality",
account=account,
)
# Assert
assert result.indexing_technique == "high_quality"
assert result.embedding_model_provider == embedding_model.provider
assert result.embedding_model == embedding_model.model
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
mock_db.commit.assert_called_once()
def test_create_internal_dataset_with_high_quality_indexing_custom_embedding(
self, mock_dataset_service_dependencies
):
"""Test creation with high_quality indexing using custom embedding model."""
# Arrange
tenant_id = str(uuid4())
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
name = "Custom Embedding Dataset"
embedding_provider = "openai"
embedding_model_name = "text-embedding-3-small"
# Mock database query
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
# Mock model manager
embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock(
model=embedding_model_name, provider=embedding_provider
)
mock_model_manager_instance = Mock()
mock_model_manager_instance.get_model_instance.return_value = embedding_model
mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance
mock_db = mock_dataset_service_dependencies["db_session"]
mock_db.add = Mock()
mock_db.flush = Mock()
mock_db.commit = Mock()
# Act
result = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=name,
description=None,
indexing_technique="high_quality",
account=account,
embedding_model_provider=embedding_provider,
embedding_model_name=embedding_model_name,
)
# Assert
assert result.indexing_technique == "high_quality"
assert result.embedding_model_provider == embedding_provider
assert result.embedding_model == embedding_model_name
mock_dataset_service_dependencies["check_embedding"].assert_called_once_with(
tenant_id, embedding_provider, embedding_model_name
)
mock_model_manager_instance.get_model_instance.assert_called_once_with(
tenant_id=tenant_id,
provider=embedding_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=embedding_model_name,
)
mock_db.commit.assert_called_once()
def test_create_internal_dataset_with_retrieval_model(self, mock_dataset_service_dependencies):
"""Test creation with retrieval model configuration."""
# Arrange
tenant_id = str(uuid4())
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
name = "Retrieval Model Dataset"
# Mock database query
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
# Mock retrieval model
retrieval_model = DatasetCreateTestDataFactory.create_retrieval_model_mock()
retrieval_model_dict = {"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0}
mock_db = mock_dataset_service_dependencies["db_session"]
mock_db.add = Mock()
mock_db.flush = Mock()
mock_db.commit = Mock()
# Act
result = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=name,
description=None,
indexing_technique=None,
account=account,
retrieval_model=retrieval_model,
)
# Assert
assert result.retrieval_model == retrieval_model_dict
retrieval_model.model_dump.assert_called_once()
mock_db.commit.assert_called_once()
def test_create_internal_dataset_with_retrieval_model_reranking(self, mock_dataset_service_dependencies):
"""Test creation with retrieval model that includes reranking."""
# Arrange
tenant_id = str(uuid4())
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
name = "Reranking Dataset"
# Mock database query
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
# Mock model manager
embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock()
mock_model_manager_instance = Mock()
mock_model_manager_instance.get_default_model_instance.return_value = embedding_model
mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance
# Mock retrieval model with reranking
reranking_model = Mock()
reranking_model.reranking_provider_name = "cohere"
reranking_model.reranking_model_name = "rerank-english-v3.0"
retrieval_model = DatasetCreateTestDataFactory.create_retrieval_model_mock()
retrieval_model.reranking_model = reranking_model
mock_db = mock_dataset_service_dependencies["db_session"]
mock_db.add = Mock()
mock_db.flush = Mock()
mock_db.commit = Mock()
# Act
result = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=name,
description=None,
indexing_technique="high_quality",
account=account,
retrieval_model=retrieval_model,
)
# Assert
mock_dataset_service_dependencies["check_reranking"].assert_called_once_with(
tenant_id, "cohere", "rerank-english-v3.0"
)
mock_db.commit.assert_called_once()
def test_create_internal_dataset_with_custom_permission(self, mock_dataset_service_dependencies):
"""Test creation with custom permission setting."""
# Arrange
tenant_id = str(uuid4())
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
name = "Custom Permission Dataset"
# Mock database query
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
mock_db = mock_dataset_service_dependencies["db_session"]
mock_db.add = Mock()
mock_db.flush = Mock()
mock_db.commit = Mock()
# Act
result = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=name,
description=None,
indexing_technique=None,
account=account,
permission="all_team_members",
)
# Assert
assert result.permission == "all_team_members"
mock_db.commit.assert_called_once()
# ==================== External Dataset Creation Tests ====================
def test_create_external_dataset_success(self, mock_dataset_service_dependencies):
"""Test successful creation of external dataset."""
# Arrange
tenant_id = str(uuid4())
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
name = "External Dataset"
external_api_id = "external-api-123"
external_knowledge_id = "external-knowledge-456"
# Mock database query
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
# Mock external knowledge API
external_api = DatasetCreateTestDataFactory.create_external_knowledge_api_mock(api_id=external_api_id)
mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = external_api
mock_db = mock_dataset_service_dependencies["db_session"]
mock_db.add = Mock()
mock_db.flush = Mock()
mock_db.commit = Mock()
# Act
result = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=name,
description=None,
indexing_technique=None,
account=account,
provider="external",
external_knowledge_api_id=external_api_id,
external_knowledge_id=external_knowledge_id,
)
# Assert
assert result.provider == "external"
assert mock_db.add.call_count == 2 # Dataset + ExternalKnowledgeBindings
mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.assert_called_once_with(
external_api_id
)
mock_db.commit.assert_called_once()
def test_create_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies):
"""Test error when external knowledge API is not found."""
# Arrange
tenant_id = str(uuid4())
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
name = "External Dataset"
external_api_id = "non-existent-api"
# Mock database query
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
# Mock external knowledge API not found
mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = None
mock_db = mock_dataset_service_dependencies["db_session"]
mock_db.add = Mock()
mock_db.flush = Mock()
# Act & Assert
with pytest.raises(ValueError, match="External API template not found"):
DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=name,
description=None,
indexing_technique=None,
account=account,
provider="external",
external_knowledge_api_id=external_api_id,
external_knowledge_id="knowledge-123",
)
def test_create_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies):
"""Test error when external knowledge ID is missing."""
# Arrange
tenant_id = str(uuid4())
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
name = "External Dataset"
external_api_id = "external-api-123"
# Mock database query
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
# Mock external knowledge API
external_api = DatasetCreateTestDataFactory.create_external_knowledge_api_mock(api_id=external_api_id)
mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = external_api
mock_db = mock_dataset_service_dependencies["db_session"]
mock_db.add = Mock()
mock_db.flush = Mock()
# Act & Assert
with pytest.raises(ValueError, match="external_knowledge_id is required"):
DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=name,
description=None,
indexing_technique=None,
account=account,
provider="external",
external_knowledge_api_id=external_api_id,
external_knowledge_id=None,
)
# ==================== Error Handling Tests ====================
def test_create_dataset_duplicate_name_error(self, mock_dataset_service_dependencies):
"""Test error when dataset name already exists."""
# Arrange
tenant_id = str(uuid4())
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
name = "Duplicate Dataset"
# Mock database query to return existing dataset
existing_dataset = DatasetCreateTestDataFactory.create_dataset_mock(name=name)
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = existing_dataset
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
# Act & Assert
with pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {name} already exists"):
DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=name,
description=None,
indexing_technique=None,
account=account,
)
class TestDatasetServiceCreateEmptyRagPipelineDataset:
"""
Comprehensive unit tests for DatasetService.create_empty_rag_pipeline_dataset method.
This test suite covers:
- RAG pipeline dataset creation with provided name
- RAG pipeline dataset creation with auto-generated name
- Pipeline creation
- Error conditions (duplicate names, missing current user)
"""
@pytest.fixture
def mock_rag_pipeline_dependencies(self):
"""Common mock setup for RAG pipeline dataset creation."""
with (
patch("services.dataset_service.db.session") as mock_db,
patch("services.dataset_service.current_user") as mock_current_user,
patch("services.dataset_service.generate_incremental_name") as mock_generate_name,
):
# Configure mock_current_user to behave like a Flask-Login proxy
# Default: no user (falsy)
mock_current_user.id = None
yield {
"db_session": mock_db,
"current_user_mock": mock_current_user,
"generate_name": mock_generate_name,
}
def test_create_rag_pipeline_dataset_with_name_success(self, mock_rag_pipeline_dependencies):
"""Test successful creation of RAG pipeline dataset with provided name."""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
name = "RAG Pipeline Dataset"
description = "RAG Pipeline Description"
# Mock current user - set up the mock to have id attribute accessible directly
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
# Mock database query (no duplicate name)
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
# Mock database operations
mock_db = mock_rag_pipeline_dependencies["db_session"]
mock_db.add = Mock()
mock_db.flush = Mock()
mock_db.commit = Mock()
# Create entity
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
entity = RagPipelineDatasetCreateEntity(
name=name,
description=description,
icon_info=icon_info,
permission="only_me",
)
# Act
result = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
)
# Assert
assert result is not None
assert result.name == name
assert result.description == description
assert result.tenant_id == tenant_id
assert result.created_by == user_id
assert result.provider == "vendor"
assert result.runtime_mode == "rag_pipeline"
assert result.permission == "only_me"
assert mock_db.add.call_count == 2 # Pipeline + Dataset
mock_db.commit.assert_called_once()
def test_create_rag_pipeline_dataset_with_auto_generated_name(self, mock_rag_pipeline_dependencies):
"""Test creation of RAG pipeline dataset with auto-generated name."""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
auto_name = "Untitled 1"
# Mock current user - set up the mock to have id attribute accessible directly
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
# Mock database query (empty name, need to generate)
mock_query = Mock()
mock_query.filter_by.return_value.all.return_value = []
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
# Mock name generation
mock_rag_pipeline_dependencies["generate_name"].return_value = auto_name
# Mock database operations
mock_db = mock_rag_pipeline_dependencies["db_session"]
mock_db.add = Mock()
mock_db.flush = Mock()
mock_db.commit = Mock()
# Create entity with empty name
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
entity = RagPipelineDatasetCreateEntity(
name="",
description="",
icon_info=icon_info,
permission="only_me",
)
# Act
result = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
)
# Assert
assert result.name == auto_name
mock_rag_pipeline_dependencies["generate_name"].assert_called_once()
mock_db.commit.assert_called_once()
def test_create_rag_pipeline_dataset_duplicate_name_error(self, mock_rag_pipeline_dependencies):
"""Test error when RAG pipeline dataset name already exists."""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
name = "Duplicate RAG Dataset"
# Mock current user - set up the mock to have id attribute accessible directly
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
# Mock database query to return existing dataset
existing_dataset = DatasetCreateTestDataFactory.create_dataset_mock(name=name)
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = existing_dataset
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
# Create entity
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
entity = RagPipelineDatasetCreateEntity(
name=name,
description="",
icon_info=icon_info,
permission="only_me",
)
# Act & Assert
with pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {name} already exists"):
DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
)
def test_create_rag_pipeline_dataset_missing_current_user_error(self, mock_rag_pipeline_dependencies):
"""Test error when current user is not available."""
# Arrange
tenant_id = str(uuid4())
# Mock current user as None - set id to None so the check fails
mock_rag_pipeline_dependencies["current_user_mock"].id = None
# Mock database query
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
# Create entity
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
entity = RagPipelineDatasetCreateEntity(
name="Test Dataset",
description="",
icon_info=icon_info,
permission="only_me",
)
# Act & Assert
with pytest.raises(ValueError, match="Current user or current user id not found"):
DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
)
def test_create_rag_pipeline_dataset_with_custom_permission(self, mock_rag_pipeline_dependencies):
"""Test creation with custom permission setting."""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
name = "Custom Permission RAG Dataset"
# Mock current user - set up the mock to have id attribute accessible directly
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
# Mock database query
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
# Mock database operations
mock_db = mock_rag_pipeline_dependencies["db_session"]
mock_db.add = Mock()
mock_db.flush = Mock()
mock_db.commit = Mock()
# Create entity
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
entity = RagPipelineDatasetCreateEntity(
name=name,
description="",
icon_info=icon_info,
permission="all_team",
)
# Act
result = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
)
# Assert
assert result.permission == "all_team"
mock_db.commit.assert_called_once()
def test_create_rag_pipeline_dataset_with_icon_info(self, mock_rag_pipeline_dependencies):
"""Test creation with icon info configuration."""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
name = "Icon Info RAG Dataset"
# Mock current user - set up the mock to have id attribute accessible directly
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
# Mock database query
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
# Mock database operations
mock_db = mock_rag_pipeline_dependencies["db_session"]
mock_db.add = Mock()
mock_db.flush = Mock()
mock_db.commit = Mock()
# Create entity with icon info
icon_info = IconInfo(
icon="📚",
icon_background="#E8F5E9",
icon_type="emoji",
icon_url="https://example.com/icon.png",
)
entity = RagPipelineDatasetCreateEntity(
name=name,
description="",
icon_info=icon_info,
permission="only_me",
)
# Act
result = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
)
# Assert
assert result.icon_info == icon_info.model_dump()
mock_db.commit.assert_called_once()

View File

@@ -0,0 +1,216 @@
from unittest.mock import Mock, patch
import pytest
from models.account import Account, TenantAccountRole
from models.dataset import Dataset
from services.dataset_service import DatasetService
class DatasetDeleteTestDataFactory:
"""Factory class for creating test data and mock objects for dataset delete tests."""
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
tenant_id: str = "test-tenant-123",
created_by: str = "creator-456",
doc_form: str | None = None,
indexing_technique: str | None = "high_quality",
**kwargs,
) -> Mock:
"""Create a mock dataset with specified attributes."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.created_by = created_by
dataset.doc_form = doc_form
dataset.indexing_technique = indexing_technique
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_user_mock(
user_id: str = "user-789",
tenant_id: str = "test-tenant-123",
role: TenantAccountRole = TenantAccountRole.ADMIN,
**kwargs,
) -> Mock:
"""Create a mock user with specified attributes."""
user = Mock(spec=Account)
user.id = user_id
user.current_tenant_id = tenant_id
user.current_role = role
for key, value in kwargs.items():
setattr(user, key, value)
return user
class TestDatasetServiceDeleteDataset:
"""
Comprehensive unit tests for DatasetService.delete_dataset method.
This test suite covers all deletion scenarios including:
- Normal dataset deletion with documents
- Empty dataset deletion (no documents, doc_form is None)
- Dataset deletion with missing indexing_technique
- Permission checks
- Event handling
This test suite provides regression protection for issue #27073.
"""
@pytest.fixture
def mock_dataset_service_dependencies(self):
"""Common mock setup for dataset service dependencies."""
with (
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
patch("extensions.ext_database.db.session") as mock_db,
patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted,
):
yield {
"get_dataset": mock_get_dataset,
"check_permission": mock_check_perm,
"db_session": mock_db,
"dataset_was_deleted": mock_dataset_was_deleted,
}
def test_delete_dataset_with_documents_success(self, mock_dataset_service_dependencies):
"""
Test successful deletion of a dataset with documents.
This test verifies:
- Dataset is retrieved correctly
- Permission check is performed
- dataset_was_deleted event is sent
- Dataset is deleted from database
- Method returns True
"""
# Arrange
dataset = DatasetDeleteTestDataFactory.create_dataset_mock(
doc_form="text_model", indexing_technique="high_quality"
)
user = DatasetDeleteTestDataFactory.create_user_mock()
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
# Act
result = DatasetService.delete_dataset(dataset.id, user)
# Assert
assert result is True
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id)
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
def test_delete_empty_dataset_success(self, mock_dataset_service_dependencies):
"""
Test successful deletion of an empty dataset (no documents, doc_form is None).
This test verifies that:
- Empty datasets can be deleted without errors
- dataset_was_deleted event is sent (event handler will skip cleanup if doc_form is None)
- Dataset is deleted from database
- Method returns True
This is the primary test for issue #27073 where deleting an empty dataset
caused internal server error due to assertion failure in event handlers.
"""
# Arrange
dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique=None)
user = DatasetDeleteTestDataFactory.create_user_mock()
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
# Act
result = DatasetService.delete_dataset(dataset.id, user)
# Assert - Verify complete deletion flow
assert result is True
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id)
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
def test_delete_dataset_with_partial_none_values(self, mock_dataset_service_dependencies):
"""
Test deletion of dataset with partial None values.
This test verifies that datasets with partial None values (e.g., doc_form exists
but indexing_technique is None) can be deleted successfully. The event handler
will skip cleanup if any required field is None.
Improvement based on Gemini Code Assist suggestion: Added comprehensive assertions
to verify all core deletion operations are performed, not just event sending.
"""
# Arrange
dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form="text_model", indexing_technique=None)
user = DatasetDeleteTestDataFactory.create_user_mock()
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
# Act
result = DatasetService.delete_dataset(dataset.id, user)
# Assert - Verify complete deletion flow (Gemini suggestion implemented)
assert result is True
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id)
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, mock_dataset_service_dependencies):
"""
Test deletion of dataset where doc_form is None but indexing_technique exists.
This edge case can occur in certain dataset configurations and should be handled
gracefully by the event handler's conditional check.
"""
# Arrange
dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique="high_quality")
user = DatasetDeleteTestDataFactory.create_user_mock()
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
# Act
result = DatasetService.delete_dataset(dataset.id, user)
# Assert - Verify complete deletion flow
assert result is True
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id)
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
def test_delete_dataset_not_found(self, mock_dataset_service_dependencies):
"""
Test deletion attempt when dataset doesn't exist.
This test verifies that:
- Method returns False when dataset is not found
- No deletion operations are performed
- No events are sent
"""
# Arrange
dataset_id = "non-existent-dataset"
user = DatasetDeleteTestDataFactory.create_user_mock()
mock_dataset_service_dependencies["get_dataset"].return_value = None
# Act
result = DatasetService.delete_dataset(dataset_id, user)
# Assert
assert result is False
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id)
mock_dataset_service_dependencies["check_permission"].assert_not_called()
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_not_called()
mock_dataset_service_dependencies["db_session"].delete.assert_not_called()
mock_dataset_service_dependencies["db_session"].commit.assert_not_called()

View File

@@ -0,0 +1,746 @@
"""
Comprehensive unit tests for DatasetService retrieval/list methods.
This test suite covers:
- get_datasets - pagination, search, filtering, permissions
- get_dataset - single dataset retrieval
- get_datasets_by_ids - bulk retrieval
- get_process_rules - dataset processing rules
- get_dataset_queries - dataset query history
- get_related_apps - apps using the dataset
"""
from unittest.mock import Mock, create_autospec, patch
from uuid import uuid4
import pytest
from models.account import Account, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
Dataset,
DatasetPermission,
DatasetPermissionEnum,
DatasetProcessRule,
DatasetQuery,
)
from services.dataset_service import DatasetService, DocumentService
class DatasetRetrievalTestDataFactory:
"""Factory class for creating test data and mock objects for dataset retrieval tests."""
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
name: str = "Test Dataset",
tenant_id: str = "tenant-123",
created_by: str = "user-123",
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
**kwargs,
) -> Mock:
"""Create a mock dataset with specified attributes."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.name = name
dataset.tenant_id = tenant_id
dataset.created_by = created_by
dataset.permission = permission
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_account_mock(
account_id: str = "account-123",
tenant_id: str = "tenant-123",
role: TenantAccountRole = TenantAccountRole.NORMAL,
**kwargs,
) -> Mock:
"""Create a mock account."""
account = create_autospec(Account, instance=True)
account.id = account_id
account.current_tenant_id = tenant_id
account.current_role = role
for key, value in kwargs.items():
setattr(account, key, value)
return account
@staticmethod
def create_dataset_permission_mock(
dataset_id: str = "dataset-123",
account_id: str = "account-123",
**kwargs,
) -> Mock:
"""Create a mock dataset permission."""
permission = Mock(spec=DatasetPermission)
permission.dataset_id = dataset_id
permission.account_id = account_id
for key, value in kwargs.items():
setattr(permission, key, value)
return permission
@staticmethod
def create_process_rule_mock(
dataset_id: str = "dataset-123",
mode: str = "automatic",
rules: dict | None = None,
**kwargs,
) -> Mock:
"""Create a mock dataset process rule."""
process_rule = Mock(spec=DatasetProcessRule)
process_rule.dataset_id = dataset_id
process_rule.mode = mode
process_rule.rules_dict = rules or {}
for key, value in kwargs.items():
setattr(process_rule, key, value)
return process_rule
@staticmethod
def create_dataset_query_mock(
dataset_id: str = "dataset-123",
query_id: str = "query-123",
**kwargs,
) -> Mock:
"""Create a mock dataset query."""
dataset_query = Mock(spec=DatasetQuery)
dataset_query.id = query_id
dataset_query.dataset_id = dataset_id
for key, value in kwargs.items():
setattr(dataset_query, key, value)
return dataset_query
@staticmethod
def create_app_dataset_join_mock(
app_id: str = "app-123",
dataset_id: str = "dataset-123",
**kwargs,
) -> Mock:
"""Create a mock app-dataset join."""
join = Mock(spec=AppDatasetJoin)
join.app_id = app_id
join.dataset_id = dataset_id
for key, value in kwargs.items():
setattr(join, key, value)
return join
class TestDatasetServiceGetDatasets:
"""
Comprehensive unit tests for DatasetService.get_datasets method.
This test suite covers:
- Pagination
- Search functionality
- Tag filtering
- Permission-based filtering (ONLY_ME, ALL_TEAM, PARTIAL_TEAM)
- Role-based filtering (OWNER, DATASET_OPERATOR, NORMAL)
- include_all flag
"""
@pytest.fixture
def mock_dependencies(self):
"""Common mock setup for get_datasets tests."""
with (
patch("services.dataset_service.db.session") as mock_db,
patch("services.dataset_service.db.paginate") as mock_paginate,
patch("services.dataset_service.TagService") as mock_tag_service,
):
yield {
"db_session": mock_db,
"paginate": mock_paginate,
"tag_service": mock_tag_service,
}
# ==================== Basic Retrieval Tests ====================
def test_get_datasets_basic_pagination(self, mock_dependencies):
"""Test basic pagination without user or filters."""
# Arrange
tenant_id = str(uuid4())
page = 1
per_page = 20
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(
dataset_id=f"dataset-{i}", name=f"Dataset {i}", tenant_id=tenant_id
)
for i in range(5)
]
mock_paginate_result.total = 5
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id)
# Assert
assert len(datasets) == 5
assert total == 5
mock_dependencies["paginate"].assert_called_once()
def test_get_datasets_with_search(self, mock_dependencies):
"""Test get_datasets with search keyword."""
# Arrange
tenant_id = str(uuid4())
page = 1
per_page = 20
search = "test"
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(
dataset_id="dataset-1", name="Test Dataset", tenant_id=tenant_id
)
]
mock_paginate_result.total = 1
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, search=search)
# Assert
assert len(datasets) == 1
assert total == 1
mock_dependencies["paginate"].assert_called_once()
def test_get_datasets_with_tag_filtering(self, mock_dependencies):
"""Test get_datasets with tag_ids filtering."""
# Arrange
tenant_id = str(uuid4())
page = 1
per_page = 20
tag_ids = ["tag-1", "tag-2"]
# Mock tag service
target_ids = ["dataset-1", "dataset-2"]
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.return_value = target_ids
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
for dataset_id in target_ids
]
mock_paginate_result.total = 2
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, tag_ids=tag_ids)
# Assert
assert len(datasets) == 2
assert total == 2
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.assert_called_once_with(
"knowledge", tenant_id, tag_ids
)
def test_get_datasets_with_empty_tag_ids(self, mock_dependencies):
"""Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets."""
# Arrange
tenant_id = str(uuid4())
page = 1
per_page = 20
tag_ids = []
# Mock pagination result - when tag_ids is empty, tag filtering is skipped
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", tenant_id=tenant_id)
for i in range(3)
]
mock_paginate_result.total = 3
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, tag_ids=tag_ids)
# Assert
# When tag_ids is empty, tag filtering is skipped, so normal query results are returned
assert len(datasets) == 3
assert total == 3
# Tag service should not be called when tag_ids is empty
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.assert_not_called()
mock_dependencies["paginate"].assert_called_once()
# ==================== Permission-Based Filtering Tests ====================
def test_get_datasets_without_user_shows_only_all_team(self, mock_dependencies):
"""Test that without user, only ALL_TEAM datasets are shown."""
# Arrange
tenant_id = str(uuid4())
page = 1
per_page = 20
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(
dataset_id="dataset-1",
tenant_id=tenant_id,
permission=DatasetPermissionEnum.ALL_TEAM,
)
]
mock_paginate_result.total = 1
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, user=None)
# Assert
assert len(datasets) == 1
mock_dependencies["paginate"].assert_called_once()
def test_get_datasets_owner_with_include_all(self, mock_dependencies):
"""Test that OWNER with include_all=True sees all datasets."""
# Arrange
tenant_id = str(uuid4())
user = DatasetRetrievalTestDataFactory.create_account_mock(
account_id="owner-123", tenant_id=tenant_id, role=TenantAccountRole.OWNER
)
# Mock dataset permissions query (empty - owner doesn't need explicit permissions)
mock_query = Mock()
mock_query.filter_by.return_value.all.return_value = []
mock_dependencies["db_session"].query.return_value = mock_query
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", tenant_id=tenant_id)
for i in range(3)
]
mock_paginate_result.total = 3
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(
page=1, per_page=20, tenant_id=tenant_id, user=user, include_all=True
)
# Assert
assert len(datasets) == 3
assert total == 3
def test_get_datasets_normal_user_only_me_permission(self, mock_dependencies):
"""Test that normal user sees ONLY_ME datasets they created."""
# Arrange
tenant_id = str(uuid4())
user_id = "user-123"
user = DatasetRetrievalTestDataFactory.create_account_mock(
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.NORMAL
)
# Mock dataset permissions query (no explicit permissions)
mock_query = Mock()
mock_query.filter_by.return_value.all.return_value = []
mock_dependencies["db_session"].query.return_value = mock_query
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(
dataset_id="dataset-1",
tenant_id=tenant_id,
created_by=user_id,
permission=DatasetPermissionEnum.ONLY_ME,
)
]
mock_paginate_result.total = 1
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
# Assert
assert len(datasets) == 1
assert total == 1
def test_get_datasets_normal_user_all_team_permission(self, mock_dependencies):
"""Test that normal user sees ALL_TEAM datasets."""
# Arrange
tenant_id = str(uuid4())
user = DatasetRetrievalTestDataFactory.create_account_mock(
account_id="user-123", tenant_id=tenant_id, role=TenantAccountRole.NORMAL
)
# Mock dataset permissions query (no explicit permissions)
mock_query = Mock()
mock_query.filter_by.return_value.all.return_value = []
mock_dependencies["db_session"].query.return_value = mock_query
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(
dataset_id="dataset-1",
tenant_id=tenant_id,
permission=DatasetPermissionEnum.ALL_TEAM,
)
]
mock_paginate_result.total = 1
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
# Assert
assert len(datasets) == 1
assert total == 1
def test_get_datasets_normal_user_partial_team_with_permission(self, mock_dependencies):
"""Test that normal user sees PARTIAL_TEAM datasets they have permission for."""
# Arrange
tenant_id = str(uuid4())
user_id = "user-123"
dataset_id = "dataset-1"
user = DatasetRetrievalTestDataFactory.create_account_mock(
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.NORMAL
)
# Mock dataset permissions query - user has permission
permission = DatasetRetrievalTestDataFactory.create_dataset_permission_mock(
dataset_id=dataset_id, account_id=user_id
)
mock_query = Mock()
mock_query.filter_by.return_value.all.return_value = [permission]
mock_dependencies["db_session"].query.return_value = mock_query
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(
dataset_id=dataset_id,
tenant_id=tenant_id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
]
mock_paginate_result.total = 1
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
# Assert
assert len(datasets) == 1
assert total == 1
def test_get_datasets_dataset_operator_with_permissions(self, mock_dependencies):
"""Test that DATASET_OPERATOR only sees datasets they have explicit permission for."""
# Arrange
tenant_id = str(uuid4())
user_id = "operator-123"
dataset_id = "dataset-1"
user = DatasetRetrievalTestDataFactory.create_account_mock(
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.DATASET_OPERATOR
)
# Mock dataset permissions query - operator has permission
permission = DatasetRetrievalTestDataFactory.create_dataset_permission_mock(
dataset_id=dataset_id, account_id=user_id
)
mock_query = Mock()
mock_query.filter_by.return_value.all.return_value = [permission]
mock_dependencies["db_session"].query.return_value = mock_query
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
]
mock_paginate_result.total = 1
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
# Assert
assert len(datasets) == 1
assert total == 1
def test_get_datasets_dataset_operator_without_permissions(self, mock_dependencies):
"""Test that DATASET_OPERATOR without permissions returns empty result."""
# Arrange
tenant_id = str(uuid4())
user_id = "operator-123"
user = DatasetRetrievalTestDataFactory.create_account_mock(
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.DATASET_OPERATOR
)
# Mock dataset permissions query - no permissions
mock_query = Mock()
mock_query.filter_by.return_value.all.return_value = []
mock_dependencies["db_session"].query.return_value = mock_query
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
# Assert
assert datasets == []
assert total == 0
class TestDatasetServiceGetDataset:
"""Comprehensive unit tests for DatasetService.get_dataset method."""
@pytest.fixture
def mock_dependencies(self):
"""Common mock setup for get_dataset tests."""
with patch("services.dataset_service.db.session") as mock_db:
yield {"db_session": mock_db}
def test_get_dataset_success(self, mock_dependencies):
"""Test successful retrieval of a single dataset."""
# Arrange
dataset_id = str(uuid4())
dataset = DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
# Mock database query
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = dataset
mock_dependencies["db_session"].query.return_value = mock_query
# Act
result = DatasetService.get_dataset(dataset_id)
# Assert
assert result is not None
assert result.id == dataset_id
mock_query.filter_by.assert_called_once_with(id=dataset_id)
def test_get_dataset_not_found(self, mock_dependencies):
"""Test retrieval when dataset doesn't exist."""
# Arrange
dataset_id = str(uuid4())
# Mock database query returning None
mock_query = Mock()
mock_query.filter_by.return_value.first.return_value = None
mock_dependencies["db_session"].query.return_value = mock_query
# Act
result = DatasetService.get_dataset(dataset_id)
# Assert
assert result is None
class TestDatasetServiceGetDatasetsByIds:
"""Comprehensive unit tests for DatasetService.get_datasets_by_ids method."""
@pytest.fixture
def mock_dependencies(self):
"""Common mock setup for get_datasets_by_ids tests."""
with patch("services.dataset_service.db.paginate") as mock_paginate:
yield {"paginate": mock_paginate}
def test_get_datasets_by_ids_success(self, mock_dependencies):
"""Test successful bulk retrieval of datasets by IDs."""
# Arrange
tenant_id = str(uuid4())
dataset_ids = [str(uuid4()), str(uuid4()), str(uuid4())]
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
for dataset_id in dataset_ids
]
mock_paginate_result.total = len(dataset_ids)
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
datasets, total = DatasetService.get_datasets_by_ids(dataset_ids, tenant_id)
# Assert
assert len(datasets) == 3
assert total == 3
assert all(dataset.id in dataset_ids for dataset in datasets)
mock_dependencies["paginate"].assert_called_once()
def test_get_datasets_by_ids_empty_list(self, mock_dependencies):
"""Test get_datasets_by_ids with empty list returns empty result."""
# Arrange
tenant_id = str(uuid4())
dataset_ids = []
# Act
datasets, total = DatasetService.get_datasets_by_ids(dataset_ids, tenant_id)
# Assert
assert datasets == []
assert total == 0
mock_dependencies["paginate"].assert_not_called()
def test_get_datasets_by_ids_none_list(self, mock_dependencies):
"""Test get_datasets_by_ids with None returns empty result."""
# Arrange
tenant_id = str(uuid4())
# Act
datasets, total = DatasetService.get_datasets_by_ids(None, tenant_id)
# Assert
assert datasets == []
assert total == 0
mock_dependencies["paginate"].assert_not_called()
class TestDatasetServiceGetProcessRules:
"""Comprehensive unit tests for DatasetService.get_process_rules method."""
@pytest.fixture
def mock_dependencies(self):
"""Common mock setup for get_process_rules tests."""
with patch("services.dataset_service.db.session") as mock_db:
yield {"db_session": mock_db}
def test_get_process_rules_with_existing_rule(self, mock_dependencies):
"""Test retrieval of process rules when rule exists."""
# Arrange
dataset_id = str(uuid4())
rules_data = {
"pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}],
"segmentation": {"delimiter": "\n", "max_tokens": 500},
}
process_rule = DatasetRetrievalTestDataFactory.create_process_rule_mock(
dataset_id=dataset_id, mode="custom", rules=rules_data
)
# Mock database query
mock_query = Mock()
mock_query.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value = process_rule
mock_dependencies["db_session"].query.return_value = mock_query
# Act
result = DatasetService.get_process_rules(dataset_id)
# Assert
assert result["mode"] == "custom"
assert result["rules"] == rules_data
def test_get_process_rules_without_existing_rule(self, mock_dependencies):
"""Test retrieval of process rules when no rule exists (returns defaults)."""
# Arrange
dataset_id = str(uuid4())
# Mock database query returning None
mock_query = Mock()
mock_query.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value = None
mock_dependencies["db_session"].query.return_value = mock_query
# Act
result = DatasetService.get_process_rules(dataset_id)
# Assert
assert result["mode"] == DocumentService.DEFAULT_RULES["mode"]
assert "rules" in result
assert result["rules"] == DocumentService.DEFAULT_RULES["rules"]
class TestDatasetServiceGetDatasetQueries:
"""Comprehensive unit tests for DatasetService.get_dataset_queries method."""
@pytest.fixture
def mock_dependencies(self):
"""Common mock setup for get_dataset_queries tests."""
with patch("services.dataset_service.db.paginate") as mock_paginate:
yield {"paginate": mock_paginate}
def test_get_dataset_queries_success(self, mock_dependencies):
"""Test successful retrieval of dataset queries."""
# Arrange
dataset_id = str(uuid4())
page = 1
per_page = 20
# Mock pagination result
mock_paginate_result = Mock()
mock_paginate_result.items = [
DatasetRetrievalTestDataFactory.create_dataset_query_mock(dataset_id=dataset_id, query_id=f"query-{i}")
for i in range(3)
]
mock_paginate_result.total = 3
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
queries, total = DatasetService.get_dataset_queries(dataset_id, page, per_page)
# Assert
assert len(queries) == 3
assert total == 3
assert all(query.dataset_id == dataset_id for query in queries)
mock_dependencies["paginate"].assert_called_once()
def test_get_dataset_queries_empty_result(self, mock_dependencies):
"""Test retrieval when no queries exist."""
# Arrange
dataset_id = str(uuid4())
page = 1
per_page = 20
# Mock pagination result (empty)
mock_paginate_result = Mock()
mock_paginate_result.items = []
mock_paginate_result.total = 0
mock_dependencies["paginate"].return_value = mock_paginate_result
# Act
queries, total = DatasetService.get_dataset_queries(dataset_id, page, per_page)
# Assert
assert queries == []
assert total == 0
class TestDatasetServiceGetRelatedApps:
"""Comprehensive unit tests for DatasetService.get_related_apps method."""
@pytest.fixture
def mock_dependencies(self):
"""Common mock setup for get_related_apps tests."""
with patch("services.dataset_service.db.session") as mock_db:
yield {"db_session": mock_db}
def test_get_related_apps_success(self, mock_dependencies):
"""Test successful retrieval of related apps."""
# Arrange
dataset_id = str(uuid4())
# Mock app-dataset joins
app_joins = [
DatasetRetrievalTestDataFactory.create_app_dataset_join_mock(app_id=f"app-{i}", dataset_id=dataset_id)
for i in range(2)
]
# Mock database query
mock_query = Mock()
mock_query.where.return_value.order_by.return_value.all.return_value = app_joins
mock_dependencies["db_session"].query.return_value = mock_query
# Act
result = DatasetService.get_related_apps(dataset_id)
# Assert
assert len(result) == 2
assert all(join.dataset_id == dataset_id for join in result)
mock_query.where.assert_called_once()
mock_query.where.return_value.order_by.assert_called_once()
def test_get_related_apps_empty_result(self, mock_dependencies):
"""Test retrieval when no related apps exist."""
# Arrange
dataset_id = str(uuid4())
# Mock database query returning empty list
mock_query = Mock()
mock_query.where.return_value.order_by.return_value.all.return_value = []
mock_dependencies["db_session"].query.return_value = mock_query
# Act
result = DatasetService.get_related_apps(dataset_id)
# Assert
assert result == []

View File

@@ -0,0 +1,652 @@
import datetime
from typing import Any
# Mock redis_client before importing dataset_service
from unittest.mock import Mock, create_autospec, patch
import pytest
from core.model_runtime.entities.model_entities import ModelType
from models.account import Account
from models.dataset import Dataset, ExternalKnowledgeBindings
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
class DatasetUpdateTestDataFactory:
"""Factory class for creating test data and mock objects for dataset update tests."""
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
provider: str = "vendor",
name: str = "old_name",
description: str = "old_description",
indexing_technique: str = "high_quality",
retrieval_model: str = "old_model",
embedding_model_provider: str | None = None,
embedding_model: str | None = None,
collection_binding_id: str | None = None,
**kwargs,
) -> Mock:
"""Create a mock dataset with specified attributes."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.provider = provider
dataset.name = name
dataset.description = description
dataset.indexing_technique = indexing_technique
dataset.retrieval_model = retrieval_model
dataset.embedding_model_provider = embedding_model_provider
dataset.embedding_model = embedding_model
dataset.collection_binding_id = collection_binding_id
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_user_mock(user_id: str = "user-789") -> Mock:
"""Create a mock user."""
user = Mock()
user.id = user_id
return user
@staticmethod
def create_external_binding_mock(
external_knowledge_id: str = "old_knowledge_id", external_knowledge_api_id: str = "old_api_id"
) -> Mock:
"""Create a mock external knowledge binding."""
binding = Mock(spec=ExternalKnowledgeBindings)
binding.external_knowledge_id = external_knowledge_id
binding.external_knowledge_api_id = external_knowledge_api_id
return binding
@staticmethod
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
"""Create a mock embedding model."""
embedding_model = Mock()
embedding_model.model = model
embedding_model.provider = provider
return embedding_model
@staticmethod
def create_collection_binding_mock(binding_id: str = "binding-456") -> Mock:
"""Create a mock collection binding."""
binding = Mock()
binding.id = binding_id
return binding
@staticmethod
def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock:
"""Create a mock current user."""
current_user = create_autospec(Account, instance=True)
current_user.current_tenant_id = tenant_id
return current_user
class TestDatasetServiceUpdateDataset:
"""
Comprehensive unit tests for DatasetService.update_dataset method.
This test suite covers all supported scenarios including:
- External dataset updates
- Internal dataset updates with different indexing techniques
- Embedding model updates
- Permission checks
- Error conditions and edge cases
"""
@pytest.fixture
def mock_dataset_service_dependencies(self):
"""Common mock setup for dataset service dependencies."""
with (
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
patch("extensions.ext_database.db.session") as mock_db,
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
patch("services.dataset_service.DatasetService._has_dataset_same_name") as has_dataset_same_name,
):
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_naive_utc_now.return_value = current_time
yield {
"get_dataset": mock_get_dataset,
"check_permission": mock_check_perm,
"db_session": mock_db,
"naive_utc_now": mock_naive_utc_now,
"current_time": current_time,
"has_dataset_same_name": has_dataset_same_name,
}
@pytest.fixture
def mock_external_provider_dependencies(self):
"""Mock setup for external provider tests."""
with patch("services.dataset_service.Session") as mock_session:
from extensions.ext_database import db
with patch.object(db.__class__, "engine", new_callable=Mock):
session_mock = Mock()
mock_session.return_value.__enter__.return_value = session_mock
yield session_mock
@pytest.fixture
def mock_internal_provider_dependencies(self):
"""Mock setup for internal provider tests."""
with (
patch("services.dataset_service.ModelManager") as mock_model_manager,
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
) as mock_get_binding,
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
patch(
"services.dataset_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
):
mock_current_user.current_tenant_id = "tenant-123"
yield {
"model_manager": mock_model_manager,
"get_binding": mock_get_binding,
"task": mock_task,
"current_user": mock_current_user,
}
def _assert_database_update_called(self, mock_db, dataset_id: str, expected_updates: dict[str, Any]):
"""Helper method to verify database update calls."""
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_updates)
mock_db.commit.assert_called_once()
def _assert_external_dataset_update(self, mock_dataset, mock_binding, update_data: dict[str, Any]):
"""Helper method to verify external dataset updates."""
assert mock_dataset.name == update_data.get("name", mock_dataset.name)
assert mock_dataset.description == update_data.get("description", mock_dataset.description)
assert mock_dataset.retrieval_model == update_data.get("external_retrieval_model", mock_dataset.retrieval_model)
if "external_knowledge_id" in update_data:
assert mock_binding.external_knowledge_id == update_data["external_knowledge_id"]
if "external_knowledge_api_id" in update_data:
assert mock_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"]
# ==================== External Dataset Tests ====================
def test_update_external_dataset_success(
self, mock_dataset_service_dependencies, mock_external_provider_dependencies
):
"""Test successful update of external dataset."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="external", name="old_name", description="old_description", retrieval_model="old_model"
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
binding = DatasetUpdateTestDataFactory.create_external_binding_mock()
# Mock external knowledge binding query
mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = binding
update_data = {
"name": "new_name",
"description": "new_description",
"external_retrieval_model": "new_model",
"permission": "only_me",
"external_knowledge_id": "new_knowledge_id",
"external_knowledge_api_id": "new_api_id",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
# Verify dataset and binding updates
self._assert_external_dataset_update(dataset, binding, update_data)
# Verify database operations
mock_db = mock_dataset_service_dependencies["db_session"]
mock_db.add.assert_any_call(dataset)
mock_db.add.assert_any_call(binding)
mock_db.commit.assert_called_once()
# Verify return value
assert result == dataset
def test_update_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies):
"""Test error when external knowledge id is missing."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "External knowledge id is required" in str(context.value)
def test_update_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies):
"""Test error when external knowledge api id is missing."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "External knowledge api id is required" in str(context.value)
def test_update_external_dataset_binding_not_found_error(
self, mock_dataset_service_dependencies, mock_external_provider_dependencies
):
"""Test error when external knowledge binding is not found."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
# Mock external knowledge binding query returning None
mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = None
update_data = {
"name": "new_name",
"external_knowledge_id": "knowledge_id",
"external_knowledge_api_id": "api_id",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "External knowledge binding not found" in str(context.value)
# ==================== Internal Dataset Basic Tests ====================
def test_update_internal_dataset_basic_success(self, mock_dataset_service_dependencies):
"""Test successful update of internal dataset with basic fields."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id="binding-123",
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {
"name": "new_name",
"description": "new_description",
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify permission check was called
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
# Verify database update was called with correct filtered data
expected_filtered_data = {
"name": "new_name",
"description": "new_description",
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify return value
assert result == dataset
def test_update_internal_dataset_filter_none_values(self, mock_dataset_service_dependencies):
"""Test that None values are filtered out except for description field."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {
"name": "new_name",
"description": None, # Should be included
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"embedding_model_provider": None, # Should be filtered out
"embedding_model": None, # Should be filtered out
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify database update was called with filtered data
expected_filtered_data = {
"name": "new_name",
"description": None, # Description should be included even if None
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
actual_call_args = mock_dataset_service_dependencies[
"db_session"
].query.return_value.filter_by.return_value.update.call_args[0][0]
# Remove timestamp for comparison as it's dynamic
del actual_call_args["updated_at"]
del expected_filtered_data["updated_at"]
assert actual_call_args == expected_filtered_data
# Verify return value
assert result == dataset
# ==================== Indexing Technique Switch Tests ====================
def test_update_internal_dataset_indexing_technique_to_economy(
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
):
"""Test updating internal dataset indexing technique to economy."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify database update was called with embedding model fields cleared
expected_filtered_data = {
"indexing_technique": "economy",
"embedding_model": None,
"embedding_model_provider": None,
"collection_binding_id": None,
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify return value
assert result == dataset
def test_update_internal_dataset_indexing_technique_to_high_quality(
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
):
"""Test updating internal dataset indexing technique to high_quality."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
# Mock embedding model
embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock()
mock_internal_provider_dependencies[
"model_manager"
].return_value.get_model_instance.return_value = embedding_model
# Mock collection binding
binding = DatasetUpdateTestDataFactory.create_collection_binding_mock()
mock_internal_provider_dependencies["get_binding"].return_value = binding
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify embedding model was validated
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with(
tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id,
provider="openai",
model_type=ModelType.TEXT_EMBEDDING,
model="text-embedding-ada-002",
)
# Verify collection binding was retrieved
mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-ada-002")
# Verify database update was called with correct data
expected_filtered_data = {
"indexing_technique": "high_quality",
"embedding_model": "text-embedding-ada-002",
"embedding_model_provider": "openai",
"collection_binding_id": "binding-456",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify vector index task was triggered
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "add")
# Verify return value
assert result == dataset
# ==================== Embedding Model Update Tests ====================
def test_update_internal_dataset_keep_existing_embedding_model(self, mock_dataset_service_dependencies):
"""Test updating internal dataset without changing embedding model."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id="binding-123",
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify database update was called with existing embedding model preserved
expected_filtered_data = {
"name": "new_name",
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"collection_binding_id": "binding-123",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify return value
assert result == dataset
def test_update_internal_dataset_embedding_model_update(
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
):
"""Test updating internal dataset with new embedding model."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
# Mock embedding model
embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock("text-embedding-3-small")
mock_internal_provider_dependencies[
"model_manager"
].return_value.get_model_instance.return_value = embedding_model
# Mock collection binding
binding = DatasetUpdateTestDataFactory.create_collection_binding_mock("binding-789")
mock_internal_provider_dependencies["get_binding"].return_value = binding
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-3-small",
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify embedding model was validated
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with(
tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id,
provider="openai",
model_type=ModelType.TEXT_EMBEDDING,
model="text-embedding-3-small",
)
# Verify collection binding was retrieved
mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-3-small")
# Verify database update was called with correct data
expected_filtered_data = {
"indexing_technique": "high_quality",
"embedding_model": "text-embedding-3-small",
"embedding_model_provider": "openai",
"collection_binding_id": "binding-789",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify vector index task was triggered
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "update")
# Verify return value
assert result == dataset
def test_update_internal_dataset_no_indexing_technique_change(self, mock_dataset_service_dependencies):
"""Test updating internal dataset without changing indexing technique."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id="binding-123",
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {
"name": "new_name",
"indexing_technique": "high_quality", # Same as current
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify database update was called with correct data
expected_filtered_data = {
"name": "new_name",
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"collection_binding_id": "binding-123",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify return value
assert result == dataset
# ==================== Error Handling Tests ====================
def test_update_dataset_not_found_error(self, mock_dataset_service_dependencies):
"""Test error when dataset is not found."""
mock_dataset_service_dependencies["get_dataset"].return_value = None
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "Dataset not found" in str(context.value)
def test_update_dataset_permission_error(self, mock_dataset_service_dependencies):
"""Test error when user doesn't have permission."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock()
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
mock_dataset_service_dependencies["check_permission"].side_effect = NoPermissionError("No permission")
update_data = {"name": "new_name"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(NoPermissionError):
DatasetService.update_dataset("dataset-123", update_data, user)
def test_update_internal_dataset_embedding_model_error(
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
):
"""Test error when embedding model is not available."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
# Mock model manager to raise error
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.side_effect = Exception(
"No Embedding Model available"
)
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "invalid_provider",
"embedding_model": "invalid_model",
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(Exception) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "No Embedding Model available".lower() in str(context.value).lower()

View File

@@ -0,0 +1,317 @@
from unittest.mock import Mock, patch
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
class DocumentIndexingTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests."""
@staticmethod
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()
features.billing.enabled = billing_enabled
features.billing.subscription = Mock()
features.billing.subscription.plan = plan
return features
@staticmethod
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
"""Create mock TenantIsolatedTaskQueue."""
queue = Mock(spec=TenantIsolatedTaskQueue)
queue.get_task_key.return_value = "task_key" if has_task_key else None
queue.push_tasks = Mock()
queue.set_task_waiting_time = Mock()
return queue
@staticmethod
def create_document_task_proxy(
tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None
) -> DocumentIndexingTaskProxy:
"""Create DocumentIndexingTaskProxy instance for testing."""
if document_ids is None:
document_ids = ["doc-1", "doc-2", "doc-3"]
return DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
class TestDocumentIndexingTaskProxy:
"""Test cases for DocumentIndexingTaskProxy class."""
def test_initialization(self):
"""Test DocumentIndexingTaskProxy initialization."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1", "doc-2", "doc-3"]
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing"
@patch("services.document_indexing_task_proxy.FeatureService")
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features()
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
# Act
features1 = proxy.features
features2 = proxy.features # Second call should use cached property
# Assert
assert features1 == mock_features
assert features2 == mock_features
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
mock_task.delay = Mock()
# Act
proxy._send_to_direct_queue(mock_task)
# Assert
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=True
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once()
pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0]
assert len(pushed_tasks) == 1
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
mock_task.delay.assert_not_called()
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=False
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_default_tenant_queue(self, mock_task):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_default_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_tenant_queue(self, mock_task):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_priority_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_direct_queue(self, mock_task):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_direct_queue = Mock()
# Act
proxy._send_to_priority_direct_queue()
# Assert
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_default_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.TEAM
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# If billing enabled with non sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_disabled(self, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_direct_queue = Mock()
# Act
proxy._dispatch()
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_delay_method(self, mock_feature_service):
"""Test delay method integration."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy.delay()
# Assert
# If billing enabled with sandbox plan, should send to default tenant queue
proxy._send_to_default_tenant_queue.assert_called_once()
def test_document_task_dataclass(self):
"""Test DocumentTask dataclass."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1", "doc-2"]
# Act
task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
# Assert
assert task.tenant_id == tenant_id
assert task.dataset_id == dataset_id
assert task.document_ids == document_ids
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
def test_initialization_with_empty_document_ids(self):
"""Test initialization with empty document_ids list."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = []
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
def test_initialization_with_single_document_id(self):
"""Test initialization with single document_id."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1"]
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids

View File

@@ -0,0 +1,33 @@
import sqlalchemy as sa
from models.dataset import Document
from services.dataset_service import DocumentService
def test_normalize_display_status_alias_mapping():
assert DocumentService.normalize_display_status("ACTIVE") == "available"
assert DocumentService.normalize_display_status("enabled") == "available"
assert DocumentService.normalize_display_status("archived") == "archived"
assert DocumentService.normalize_display_status("unknown") is None
def test_build_display_status_filters_available():
filters = DocumentService.build_display_status_filters("available")
assert len(filters) == 3
for condition in filters:
assert condition is not None
def test_apply_display_status_filter_applies_when_status_present():
query = sa.select(Document)
filtered = DocumentService.apply_display_status_filter(query, "queuing")
compiled = str(filtered.compile(compile_kwargs={"literal_binds": True}))
assert "WHERE" in compiled
assert "documents.indexing_status = 'waiting'" in compiled
def test_apply_display_status_filter_returns_same_when_invalid():
query = sa.select(Document)
filtered = DocumentService.apply_display_status_filter(query, "invalid")
compiled = str(filtered.compile(compile_kwargs={"literal_binds": True}))
assert "WHERE" not in compiled

View File

@@ -0,0 +1,203 @@
from pathlib import Path
from unittest.mock import Mock, create_autospec, patch
import pytest
from flask_restx import reqparse
from werkzeug.exceptions import BadRequest
from models.account import Account
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService
class TestMetadataBugCompleteValidation:
"""Complete test suite to verify the metadata nullable bug and its fix."""
def test_1_pydantic_layer_validation(self):
"""Test Layer 1: Pydantic model validation correctly rejects None values."""
# Pydantic should reject None values for required fields
with pytest.raises((ValueError, TypeError)):
MetadataArgs(type=None, name=None)
with pytest.raises((ValueError, TypeError)):
MetadataArgs(type="string", name=None)
with pytest.raises((ValueError, TypeError)):
MetadataArgs(type=None, name="test")
# Valid values should work
valid_args = MetadataArgs(type="string", name="test_name")
assert valid_args.type == "string"
assert valid_args.name == "test_name"
def test_2_business_logic_layer_crashes_on_none(self):
"""Test Layer 2: Business logic crashes when None values slip through."""
# Create mock that bypasses Pydantic validation
mock_metadata_args = Mock()
mock_metadata_args.name = None
mock_metadata_args.type = "string"
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
# Should crash with TypeError
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
# Test update method as well
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
def test_3_database_constraints_verification(self):
"""Test Layer 3: Verify database model has nullable=False constraints."""
from sqlalchemy import inspect
from models.dataset import DatasetMetadata
# Get table info
mapper = inspect(DatasetMetadata)
# Check that type and name columns are not nullable
type_column = mapper.columns["type"]
name_column = mapper.columns["name"]
assert type_column.nullable is False, "type column should be nullable=False"
assert name_column.nullable is False, "name column should be nullable=False"
def test_4_fixed_api_layer_rejects_null(self, app):
"""Test Layer 4: Fixed API configuration properly rejects null values."""
# Test Console API create endpoint (fixed)
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
with pytest.raises(BadRequest):
parser.parse_args()
# Test with just name being null
with app.test_request_context(json={"type": "string", "name": None}, content_type="application/json"):
with pytest.raises(BadRequest):
parser.parse_args()
# Test with just type being null
with app.test_request_context(json={"type": None, "name": "test"}, content_type="application/json"):
with pytest.raises(BadRequest):
parser.parse_args()
def test_5_fixed_api_accepts_valid_values(self, app):
"""Test that fixed API still accepts valid non-null values."""
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
with app.test_request_context(json={"type": "string", "name": "valid_name"}, content_type="application/json"):
args = parser.parse_args()
assert args["type"] == "string"
assert args["name"] == "valid_name"
def test_6_simulated_buggy_behavior(self, app):
"""Test simulating the original buggy behavior with nullable=True."""
# Simulate the old buggy configuration
buggy_parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=True, location="json")
.add_argument("name", type=str, required=True, nullable=True, location="json")
)
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
# This would pass in the buggy version
args = buggy_parser.parse_args()
assert args["type"] is None
assert args["name"] is None
# But would crash when trying to create MetadataArgs
with pytest.raises((ValueError, TypeError)):
MetadataArgs.model_validate(args)
def test_7_end_to_end_validation_layers(self):
"""Test all validation layers work together correctly."""
# Layer 1: API should reject null at parameter level (with fix)
# Layer 2: Pydantic should reject null at model level
# Layer 3: Business logic expects non-null
# Layer 4: Database enforces non-null
# Test that valid data flows through all layers
valid_data = {"type": "string", "name": "test_metadata"}
# Should create valid Pydantic object
metadata_args = MetadataArgs.model_validate(valid_data)
assert metadata_args.type == "string"
assert metadata_args.name == "test_metadata"
# Should not crash in business logic length check
assert len(metadata_args.name) <= 255 # This should not crash
assert len(metadata_args.type) > 0 # This should not crash
def test_8_verify_specific_fix_locations(self):
"""Verify that the specific locations mentioned in bug report are fixed."""
# Read the actual files to verify fixes
import os
# Console API create
console_create_file = "api/controllers/console/datasets/metadata.py"
if os.path.exists(console_create_file):
content = Path(console_create_file).read_text()
# Should contain nullable=False, not nullable=True
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
# Service API create
service_create_file = "api/controllers/service_api/dataset/metadata.py"
if os.path.exists(service_create_file):
content = Path(service_create_file).read_text()
# Should contain nullable=False, not nullable=True
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
assert "nullable=True" not in create_api_section
class TestMetadataValidationSummary:
"""Summary tests that demonstrate the complete validation architecture."""
def test_validation_layer_architecture(self):
"""Document and test the 4-layer validation architecture."""
# Layer 1: API Parameter Validation (Flask-RESTful reqparse)
# - Role: First line of defense, validates HTTP request parameters
# - Fixed: nullable=False ensures null values are rejected at API boundary
# Layer 2: Pydantic Model Validation
# - Role: Validates data structure and types before business logic
# - Working: Required fields without Optional[] reject None values
# Layer 3: Business Logic Validation
# - Role: Domain-specific validation (length checks, uniqueness, etc.)
# - Vulnerable: Direct len() calls crash on None values
# Layer 4: Database Constraints
# - Role: Final data integrity enforcement
# - Working: nullable=False prevents None values in database
# The bug was: Layer 1 allowed None, but Layers 2-4 expected non-None
# The fix: Make Layer 1 consistent with Layers 2-4
assert True # This test documents the architecture
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,127 @@
from unittest.mock import Mock, create_autospec, patch
import pytest
from flask_restx import reqparse
from models.account import Account
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService
class TestMetadataNullableBug:
"""Test case to reproduce the metadata nullable validation bug."""
def test_metadata_args_with_none_values_should_fail(self):
"""Test that MetadataArgs validation should reject None values."""
# This test demonstrates the expected behavior - should fail validation
with pytest.raises((ValueError, TypeError)):
# This should fail because Pydantic expects non-None values
MetadataArgs(type=None, name=None)
def test_metadata_service_create_with_none_name_crashes(self):
"""Test that MetadataService.create_metadata crashes when name is None."""
# Mock the MetadataArgs to bypass Pydantic validation
mock_metadata_args = Mock()
mock_metadata_args.name = None # This will cause len() to crash
mock_metadata_args.type = "string"
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
# This should crash with TypeError when calling len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
def test_metadata_service_update_with_none_name_crashes(self):
"""Test that MetadataService.update_metadata_name crashes when name is None."""
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
# This should crash with TypeError when calling len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
def test_api_parser_accepts_null_values(self, app):
"""Test that API parser configuration incorrectly accepts null values."""
# Simulate the current API parser configuration
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=True, location="json")
.add_argument("name", type=str, required=True, nullable=True, location="json")
)
# Simulate request data with null values
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
# This should parse successfully due to nullable=True
args = parser.parse_args()
# Verify that null values are accepted
assert args["type"] is None
assert args["name"] is None
# This demonstrates the bug: API accepts None but business logic will crash
def test_integration_bug_scenario(self, app):
"""Test the complete bug scenario from API to service layer."""
# Step 1: API parser accepts null values (current buggy behavior)
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=True, location="json")
.add_argument("name", type=str, required=True, nullable=True, location="json")
)
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
args = parser.parse_args()
# Step 2: Try to create MetadataArgs with None values
# This should fail at Pydantic validation level
with pytest.raises((ValueError, TypeError)):
metadata_args = MetadataArgs.model_validate(args)
# Step 3: If we bypass Pydantic (simulating the bug scenario)
# Move this outside the request context to avoid Flask-Login issues
mock_metadata_args = Mock()
mock_metadata_args.name = None # From args["name"]
mock_metadata_args.type = None # From args["type"]
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
# Step 4: Service layer crashes on len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
def test_correct_nullable_false_configuration_works(self, app):
"""Test that the correct nullable=False configuration works as expected."""
# This tests the FIXED configuration
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
# This should fail with BadRequest due to nullable=False
from werkzeug.exceptions import BadRequest
with pytest.raises(BadRequest):
parser.parse_args()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,153 @@
import unittest
from unittest.mock import MagicMock, patch
from models.dataset import Dataset, Document
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
MetadataDetail,
MetadataOperationData,
)
from services.metadata_service import MetadataService
class TestMetadataPartialUpdate(unittest.TestCase):
def setUp(self):
self.dataset = MagicMock(spec=Dataset)
self.dataset.id = "dataset_id"
self.dataset.built_in_field_enabled = False
self.document = MagicMock(spec=Document)
self.document.id = "doc_id"
self.document.doc_metadata = {"existing_key": "existing_value"}
self.document.data_source_type = "upload_file"
@patch("services.metadata_service.db")
@patch("services.metadata_service.DocumentService")
@patch("services.metadata_service.current_account_with_tenant")
@patch("services.metadata_service.redis_client")
def test_partial_update_merges_metadata(self, mock_redis, mock_current_account, mock_document_service, mock_db):
# Setup mocks
mock_redis.get.return_value = None
mock_document_service.get_document.return_value = self.document
mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id")
# Mock DB query for existing bindings
# No existing binding for new key
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
# Input data
operation = DocumentMetadataOperation(
document_id="doc_id",
metadata_list=[MetadataDetail(id="new_meta_id", name="new_key", value="new_value")],
partial_update=True,
)
metadata_args = MetadataOperationData(operation_data=[operation])
# Execute
MetadataService.update_documents_metadata(self.dataset, metadata_args)
# Verify
# 1. Check that doc_metadata contains BOTH existing and new keys
expected_metadata = {"existing_key": "existing_value", "new_key": "new_value"}
assert self.document.doc_metadata == expected_metadata
# 2. Check that existing bindings were NOT deleted
# The delete call in the original code: db.session.query(...).filter_by(...).delete()
# In partial update, this should NOT be called.
mock_db.session.query.return_value.filter_by.return_value.delete.assert_not_called()
@patch("services.metadata_service.db")
@patch("services.metadata_service.DocumentService")
@patch("services.metadata_service.current_account_with_tenant")
@patch("services.metadata_service.redis_client")
def test_full_update_replaces_metadata(self, mock_redis, mock_current_account, mock_document_service, mock_db):
# Setup mocks
mock_redis.get.return_value = None
mock_document_service.get_document.return_value = self.document
mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id")
# Input data (partial_update=False by default)
operation = DocumentMetadataOperation(
document_id="doc_id",
metadata_list=[MetadataDetail(id="new_meta_id", name="new_key", value="new_value")],
partial_update=False,
)
metadata_args = MetadataOperationData(operation_data=[operation])
# Execute
MetadataService.update_documents_metadata(self.dataset, metadata_args)
# Verify
# 1. Check that doc_metadata contains ONLY the new key
expected_metadata = {"new_key": "new_value"}
assert self.document.doc_metadata == expected_metadata
# 2. Check that existing bindings WERE deleted
# In full update (default), we expect the existing bindings to be cleared.
mock_db.session.query.return_value.filter_by.return_value.delete.assert_called()
@patch("services.metadata_service.db")
@patch("services.metadata_service.DocumentService")
@patch("services.metadata_service.current_account_with_tenant")
@patch("services.metadata_service.redis_client")
def test_partial_update_skips_existing_binding(
self, mock_redis, mock_current_account, mock_document_service, mock_db
):
# Setup mocks
mock_redis.get.return_value = None
mock_document_service.get_document.return_value = self.document
mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id")
# Mock DB query to return an existing binding
# This simulates that the document ALREADY has the metadata we are trying to add
mock_existing_binding = MagicMock()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_existing_binding
# Input data
operation = DocumentMetadataOperation(
document_id="doc_id",
metadata_list=[MetadataDetail(id="existing_meta_id", name="existing_key", value="existing_value")],
partial_update=True,
)
metadata_args = MetadataOperationData(operation_data=[operation])
# Execute
MetadataService.update_documents_metadata(self.dataset, metadata_args)
# Verify
# We verify that db.session.add was NOT called for DatasetMetadataBinding
# Since we can't easily check "not called with specific type" on the generic add method without complex logic,
# we can check if the number of add calls is 1 (only for the document update) instead of 2 (document + binding)
# Expected calls:
# 1. db.session.add(document)
# 2. NO db.session.add(binding) because it exists
# Note: In the code, db.session.add is called for document.
# Then loop over metadata_list.
# If existing_binding found, continue.
# So binding add should be skipped.
# Let's filter the calls to add to see what was added
add_calls = mock_db.session.add.call_args_list
added_objects = [call.args[0] for call in add_calls]
# Check that no DatasetMetadataBinding was added
from models.dataset import DatasetMetadataBinding
has_binding_add = any(
isinstance(obj, DatasetMetadataBinding)
or (isinstance(obj, MagicMock) and getattr(obj, "__class__", None) == DatasetMetadataBinding)
for obj in added_objects
)
# Since we mock everything, checking isinstance might be tricky if DatasetMetadataBinding
# is not the exact class used in the service (imports match).
# But we can check the count.
# If it were added, there would be 2 calls. If skipped, 1 call.
assert mock_db.session.add.call_count == 1
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,483 @@
import json
from unittest.mock import Mock, patch
import pytest
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
class RagPipelineTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for RagPipelineTaskProxy tests."""
@staticmethod
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()
features.billing.enabled = billing_enabled
features.billing.subscription = Mock()
features.billing.subscription.plan = plan
return features
@staticmethod
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
"""Create mock TenantIsolatedTaskQueue."""
queue = Mock(spec=TenantIsolatedTaskQueue)
queue.get_task_key.return_value = "task_key" if has_task_key else None
queue.push_tasks = Mock()
queue.set_task_waiting_time = Mock()
return queue
@staticmethod
def create_rag_pipeline_invoke_entity(
pipeline_id: str = "pipeline-123",
user_id: str = "user-456",
tenant_id: str = "tenant-789",
workflow_id: str = "workflow-101",
streaming: bool = True,
workflow_execution_id: str | None = None,
workflow_thread_pool_id: str | None = None,
) -> RagPipelineInvokeEntity:
"""Create RagPipelineInvokeEntity instance for testing."""
return RagPipelineInvokeEntity(
pipeline_id=pipeline_id,
application_generate_entity={"key": "value"},
user_id=user_id,
tenant_id=tenant_id,
workflow_id=workflow_id,
streaming=streaming,
workflow_execution_id=workflow_execution_id,
workflow_thread_pool_id=workflow_thread_pool_id,
)
@staticmethod
def create_rag_pipeline_task_proxy(
dataset_tenant_id: str = "tenant-123",
user_id: str = "user-456",
rag_pipeline_invoke_entities: list[RagPipelineInvokeEntity] | None = None,
) -> RagPipelineTaskProxy:
"""Create RagPipelineTaskProxy instance for testing."""
if rag_pipeline_invoke_entities is None:
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
return RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
@staticmethod
def create_mock_upload_file(file_id: str = "file-123") -> Mock:
"""Create mock upload file."""
upload_file = Mock()
upload_file.id = file_id
return upload_file
class TestRagPipelineTaskProxy:
"""Test cases for RagPipelineTaskProxy class."""
def test_initialization(self):
"""Test RagPipelineTaskProxy initialization."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert proxy._dataset_tenant_id == dataset_tenant_id
assert proxy._user_id == user_id
assert proxy._rag_pipeline_invoke_entities == rag_pipeline_invoke_entities
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
assert proxy._tenant_isolated_task_queue._tenant_id == dataset_tenant_id
assert proxy._tenant_isolated_task_queue._unique_key == "pipeline"
def test_initialization_with_empty_entities(self):
"""Test initialization with empty rag_pipeline_invoke_entities."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = []
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert proxy._dataset_tenant_id == dataset_tenant_id
assert proxy._user_id == user_id
assert proxy._rag_pipeline_invoke_entities == []
def test_initialization_with_multiple_entities(self):
"""Test initialization with multiple rag_pipeline_invoke_entities."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-3"),
]
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert len(proxy._rag_pipeline_invoke_entities) == 3
assert proxy._rag_pipeline_invoke_entities[0].pipeline_id == "pipeline-1"
assert proxy._rag_pipeline_invoke_entities[1].pipeline_id == "pipeline-2"
assert proxy._rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3"
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features()
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
# Act
features1 = proxy.features
features2 = proxy.features # Second call should use cached property
# Assert
assert features1 == mock_features
assert features2 == mock_features
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_upload_invoke_entities(self, mock_db, mock_file_service_class):
"""Test _upload_invoke_entities method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
result = proxy._upload_invoke_entities()
# Assert
assert result == "file-123"
mock_file_service_class.assert_called_once_with(mock_db.engine)
# Verify upload_text was called with correct parameters
mock_file_service.upload_text.assert_called_once()
call_args = mock_file_service.upload_text.call_args
json_text, name, user_id, tenant_id = call_args[0]
assert name == "rag_pipeline_invoke_entities.json"
assert user_id == "user-456"
assert tenant_id == "tenant-123"
# Verify JSON content
parsed_json = json.loads(json_text)
assert len(parsed_json) == 1
assert parsed_json[0]["pipeline_id"] == "pipeline-123"
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class):
"""Test _upload_invoke_entities method with multiple entities."""
# Arrange
entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
]
proxy = RagPipelineTaskProxy("tenant-123", "user-456", entities)
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-456")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
result = proxy._upload_invoke_entities()
# Assert
assert result == "file-456"
# Verify JSON content contains both entities
call_args = mock_file_service.upload_text.call_args
json_text = call_args[0][0]
parsed_json = json.loads(json_text)
assert len(parsed_json) == 2
assert parsed_json[0]["pipeline_id"] == "pipeline-1"
assert parsed_json[1]["pipeline_id"] == "pipeline-2"
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue()
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_direct_queue(upload_file_id, mock_task)
# If sent to direct queue, tenant_isolated_task_queue should not be called
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
# Celery should be called directly
mock_task.delay.assert_called_once_with(
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=True
)
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(upload_file_id, mock_task)
# If task key exists, should push tasks to the queue
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once_with([upload_file_id])
# Celery should not be called directly
mock_task.delay.assert_not_called()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=False
)
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(upload_file_id, mock_task)
# If no task key, should set task waiting time key first
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
mock_task.delay.assert_called_once_with(
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
)
# The first task should be sent to celery directly, so push tasks should not be called
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_default_tenant_queue(self, mock_task):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_tenant_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_default_tenant_queue(upload_file_id)
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
def test_send_to_priority_tenant_queue(self, mock_task):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_tenant_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_priority_tenant_queue(upload_file_id)
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
def test_send_to_priority_direct_queue(self, mock_task):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_direct_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_priority_direct_queue(upload_file_id)
# Assert
proxy._send_to_direct_queue.assert_called_once_with(upload_file_id, mock_task)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is enabled with sandbox plan, should send to default tenant queue
proxy._send_to_default_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_billing_enabled_non_sandbox_plan(
self, mock_db, mock_file_service_class, mock_feature_service
):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.TEAM
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is enabled with non-sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_direct_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class):
"""Test _dispatch method when upload_file_id is empty."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = Mock()
mock_upload_file.id = "" # Empty file ID
mock_file_service.upload_text.return_value = mock_upload_file
# Act & Assert
with pytest.raises(ValueError, match="upload_file_id is empty"):
proxy._dispatch()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test delay method integration."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._dispatch = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy.delay()
# Assert
proxy._dispatch.assert_called_once()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.logger")
def test_delay_method_with_empty_entities(self, mock_logger):
"""Test delay method with empty rag_pipeline_invoke_entities."""
# Arrange
proxy = RagPipelineTaskProxy("tenant-123", "user-456", [])
# Act
proxy.delay()
# Assert
mock_logger.warning.assert_called_once_with(
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s", "tenant-123", "user-456"
)

View File

@@ -0,0 +1,779 @@
import unittest
from datetime import UTC, datetime
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy.orm import Session
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError
from events.event_handlers.sync_workflow_schedule_when_app_published import (
sync_schedule_from_workflow,
)
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
from models.account import Account, TenantAccountJoin
from models.trigger import WorkflowSchedulePlan
from models.workflow import Workflow
from services.trigger.schedule_service import ScheduleService
class TestScheduleService(unittest.TestCase):
"""Test cases for ScheduleService class."""
def test_calculate_next_run_at_valid_cron(self):
"""Test calculating next run time with valid cron expression."""
# Test daily cron at 10:30 AM
cron_expr = "30 10 * * *"
timezone = "UTC"
base_time = datetime(2025, 8, 29, 9, 0, 0, tzinfo=UTC)
next_run = calculate_next_run_at(cron_expr, timezone, base_time)
assert next_run is not None
assert next_run.hour == 10
assert next_run.minute == 30
assert next_run.day == 29
def test_calculate_next_run_at_with_timezone(self):
"""Test calculating next run time with different timezone."""
cron_expr = "0 9 * * *" # 9:00 AM
timezone = "America/New_York"
base_time = datetime(2025, 8, 29, 12, 0, 0, tzinfo=UTC) # 8:00 AM EDT
next_run = calculate_next_run_at(cron_expr, timezone, base_time)
assert next_run is not None
# 9:00 AM EDT = 13:00 UTC (during EDT)
expected_utc_hour = 13
assert next_run.hour == expected_utc_hour
def test_calculate_next_run_at_with_last_day_of_month(self):
"""Test calculating next run time with 'L' (last day) syntax."""
cron_expr = "0 10 L * *" # 10:00 AM on last day of month
timezone = "UTC"
base_time = datetime(2025, 2, 15, 9, 0, 0, tzinfo=UTC)
next_run = calculate_next_run_at(cron_expr, timezone, base_time)
assert next_run is not None
# February 2025 has 28 days
assert next_run.day == 28
assert next_run.month == 2
def test_calculate_next_run_at_invalid_cron(self):
"""Test calculating next run time with invalid cron expression."""
cron_expr = "invalid cron"
timezone = "UTC"
with pytest.raises(ValueError):
calculate_next_run_at(cron_expr, timezone)
def test_calculate_next_run_at_invalid_timezone(self):
"""Test calculating next run time with invalid timezone."""
from pytz import UnknownTimeZoneError
cron_expr = "30 10 * * *"
timezone = "Invalid/Timezone"
with pytest.raises(UnknownTimeZoneError):
calculate_next_run_at(cron_expr, timezone)
@patch("libs.schedule_utils.calculate_next_run_at")
def test_create_schedule(self, mock_calculate_next_run):
"""Test creating a new schedule."""
mock_session = MagicMock(spec=Session)
mock_calculate_next_run.return_value = datetime(2025, 8, 30, 10, 30, 0, tzinfo=UTC)
config = ScheduleConfig(
node_id="start",
cron_expression="30 10 * * *",
timezone="UTC",
)
schedule = ScheduleService.create_schedule(
session=mock_session,
tenant_id="test-tenant",
app_id="test-app",
config=config,
)
assert schedule is not None
assert schedule.tenant_id == "test-tenant"
assert schedule.app_id == "test-app"
assert schedule.node_id == "start"
assert schedule.cron_expression == "30 10 * * *"
assert schedule.timezone == "UTC"
assert schedule.next_run_at is not None
mock_session.add.assert_called_once()
mock_session.flush.assert_called_once()
@patch("services.trigger.schedule_service.calculate_next_run_at")
def test_update_schedule(self, mock_calculate_next_run):
"""Test updating an existing schedule."""
mock_session = MagicMock(spec=Session)
mock_schedule = Mock(spec=WorkflowSchedulePlan)
mock_schedule.cron_expression = "0 12 * * *"
mock_schedule.timezone = "America/New_York"
mock_session.get.return_value = mock_schedule
mock_calculate_next_run.return_value = datetime(2025, 8, 30, 12, 0, 0, tzinfo=UTC)
updates = SchedulePlanUpdate(
cron_expression="0 12 * * *",
timezone="America/New_York",
)
result = ScheduleService.update_schedule(
session=mock_session,
schedule_id="test-schedule-id",
updates=updates,
)
assert result is not None
assert result.cron_expression == "0 12 * * *"
assert result.timezone == "America/New_York"
mock_calculate_next_run.assert_called_once()
mock_session.flush.assert_called_once()
def test_update_schedule_not_found(self):
"""Test updating a non-existent schedule raises exception."""
from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError
mock_session = MagicMock(spec=Session)
mock_session.get.return_value = None
updates = SchedulePlanUpdate(
cron_expression="0 12 * * *",
)
with pytest.raises(ScheduleNotFoundError) as context:
ScheduleService.update_schedule(
session=mock_session,
schedule_id="non-existent-id",
updates=updates,
)
assert "Schedule not found: non-existent-id" in str(context.value)
mock_session.flush.assert_not_called()
def test_delete_schedule(self):
"""Test deleting a schedule."""
mock_session = MagicMock(spec=Session)
mock_schedule = Mock(spec=WorkflowSchedulePlan)
mock_session.get.return_value = mock_schedule
# Should not raise exception and complete successfully
ScheduleService.delete_schedule(
session=mock_session,
schedule_id="test-schedule-id",
)
mock_session.delete.assert_called_once_with(mock_schedule)
mock_session.flush.assert_called_once()
def test_delete_schedule_not_found(self):
"""Test deleting a non-existent schedule raises exception."""
from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError
mock_session = MagicMock(spec=Session)
mock_session.get.return_value = None
# Should raise ScheduleNotFoundError
with pytest.raises(ScheduleNotFoundError) as context:
ScheduleService.delete_schedule(
session=mock_session,
schedule_id="non-existent-id",
)
assert "Schedule not found: non-existent-id" in str(context.value)
mock_session.delete.assert_not_called()
@patch("services.trigger.schedule_service.select")
def test_get_tenant_owner(self, mock_select):
"""Test getting tenant owner account."""
mock_session = MagicMock(spec=Session)
mock_account = Mock(spec=Account)
mock_account.id = "owner-account-id"
# Mock owner query
mock_owner_result = Mock(spec=TenantAccountJoin)
mock_owner_result.account_id = "owner-account-id"
mock_session.execute.return_value.scalar_one_or_none.return_value = mock_owner_result
mock_session.get.return_value = mock_account
result = ScheduleService.get_tenant_owner(
session=mock_session,
tenant_id="test-tenant",
)
assert result is not None
assert result.id == "owner-account-id"
@patch("services.trigger.schedule_service.select")
def test_get_tenant_owner_fallback_to_admin(self, mock_select):
"""Test getting tenant owner falls back to admin if no owner."""
mock_session = MagicMock(spec=Session)
mock_account = Mock(spec=Account)
mock_account.id = "admin-account-id"
# Mock admin query (owner returns None)
mock_admin_result = Mock(spec=TenantAccountJoin)
mock_admin_result.account_id = "admin-account-id"
mock_session.execute.return_value.scalar_one_or_none.side_effect = [None, mock_admin_result]
mock_session.get.return_value = mock_account
result = ScheduleService.get_tenant_owner(
session=mock_session,
tenant_id="test-tenant",
)
assert result is not None
assert result.id == "admin-account-id"
@patch("services.trigger.schedule_service.calculate_next_run_at")
def test_update_next_run_at(self, mock_calculate_next_run):
"""Test updating next run time after schedule triggered."""
mock_session = MagicMock(spec=Session)
mock_schedule = Mock(spec=WorkflowSchedulePlan)
mock_schedule.cron_expression = "30 10 * * *"
mock_schedule.timezone = "UTC"
mock_session.get.return_value = mock_schedule
next_time = datetime(2025, 8, 31, 10, 30, 0, tzinfo=UTC)
mock_calculate_next_run.return_value = next_time
result = ScheduleService.update_next_run_at(
session=mock_session,
schedule_id="test-schedule-id",
)
assert result == next_time
assert mock_schedule.next_run_at == next_time
mock_session.flush.assert_called_once()
class TestVisualToCron(unittest.TestCase):
"""Test cases for visual configuration to cron conversion."""
def test_visual_to_cron_hourly(self):
"""Test converting hourly visual config to cron."""
visual_config = VisualConfig(on_minute=15)
result = ScheduleService.visual_to_cron("hourly", visual_config)
assert result == "15 * * * *"
def test_visual_to_cron_daily(self):
"""Test converting daily visual config to cron."""
visual_config = VisualConfig(time="2:30 PM")
result = ScheduleService.visual_to_cron("daily", visual_config)
assert result == "30 14 * * *"
def test_visual_to_cron_weekly(self):
"""Test converting weekly visual config to cron."""
visual_config = VisualConfig(
time="10:00 AM",
weekdays=["mon", "wed", "fri"],
)
result = ScheduleService.visual_to_cron("weekly", visual_config)
assert result == "0 10 * * 1,3,5"
def test_visual_to_cron_monthly_with_specific_days(self):
"""Test converting monthly visual config with specific days."""
visual_config = VisualConfig(
time="11:30 AM",
monthly_days=[1, 15],
)
result = ScheduleService.visual_to_cron("monthly", visual_config)
assert result == "30 11 1,15 * *"
def test_visual_to_cron_monthly_with_last_day(self):
"""Test converting monthly visual config with last day using 'L' syntax."""
visual_config = VisualConfig(
time="11:30 AM",
monthly_days=[1, "last"],
)
result = ScheduleService.visual_to_cron("monthly", visual_config)
assert result == "30 11 1,L * *"
def test_visual_to_cron_monthly_only_last_day(self):
"""Test converting monthly visual config with only last day."""
visual_config = VisualConfig(
time="9:00 PM",
monthly_days=["last"],
)
result = ScheduleService.visual_to_cron("monthly", visual_config)
assert result == "0 21 L * *"
def test_visual_to_cron_monthly_with_end_days_and_last(self):
"""Test converting monthly visual config with days 29, 30, 31 and 'last'."""
visual_config = VisualConfig(
time="3:45 PM",
monthly_days=[29, 30, 31, "last"],
)
result = ScheduleService.visual_to_cron("monthly", visual_config)
# Should have 29,30,31,L - the L handles all possible last days
assert result == "45 15 29,30,31,L * *"
def test_visual_to_cron_invalid_frequency(self):
"""Test converting with invalid frequency."""
with pytest.raises(ScheduleConfigError, match="Unsupported frequency: invalid"):
ScheduleService.visual_to_cron("invalid", VisualConfig())
def test_visual_to_cron_weekly_no_weekdays(self):
"""Test converting weekly with no weekdays specified."""
visual_config = VisualConfig(time="10:00 AM")
with pytest.raises(ScheduleConfigError, match="Weekdays are required for weekly schedules"):
ScheduleService.visual_to_cron("weekly", visual_config)
def test_visual_to_cron_hourly_no_minute(self):
"""Test converting hourly with no on_minute specified."""
visual_config = VisualConfig() # on_minute defaults to 0
result = ScheduleService.visual_to_cron("hourly", visual_config)
assert result == "0 * * * *" # Should use default value 0
def test_visual_to_cron_daily_no_time(self):
"""Test converting daily with no time specified."""
visual_config = VisualConfig(time=None)
with pytest.raises(ScheduleConfigError, match="time is required for daily schedules"):
ScheduleService.visual_to_cron("daily", visual_config)
def test_visual_to_cron_weekly_no_time(self):
"""Test converting weekly with no time specified."""
visual_config = VisualConfig(weekdays=["mon"])
visual_config.time = None # Override default
with pytest.raises(ScheduleConfigError, match="time is required for weekly schedules"):
ScheduleService.visual_to_cron("weekly", visual_config)
def test_visual_to_cron_monthly_no_time(self):
"""Test converting monthly with no time specified."""
visual_config = VisualConfig(monthly_days=[1])
visual_config.time = None # Override default
with pytest.raises(ScheduleConfigError, match="time is required for monthly schedules"):
ScheduleService.visual_to_cron("monthly", visual_config)
def test_visual_to_cron_monthly_duplicate_days(self):
"""Test monthly with duplicate days should be deduplicated."""
visual_config = VisualConfig(
time="10:00 AM",
monthly_days=[1, 15, 1, 15, 31], # Duplicates
)
result = ScheduleService.visual_to_cron("monthly", visual_config)
assert result == "0 10 1,15,31 * *" # Should be deduplicated
def test_visual_to_cron_monthly_unsorted_days(self):
"""Test monthly with unsorted days should be sorted."""
visual_config = VisualConfig(
time="2:30 PM",
monthly_days=[20, 5, 15, 1, 10], # Unsorted
)
result = ScheduleService.visual_to_cron("monthly", visual_config)
assert result == "30 14 1,5,10,15,20 * *" # Should be sorted
def test_visual_to_cron_weekly_all_weekdays(self):
"""Test weekly with all weekdays."""
visual_config = VisualConfig(
time="8:00 AM",
weekdays=["sun", "mon", "tue", "wed", "thu", "fri", "sat"],
)
result = ScheduleService.visual_to_cron("weekly", visual_config)
assert result == "0 8 * * 0,1,2,3,4,5,6"
def test_visual_to_cron_hourly_boundary_values(self):
"""Test hourly with boundary minute values."""
# Minimum value
visual_config = VisualConfig(on_minute=0)
result = ScheduleService.visual_to_cron("hourly", visual_config)
assert result == "0 * * * *"
# Maximum value
visual_config = VisualConfig(on_minute=59)
result = ScheduleService.visual_to_cron("hourly", visual_config)
assert result == "59 * * * *"
def test_visual_to_cron_daily_midnight_noon(self):
"""Test daily at special times (midnight and noon)."""
# Midnight
visual_config = VisualConfig(time="12:00 AM")
result = ScheduleService.visual_to_cron("daily", visual_config)
assert result == "0 0 * * *"
# Noon
visual_config = VisualConfig(time="12:00 PM")
result = ScheduleService.visual_to_cron("daily", visual_config)
assert result == "0 12 * * *"
def test_visual_to_cron_monthly_mixed_with_last_and_duplicates(self):
"""Test monthly with mixed days, 'last', and duplicates."""
visual_config = VisualConfig(
time="11:45 PM",
monthly_days=[15, 1, "last", 15, 30, 1, "last"], # Mixed with duplicates
)
result = ScheduleService.visual_to_cron("monthly", visual_config)
assert result == "45 23 1,15,30,L * *" # Deduplicated and sorted with L at end
def test_visual_to_cron_weekly_single_day(self):
"""Test weekly with single weekday."""
visual_config = VisualConfig(
time="6:30 PM",
weekdays=["sun"],
)
result = ScheduleService.visual_to_cron("weekly", visual_config)
assert result == "30 18 * * 0"
def test_visual_to_cron_monthly_all_possible_days(self):
"""Test monthly with all 31 days plus 'last'."""
all_days = list(range(1, 32)) + ["last"]
visual_config = VisualConfig(
time="12:01 AM",
monthly_days=all_days,
)
result = ScheduleService.visual_to_cron("monthly", visual_config)
expected_days = ",".join([str(i) for i in range(1, 32)]) + ",L"
assert result == f"1 0 {expected_days} * *"
def test_visual_to_cron_monthly_no_days(self):
"""Test monthly without any days specified should raise error."""
visual_config = VisualConfig(time="10:00 AM", monthly_days=[])
with pytest.raises(ScheduleConfigError, match="Monthly days are required for monthly schedules"):
ScheduleService.visual_to_cron("monthly", visual_config)
def test_visual_to_cron_weekly_empty_weekdays_list(self):
"""Test weekly with empty weekdays list should raise error."""
visual_config = VisualConfig(time="10:00 AM", weekdays=[])
with pytest.raises(ScheduleConfigError, match="Weekdays are required for weekly schedules"):
ScheduleService.visual_to_cron("weekly", visual_config)
class TestParseTime(unittest.TestCase):
"""Test cases for time parsing function."""
def test_parse_time_am(self):
"""Test parsing AM time."""
hour, minute = convert_12h_to_24h("9:30 AM")
assert hour == 9
assert minute == 30
def test_parse_time_pm(self):
"""Test parsing PM time."""
hour, minute = convert_12h_to_24h("2:45 PM")
assert hour == 14
assert minute == 45
def test_parse_time_noon(self):
"""Test parsing 12:00 PM (noon)."""
hour, minute = convert_12h_to_24h("12:00 PM")
assert hour == 12
assert minute == 0
def test_parse_time_midnight(self):
"""Test parsing 12:00 AM (midnight)."""
hour, minute = convert_12h_to_24h("12:00 AM")
assert hour == 0
assert minute == 0
def test_parse_time_invalid_format(self):
"""Test parsing invalid time format."""
with pytest.raises(ValueError, match="Invalid time format"):
convert_12h_to_24h("25:00")
def test_parse_time_invalid_hour(self):
"""Test parsing invalid hour."""
with pytest.raises(ValueError, match="Invalid hour: 13"):
convert_12h_to_24h("13:00 PM")
def test_parse_time_invalid_minute(self):
"""Test parsing invalid minute."""
with pytest.raises(ValueError, match="Invalid minute: 60"):
convert_12h_to_24h("10:60 AM")
def test_parse_time_empty_string(self):
"""Test parsing empty string."""
with pytest.raises(ValueError, match="Time string cannot be empty"):
convert_12h_to_24h("")
def test_parse_time_invalid_period(self):
"""Test parsing invalid period."""
with pytest.raises(ValueError, match="Invalid period"):
convert_12h_to_24h("10:30 XM")
class TestExtractScheduleConfig(unittest.TestCase):
"""Test cases for extracting schedule configuration from workflow."""
def test_extract_schedule_config_with_cron_mode(self):
"""Test extracting schedule config in cron mode."""
workflow = Mock(spec=Workflow)
workflow.graph_dict = {
"nodes": [
{
"id": "schedule-node",
"data": {
"type": "trigger-schedule",
"mode": "cron",
"cron_expression": "0 10 * * *",
"timezone": "America/New_York",
},
}
]
}
config = ScheduleService.extract_schedule_config(workflow)
assert config is not None
assert config.node_id == "schedule-node"
assert config.cron_expression == "0 10 * * *"
assert config.timezone == "America/New_York"
def test_extract_schedule_config_with_visual_mode(self):
"""Test extracting schedule config in visual mode."""
workflow = Mock(spec=Workflow)
workflow.graph_dict = {
"nodes": [
{
"id": "schedule-node",
"data": {
"type": "trigger-schedule",
"mode": "visual",
"frequency": "daily",
"visual_config": {"time": "10:30 AM"},
"timezone": "UTC",
},
}
]
}
config = ScheduleService.extract_schedule_config(workflow)
assert config is not None
assert config.node_id == "schedule-node"
assert config.cron_expression == "30 10 * * *"
assert config.timezone == "UTC"
def test_extract_schedule_config_no_schedule_node(self):
"""Test extracting config when no schedule node exists."""
workflow = Mock(spec=Workflow)
workflow.graph_dict = {
"nodes": [
{
"id": "other-node",
"data": {"type": "llm"},
}
]
}
config = ScheduleService.extract_schedule_config(workflow)
assert config is None
def test_extract_schedule_config_invalid_graph(self):
"""Test extracting config with invalid graph data."""
workflow = Mock(spec=Workflow)
workflow.graph_dict = None
with pytest.raises(ScheduleConfigError, match="Workflow graph is empty"):
ScheduleService.extract_schedule_config(workflow)
class TestScheduleWithTimezone(unittest.TestCase):
"""Test cases for schedule with timezone handling."""
def test_visual_schedule_with_timezone_integration(self):
"""Test complete flow: visual config → cron → execution in different timezones.
This test verifies that when a user in Shanghai sets a schedule for 10:30 AM,
it runs at 10:30 AM Shanghai time, not 10:30 AM UTC.
"""
# User in Shanghai wants to run a task at 10:30 AM local time
visual_config = VisualConfig(
time="10:30 AM", # This is Shanghai time
monthly_days=[1],
)
# Convert to cron expression
cron_expr = ScheduleService.visual_to_cron("monthly", visual_config)
assert cron_expr is not None
assert cron_expr == "30 10 1 * *" # Direct conversion
# Now test execution with Shanghai timezone
shanghai_tz = "Asia/Shanghai"
# Base time: 2025-01-01 00:00:00 UTC (08:00:00 Shanghai)
base_time = datetime(2025, 1, 1, 0, 0, 0, tzinfo=UTC)
next_run = calculate_next_run_at(cron_expr, shanghai_tz, base_time)
assert next_run is not None
# Should run at 10:30 AM Shanghai time on Jan 1
# 10:30 AM Shanghai = 02:30 AM UTC (Shanghai is UTC+8)
assert next_run.year == 2025
assert next_run.month == 1
assert next_run.day == 1
assert next_run.hour == 2 # 02:30 UTC
assert next_run.minute == 30
def test_visual_schedule_different_timezones_same_local_time(self):
"""Test that same visual config in different timezones runs at different UTC times.
This verifies that a schedule set for "9:00 AM" runs at 9 AM local time
regardless of the timezone.
"""
visual_config = VisualConfig(
time="9:00 AM",
weekdays=["mon"],
)
cron_expr = ScheduleService.visual_to_cron("weekly", visual_config)
assert cron_expr is not None
assert cron_expr == "0 9 * * 1"
# Base time: Sunday 2025-01-05 12:00:00 UTC
base_time = datetime(2025, 1, 5, 12, 0, 0, tzinfo=UTC)
# Test New York (UTC-5 in January)
ny_next = calculate_next_run_at(cron_expr, "America/New_York", base_time)
assert ny_next is not None
# Monday 9 AM EST = Monday 14:00 UTC
assert ny_next.day == 6
assert ny_next.hour == 14 # 9 AM EST = 2 PM UTC
# Test Tokyo (UTC+9)
tokyo_next = calculate_next_run_at(cron_expr, "Asia/Tokyo", base_time)
assert tokyo_next is not None
# Monday 9 AM JST = Monday 00:00 UTC
assert tokyo_next.day == 6
assert tokyo_next.hour == 0 # 9 AM JST = 0 AM UTC
def test_visual_schedule_daily_across_dst_change(self):
"""Test that daily schedules adjust correctly during DST changes.
A schedule set for "10:00 AM" should always run at 10 AM local time,
even when DST changes.
"""
visual_config = VisualConfig(
time="10:00 AM",
)
cron_expr = ScheduleService.visual_to_cron("daily", visual_config)
assert cron_expr is not None
assert cron_expr == "0 10 * * *"
# Test before DST (EST - UTC-5)
winter_base = datetime(2025, 2, 1, 0, 0, 0, tzinfo=UTC)
winter_next = calculate_next_run_at(cron_expr, "America/New_York", winter_base)
assert winter_next is not None
# 10 AM EST = 15:00 UTC
assert winter_next.hour == 15
# Test during DST (EDT - UTC-4)
summer_base = datetime(2025, 6, 1, 0, 0, 0, tzinfo=UTC)
summer_next = calculate_next_run_at(cron_expr, "America/New_York", summer_base)
assert summer_next is not None
# 10 AM EDT = 14:00 UTC
assert summer_next.hour == 14
class TestSyncScheduleFromWorkflow(unittest.TestCase):
"""Test cases for syncing schedule from workflow."""
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.db")
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService")
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.select")
def test_sync_schedule_create_new(self, mock_select, mock_service, mock_db):
"""Test creating new schedule when none exists."""
mock_session = MagicMock()
mock_db.engine = MagicMock()
mock_session.__enter__ = MagicMock(return_value=mock_session)
mock_session.__exit__ = MagicMock(return_value=None)
Session = MagicMock(return_value=mock_session)
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
mock_session.scalar.return_value = None # No existing plan
# Mock extract_schedule_config to return a ScheduleConfig object
mock_config = Mock(spec=ScheduleConfig)
mock_config.node_id = "start"
mock_config.cron_expression = "30 10 * * *"
mock_config.timezone = "UTC"
mock_service.extract_schedule_config.return_value = mock_config
mock_new_plan = Mock(spec=WorkflowSchedulePlan)
mock_service.create_schedule.return_value = mock_new_plan
workflow = Mock(spec=Workflow)
result = sync_schedule_from_workflow("tenant-id", "app-id", workflow)
assert result == mock_new_plan
mock_service.create_schedule.assert_called_once()
mock_session.commit.assert_called_once()
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.db")
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService")
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.select")
def test_sync_schedule_update_existing(self, mock_select, mock_service, mock_db):
"""Test updating existing schedule."""
mock_session = MagicMock()
mock_db.engine = MagicMock()
mock_session.__enter__ = MagicMock(return_value=mock_session)
mock_session.__exit__ = MagicMock(return_value=None)
Session = MagicMock(return_value=mock_session)
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
mock_existing_plan = Mock(spec=WorkflowSchedulePlan)
mock_existing_plan.id = "existing-plan-id"
mock_session.scalar.return_value = mock_existing_plan
# Mock extract_schedule_config to return a ScheduleConfig object
mock_config = Mock(spec=ScheduleConfig)
mock_config.node_id = "start"
mock_config.cron_expression = "0 12 * * *"
mock_config.timezone = "America/New_York"
mock_service.extract_schedule_config.return_value = mock_config
mock_updated_plan = Mock(spec=WorkflowSchedulePlan)
mock_service.update_schedule.return_value = mock_updated_plan
workflow = Mock(spec=Workflow)
result = sync_schedule_from_workflow("tenant-id", "app-id", workflow)
assert result == mock_updated_plan
mock_service.update_schedule.assert_called_once()
# Verify the arguments passed to update_schedule
call_args = mock_service.update_schedule.call_args
assert call_args.kwargs["session"] == mock_session
assert call_args.kwargs["schedule_id"] == "existing-plan-id"
updates_obj = call_args.kwargs["updates"]
assert isinstance(updates_obj, SchedulePlanUpdate)
assert updates_obj.node_id == "start"
assert updates_obj.cron_expression == "0 12 * * *"
assert updates_obj.timezone == "America/New_York"
mock_session.commit.assert_called_once()
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.db")
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService")
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.select")
def test_sync_schedule_remove_when_no_config(self, mock_select, mock_service, mock_db):
"""Test removing schedule when no schedule config in workflow."""
mock_session = MagicMock()
mock_db.engine = MagicMock()
mock_session.__enter__ = MagicMock(return_value=mock_session)
mock_session.__exit__ = MagicMock(return_value=None)
Session = MagicMock(return_value=mock_session)
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
mock_existing_plan = Mock(spec=WorkflowSchedulePlan)
mock_existing_plan.id = "existing-plan-id"
mock_session.scalar.return_value = mock_existing_plan
mock_service.extract_schedule_config.return_value = None # No schedule config
workflow = Mock(spec=Workflow)
result = sync_schedule_from_workflow("tenant-id", "app-id", workflow)
assert result is None
# Now using ScheduleService.delete_schedule instead of session.delete
mock_service.delete_schedule.assert_called_once_with(session=mock_session, schedule_id="existing-plan-id")
mock_session.commit.assert_called_once()
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,629 @@
"""
Comprehensive unit tests for VariableTruncator class based on current implementation.
This test suite covers all functionality of the current VariableTruncator including:
- JSON size calculation for different data types
- String, array, and object truncation logic
- Segment-based truncation interface
- Helper methods for budget-based truncation
- Edge cases and error handling
"""
import functools
import json
import uuid
from typing import Any
from uuid import uuid4
import pytest
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
from core.variables.segments import (
ArrayFileSegment,
ArrayNumberSegment,
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
StringSegment,
)
from services.variable_truncator import (
DummyVariableTruncator,
MaxDepthExceededError,
TruncationResult,
UnknownTypeError,
VariableTruncator,
)
@pytest.fixture
def file() -> File:
return File(
id=str(uuid4()), # Generate new UUID for File.id
tenant_id=str(uuid.uuid4()),
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id=str(uuid.uuid4()),
filename="test_file.txt",
extension=".txt",
mime_type="text/plain",
size=1024,
storage_key="initial_key",
)
_compact_json_dumps = functools.partial(json.dumps, separators=(",", ":"))
class TestCalculateJsonSize:
"""Test calculate_json_size method with different data types."""
@pytest.fixture
def truncator(self):
return VariableTruncator()
def test_string_size_calculation(self):
"""Test JSON size calculation for strings."""
# Simple ASCII string
assert VariableTruncator.calculate_json_size("hello") == 7 # "hello" + 2 quotes
# Empty string
assert VariableTruncator.calculate_json_size("") == 2 # Just quotes
# Unicode string
assert VariableTruncator.calculate_json_size("你好") == 4
def test_number_size_calculation(self, truncator):
"""Test JSON size calculation for numbers."""
assert truncator.calculate_json_size(123) == 3
assert truncator.calculate_json_size(12.34) == 5
assert truncator.calculate_json_size(-456) == 4
assert truncator.calculate_json_size(0) == 1
def test_boolean_size_calculation(self, truncator):
"""Test JSON size calculation for booleans."""
assert truncator.calculate_json_size(True) == 4 # "true"
assert truncator.calculate_json_size(False) == 5 # "false"
def test_null_size_calculation(self, truncator):
"""Test JSON size calculation for None/null."""
assert truncator.calculate_json_size(None) == 4 # "null"
def test_array_size_calculation(self, truncator):
"""Test JSON size calculation for arrays."""
# Empty array
assert truncator.calculate_json_size([]) == 2 # "[]"
# Simple array
simple_array = [1, 2, 3]
# [1,2,3] = 1 + 1 + 1 + 1 + 1 + 2 = 7 (numbers + commas + brackets)
assert truncator.calculate_json_size(simple_array) == 7
# Array with strings
string_array = ["a", "b"]
# ["a","b"] = 3 + 3 + 1 + 2 = 9 (quoted strings + comma + brackets)
assert truncator.calculate_json_size(string_array) == 9
def test_object_size_calculation(self, truncator):
"""Test JSON size calculation for objects."""
# Empty object
assert truncator.calculate_json_size({}) == 2 # "{}"
# Simple object
simple_obj = {"a": 1}
# {"a":1} = 3 + 1 + 1 + 2 = 7 (key + colon + value + brackets)
assert truncator.calculate_json_size(simple_obj) == 7
# Multiple keys
multi_obj = {"a": 1, "b": 2}
# {"a":1,"b":2} = 3 + 1 + 1 + 1 + 3 + 1 + 1 + 2 = 13
assert truncator.calculate_json_size(multi_obj) == 13
def test_nested_structure_size_calculation(self, truncator):
"""Test JSON size calculation for nested structures."""
nested = {"items": [1, 2, {"nested": "value"}]}
size = truncator.calculate_json_size(nested)
assert size > 0 # Should calculate without error
# Verify it matches actual JSON length roughly
actual_json = _compact_json_dumps(nested)
# Should be close but not exact due to UTF-8 encoding considerations
assert abs(size - len(actual_json.encode())) <= 5
def test_calculate_json_size_max_depth_exceeded(self, truncator):
"""Test that calculate_json_size handles deep nesting gracefully."""
# Create deeply nested structure
nested: dict[str, Any] = {"level": 0}
current = nested
for i in range(105): # Create deep nesting
current["next"] = {"level": i + 1}
current = current["next"]
# Should either raise an error or handle gracefully
with pytest.raises(MaxDepthExceededError):
truncator.calculate_json_size(nested)
def test_calculate_json_size_unknown_type(self, truncator):
"""Test that calculate_json_size raises error for unknown types."""
class CustomType:
pass
with pytest.raises(UnknownTypeError):
truncator.calculate_json_size(CustomType())
class TestStringTruncation:
LENGTH_LIMIT = 10
"""Test string truncation functionality."""
@pytest.fixture
def small_truncator(self):
return VariableTruncator(string_length_limit=10)
def test_short_string_no_truncation(self, small_truncator):
"""Test that short strings are not truncated."""
short_str = "hello"
result = small_truncator._truncate_string(short_str, self.LENGTH_LIMIT)
assert result.value == short_str
assert result.truncated is False
assert result.value_size == VariableTruncator.calculate_json_size(short_str)
def test_long_string_truncation(self, small_truncator: VariableTruncator):
"""Test that long strings are truncated with ellipsis."""
long_str = "this is a very long string that exceeds the limit"
result = small_truncator._truncate_string(long_str, self.LENGTH_LIMIT)
assert result.truncated is True
assert result.value == long_str[:5] + "..."
assert result.value_size == 10 # 10 chars + "..."
def test_exact_limit_string(self, small_truncator: VariableTruncator):
"""Test string exactly at limit."""
exact_str = "1234567890" # Exactly 10 chars
result = small_truncator._truncate_string(exact_str, self.LENGTH_LIMIT)
assert result.value == "12345..."
assert result.truncated is True
assert result.value_size == 10
class TestArrayTruncation:
"""Test array truncation functionality."""
@pytest.fixture
def small_truncator(self):
return VariableTruncator(array_element_limit=3, max_size_bytes=100)
def test_small_array_no_truncation(self, small_truncator: VariableTruncator):
"""Test that small arrays are not truncated."""
small_array = [1, 2]
result = small_truncator._truncate_array(small_array, 1000)
assert result.value == small_array
assert result.truncated is False
def test_array_element_limit_truncation(self, small_truncator: VariableTruncator):
"""Test that arrays over element limit are truncated."""
large_array = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3
result = small_truncator._truncate_array(large_array, 1000)
assert result.truncated is True
assert result.value == [1, 2, 3]
def test_array_size_budget_truncation(self, small_truncator: VariableTruncator):
"""Test array truncation due to size budget constraints."""
# Create array with strings that will exceed size budget
large_strings = ["very long string " * 5, "another long string " * 5]
result = small_truncator._truncate_array(large_strings, 50)
assert result.truncated is True
# Should have truncated the strings within the array
for item in result.value:
assert isinstance(item, str)
assert VariableTruncator.calculate_json_size(result.value) <= 50
def test_array_with_nested_objects(self, small_truncator):
"""Test array truncation with nested objects."""
nested_array = [
{"name": "item1", "data": "some data"},
{"name": "item2", "data": "more data"},
{"name": "item3", "data": "even more data"},
]
result = small_truncator._truncate_array(nested_array, 30)
assert isinstance(result.value, list)
assert len(result.value) <= 3
for item in result.value:
assert isinstance(item, dict)
class TestObjectTruncation:
"""Test object truncation functionality."""
@pytest.fixture
def small_truncator(self):
return VariableTruncator(max_size_bytes=100)
def test_small_object_no_truncation(self, small_truncator):
"""Test that small objects are not truncated."""
small_obj = {"a": 1, "b": 2}
result = small_truncator._truncate_object(small_obj, 1000)
assert result.value == small_obj
assert result.truncated is False
def test_empty_object_no_truncation(self, small_truncator):
"""Test that empty objects are not truncated."""
empty_obj = {}
result = small_truncator._truncate_object(empty_obj, 100)
assert result.value == empty_obj
assert result.truncated is False
def test_object_value_truncation(self, small_truncator):
"""Test object truncation where values are truncated to fit budget."""
obj_with_long_values = {
"key1": "very long string " * 10,
"key2": "another long string " * 10,
"key3": "third long string " * 10,
}
result = small_truncator._truncate_object(obj_with_long_values, 80)
assert result.truncated is True
assert isinstance(result.value, dict)
assert set(result.value.keys()).issubset(obj_with_long_values.keys())
# Values should be truncated if they exist
for key, value in result.value.items():
if isinstance(value, str):
original_value = obj_with_long_values[key]
# Value should be same or smaller
assert len(value) <= len(original_value)
def test_object_key_dropping(self, small_truncator):
"""Test object truncation where keys are dropped due to size constraints."""
large_obj = {f"key{i:02d}": f"value{i}" for i in range(20)}
result = small_truncator._truncate_object(large_obj, 50)
assert result.truncated is True
assert len(result.value) < len(large_obj)
# Should maintain sorted key order
result_keys = list(result.value.keys())
assert result_keys == sorted(result_keys)
def test_object_with_nested_structures(self, small_truncator):
"""Test object truncation with nested arrays and objects."""
nested_obj = {"simple": "value", "array": [1, 2, 3, 4, 5], "nested": {"inner": "data", "more": ["a", "b", "c"]}}
result = small_truncator._truncate_object(nested_obj, 60)
assert isinstance(result.value, dict)
class TestSegmentBasedTruncation:
"""Test the main truncate method that works with Segments."""
@pytest.fixture
def truncator(self):
return VariableTruncator()
@pytest.fixture
def small_truncator(self):
return VariableTruncator(string_length_limit=20, array_element_limit=3, max_size_bytes=200)
def test_integer_segment_no_truncation(self, truncator):
"""Test that integer segments are never truncated."""
segment = IntegerSegment(value=12345)
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is False
assert result.result == segment
def test_boolean_as_integer_segment(self, truncator):
"""Test boolean values in IntegerSegment are converted to int."""
segment = IntegerSegment(value=True)
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is False
assert isinstance(result.result, IntegerSegment)
assert result.result.value == 1 # True converted to 1
def test_float_segment_no_truncation(self, truncator):
"""Test that float segments are never truncated."""
segment = FloatSegment(value=123.456)
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is False
assert result.result == segment
def test_none_segment_no_truncation(self, truncator):
"""Test that None segments are never truncated."""
segment = NoneSegment()
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is False
assert result.result == segment
def test_file_segment_no_truncation(self, truncator, file):
"""Test that file segments are never truncated."""
file_segment = FileSegment(value=file)
result = truncator.truncate(file_segment)
assert result.result == file_segment
assert result.truncated is False
def test_array_file_segment_no_truncation(self, truncator, file):
"""Test that array file segments are never truncated."""
array_file_segment = ArrayFileSegment(value=[file] * 20)
result = truncator.truncate(array_file_segment)
assert result.result == array_file_segment
assert result.truncated is False
def test_string_segment_small_no_truncation(self, truncator):
"""Test small string segments are not truncated."""
segment = StringSegment(value="hello world")
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is False
assert result.result == segment
def test_string_segment_large_truncation(self, small_truncator):
"""Test large string segments are truncated."""
long_text = "this is a very long string that will definitely exceed the limit"
segment = StringSegment(value=long_text)
result = small_truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is True
assert isinstance(result.result, StringSegment)
assert len(result.result.value) < len(long_text)
assert result.result.value.endswith("...")
def test_array_segment_small_no_truncation(self, truncator):
"""Test small array segments are not truncated."""
from factories.variable_factory import build_segment
segment = build_segment([1, 2, 3])
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is False
assert result.result == segment
def test_array_segment_large_truncation(self, small_truncator):
"""Test large array segments are truncated."""
from factories.variable_factory import build_segment
large_array = list(range(10)) # Exceeds element limit of 3
segment = build_segment(large_array)
result = small_truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is True
assert isinstance(result.result, ArraySegment)
assert len(result.result.value) <= 3
def test_object_segment_small_no_truncation(self, truncator):
"""Test small object segments are not truncated."""
segment = ObjectSegment(value={"key": "value"})
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is False
assert result.result == segment
def test_object_segment_large_truncation(self, small_truncator):
"""Test large object segments are truncated."""
large_obj = {f"key{i}": f"very long value {i}" * 5 for i in range(5)}
segment = ObjectSegment(value=large_obj)
result = small_truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is True
assert isinstance(result.result, ObjectSegment)
# Object should be smaller or equal than original
original_size = small_truncator.calculate_json_size(large_obj)
result_size = small_truncator.calculate_json_size(result.result.value)
assert result_size <= original_size
def test_final_size_fallback_to_json_string(self, small_truncator):
"""Test final fallback when truncated result still exceeds size limit."""
# Create data that will still be large after initial truncation
large_nested_data = {"data": ["very long string " * 5] * 5, "more": {"nested": "content " * 20}}
segment = ObjectSegment(value=large_nested_data)
# Use very small limit to force JSON string fallback
tiny_truncator = VariableTruncator(max_size_bytes=50)
result = tiny_truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is True
assert isinstance(result.result, StringSegment)
# Should be JSON string with possible truncation
assert len(result.result.value) <= 53 # 50 + "..." = 53
def test_final_size_fallback_string_truncation(self, small_truncator):
"""Test final fallback for string that still exceeds limit."""
# Create very long string that exceeds string length limit
very_long_string = "x" * 6000 # Exceeds default string_length_limit of 5000
segment = StringSegment(value=very_long_string)
# Use small limit to test string fallback path
tiny_truncator = VariableTruncator(string_length_limit=100, max_size_bytes=50)
result = tiny_truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is True
assert isinstance(result.result, StringSegment)
# Should be truncated due to string limit or final size limit
assert len(result.result.value) <= 1000 # Much smaller than original
class TestEdgeCases:
"""Test edge cases and error conditions."""
def test_empty_inputs(self):
"""Test truncator with empty inputs."""
truncator = VariableTruncator()
# Empty string
result = truncator.truncate(StringSegment(value=""))
assert not result.truncated
assert result.result.value == ""
# Empty array
from factories.variable_factory import build_segment
result = truncator.truncate(build_segment([]))
assert not result.truncated
assert result.result.value == []
# Empty object
result = truncator.truncate(ObjectSegment(value={}))
assert not result.truncated
assert result.result.value == {}
def test_zero_and_negative_limits(self):
"""Test truncator behavior with zero or very small limits."""
# Zero string limit
with pytest.raises(ValueError):
truncator = VariableTruncator(string_length_limit=3)
with pytest.raises(ValueError):
truncator = VariableTruncator(array_element_limit=0)
with pytest.raises(ValueError):
truncator = VariableTruncator(max_size_bytes=0)
def test_unicode_and_special_characters(self):
"""Test truncator with unicode and special characters."""
truncator = VariableTruncator(string_length_limit=10)
# Unicode characters
unicode_text = "🌍🚀🌍🚀🌍🚀🌍🚀🌍🚀" # Each emoji counts as 1 character
result = truncator.truncate(StringSegment(value=unicode_text))
if len(unicode_text) > 10:
assert result.truncated is True
# Special JSON characters
special_chars = '{"key": "value with \\"quotes\\" and \\n newlines"}'
result = truncator.truncate(StringSegment(value=special_chars))
assert isinstance(result.result, StringSegment)
class TestIntegrationScenarios:
"""Test realistic integration scenarios."""
def test_workflow_output_scenario(self):
"""Test truncation of typical workflow output data."""
truncator = VariableTruncator()
workflow_data = {
"result": "success",
"data": {
"users": [
{"id": 1, "name": "Alice", "email": "alice@example.com"},
{"id": 2, "name": "Bob", "email": "bob@example.com"},
]
* 3, # Multiply to make it larger
"metadata": {
"count": 6,
"processing_time": "1.23s",
"details": "x" * 200, # Long string but not too long
},
},
}
segment = ObjectSegment(value=workflow_data)
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert isinstance(result.result, (ObjectSegment, StringSegment))
# Should handle complex nested structure appropriately
def test_large_text_processing_scenario(self):
"""Test truncation of large text data."""
truncator = VariableTruncator(string_length_limit=100)
large_text = "This is a very long text document. " * 20 # Make it larger than limit
segment = StringSegment(value=large_text)
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.truncated is True
assert isinstance(result.result, StringSegment)
assert len(result.result.value) <= 103 # 100 + "..."
assert result.result.value.endswith("...")
def test_mixed_data_types_scenario(self):
"""Test truncation with mixed data types in complex structure."""
truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300)
mixed_data = {
"strings": ["short", "medium length", "very long string " * 3],
"numbers": [1, 2.5, 999999],
"booleans": [True, False, True],
"nested": {
"more_strings": ["nested string " * 2],
"more_numbers": list(range(5)),
"deep": {"level": 3, "content": "deep content " * 3},
},
"nulls": [None, None],
}
segment = ObjectSegment(value=mixed_data)
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
# Should handle all data types appropriately
if result.truncated:
# Verify the result is smaller or equal than original
original_size = truncator.calculate_json_size(mixed_data)
if isinstance(result.result, ObjectSegment):
result_size = truncator.calculate_json_size(result.result.value)
assert result_size <= original_size
def test_file_and_array_file_variable_mapping(self, file):
truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300)
mapping = {"array_file": [file]}
truncated_mapping, truncated = truncator.truncate_variable_mapping(mapping)
assert truncated is False
assert truncated_mapping == mapping
def test_dummy_variable_truncator_methods():
"""Test DummyVariableTruncator methods work correctly."""
truncator = DummyVariableTruncator()
# Test truncate_variable_mapping
test_data: dict[str, Any] = {
"key1": "value1",
"key2": ["item1", "item2"],
"large_array": list(range(2000)),
}
result, is_truncated = truncator.truncate_variable_mapping(test_data)
assert result == test_data
assert not is_truncated
# Test truncate method
segment = StringSegment(value="test string")
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.result == segment
assert result.truncated is False
segment = ArrayNumberSegment(value=list(range(2000)))
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.result == segment
assert result.truncated is False

View File

@@ -0,0 +1,501 @@
from io import BytesIO
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from services.trigger.webhook_service import WebhookService
class TestWebhookServiceUnit:
"""Unit tests for WebhookService focusing on business logic without database dependencies."""
def test_extract_webhook_data_json(self):
"""Test webhook data extraction from JSON request."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/json", "Authorization": "Bearer token"},
query_string="version=1&format=json",
json={"message": "hello", "count": 42},
):
webhook_trigger = MagicMock()
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
assert webhook_data["method"] == "POST"
assert webhook_data["headers"]["Authorization"] == "Bearer token"
# Query params are now extracted as raw strings
assert webhook_data["query_params"]["version"] == "1"
assert webhook_data["query_params"]["format"] == "json"
assert webhook_data["body"]["message"] == "hello"
assert webhook_data["body"]["count"] == 42
assert webhook_data["files"] == {}
def test_extract_webhook_data_query_params_remain_strings(self):
"""Query parameters should be extracted as raw strings without automatic conversion."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="GET",
headers={"Content-Type": "application/json"},
query_string="count=42&threshold=3.14&enabled=true&note=text",
):
webhook_trigger = MagicMock()
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
# After refactoring, raw extraction keeps query params as strings
assert webhook_data["query_params"]["count"] == "42"
assert webhook_data["query_params"]["threshold"] == "3.14"
assert webhook_data["query_params"]["enabled"] == "true"
assert webhook_data["query_params"]["note"] == "text"
def test_extract_webhook_data_form_urlencoded(self):
"""Test webhook data extraction from form URL encoded request."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={"username": "test", "password": "secret"},
):
webhook_trigger = MagicMock()
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
assert webhook_data["method"] == "POST"
assert webhook_data["body"]["username"] == "test"
assert webhook_data["body"]["password"] == "secret"
def test_extract_webhook_data_multipart_with_files(self):
"""Test webhook data extraction from multipart form with files."""
app = Flask(__name__)
# Create a mock file
file_content = b"test file content"
file_storage = FileStorage(stream=BytesIO(file_content), filename="test.txt", content_type="text/plain")
with app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "multipart/form-data"},
data={"message": "test", "upload": file_storage},
):
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
with patch.object(WebhookService, "_process_file_uploads") as mock_process_files:
mock_process_files.return_value = {"upload": "mocked_file_obj"}
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
assert webhook_data["method"] == "POST"
assert webhook_data["body"]["message"] == "test"
assert webhook_data["files"]["upload"] == "mocked_file_obj"
mock_process_files.assert_called_once()
def test_extract_webhook_data_raw_text(self):
"""Test webhook data extraction from raw text request."""
app = Flask(__name__)
with app.test_request_context(
"/webhook", method="POST", headers={"Content-Type": "text/plain"}, data="raw text content"
):
webhook_trigger = MagicMock()
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
assert webhook_data["method"] == "POST"
assert webhook_data["body"]["raw"] == "raw text content"
def test_extract_webhook_data_invalid_json(self):
"""Test webhook data extraction with invalid JSON."""
app = Flask(__name__)
with app.test_request_context(
"/webhook", method="POST", headers={"Content-Type": "application/json"}, data="invalid json"
):
webhook_trigger = MagicMock()
with pytest.raises(ValueError, match="Invalid JSON body"):
WebhookService.extract_webhook_data(webhook_trigger)
def test_generate_webhook_response_default(self):
"""Test webhook response generation with default values."""
node_config = {"data": {}}
response_data, status_code = WebhookService.generate_webhook_response(node_config)
assert status_code == 200
assert response_data["status"] == "success"
assert "Webhook processed successfully" in response_data["message"]
def test_generate_webhook_response_custom_json(self):
"""Test webhook response generation with custom JSON response."""
node_config = {"data": {"status_code": 201, "response_body": '{"result": "created", "id": 123}'}}
response_data, status_code = WebhookService.generate_webhook_response(node_config)
assert status_code == 201
assert response_data["result"] == "created"
assert response_data["id"] == 123
def test_generate_webhook_response_custom_text(self):
"""Test webhook response generation with custom text response."""
node_config = {"data": {"status_code": 202, "response_body": "Request accepted for processing"}}
response_data, status_code = WebhookService.generate_webhook_response(node_config)
assert status_code == 202
assert response_data["message"] == "Request accepted for processing"
def test_generate_webhook_response_invalid_json(self):
"""Test webhook response generation with invalid JSON response."""
node_config = {"data": {"status_code": 400, "response_body": '{"invalid": json}'}}
response_data, status_code = WebhookService.generate_webhook_response(node_config)
assert status_code == 400
assert response_data["message"] == '{"invalid": json}'
def test_generate_webhook_response_empty_response_body(self):
"""Test webhook response generation with empty response body."""
node_config = {"data": {"status_code": 204, "response_body": ""}}
response_data, status_code = WebhookService.generate_webhook_response(node_config)
assert status_code == 204
assert response_data["status"] == "success"
assert "Webhook processed successfully" in response_data["message"]
def test_generate_webhook_response_array_json(self):
"""Test webhook response generation with JSON array response."""
node_config = {"data": {"status_code": 200, "response_body": '[{"id": 1}, {"id": 2}]'}}
response_data, status_code = WebhookService.generate_webhook_response(node_config)
assert status_code == 200
assert isinstance(response_data, list)
assert len(response_data) == 2
assert response_data[0]["id"] == 1
assert response_data[1]["id"] == 2
@patch("services.trigger.webhook_service.ToolFileManager")
@patch("services.trigger.webhook_service.file_factory")
def test_process_file_uploads_success(self, mock_file_factory, mock_tool_file_manager):
"""Test successful file upload processing."""
# Mock ToolFileManager
mock_tool_file_instance = MagicMock()
mock_tool_file_manager.return_value = mock_tool_file_instance
# Mock file creation
mock_tool_file = MagicMock()
mock_tool_file.id = "test_file_id"
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
# Mock file factory
mock_file_obj = MagicMock()
mock_file_factory.build_from_mapping.return_value = mock_file_obj
# Create mock files
files = {
"file1": MagicMock(filename="test1.txt", content_type="text/plain"),
"file2": MagicMock(filename="test2.jpg", content_type="image/jpeg"),
}
# Mock file reads
files["file1"].read.return_value = b"content1"
files["file2"].read.return_value = b"content2"
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
result = WebhookService._process_file_uploads(files, webhook_trigger)
assert len(result) == 2
assert "file1" in result
assert "file2" in result
# Verify file processing was called for each file
assert mock_tool_file_manager.call_count == 2
assert mock_file_factory.build_from_mapping.call_count == 2
@patch("services.trigger.webhook_service.ToolFileManager")
@patch("services.trigger.webhook_service.file_factory")
def test_process_file_uploads_with_errors(self, mock_file_factory, mock_tool_file_manager):
"""Test file upload processing with errors."""
# Mock ToolFileManager
mock_tool_file_instance = MagicMock()
mock_tool_file_manager.return_value = mock_tool_file_instance
# Mock file creation
mock_tool_file = MagicMock()
mock_tool_file.id = "test_file_id"
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
# Mock file factory
mock_file_obj = MagicMock()
mock_file_factory.build_from_mapping.return_value = mock_file_obj
# Create mock files, one will fail
files = {
"good_file": MagicMock(filename="test.txt", content_type="text/plain"),
"bad_file": MagicMock(filename="test.bad", content_type="text/plain"),
}
files["good_file"].read.return_value = b"content"
files["bad_file"].read.side_effect = Exception("Read error")
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
result = WebhookService._process_file_uploads(files, webhook_trigger)
# Should process the good file and skip the bad one
assert len(result) == 1
assert "good_file" in result
assert "bad_file" not in result
def test_process_file_uploads_empty_filename(self):
"""Test file upload processing with empty filename."""
files = {
"no_filename": MagicMock(filename="", content_type="text/plain"),
"none_filename": MagicMock(filename=None, content_type="text/plain"),
}
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
result = WebhookService._process_file_uploads(files, webhook_trigger)
# Should skip files without filenames
assert len(result) == 0
def test_validate_json_value_string(self):
"""Test JSON value validation for string type."""
# Valid string
result = WebhookService._validate_json_value("name", "hello", "string")
assert result == "hello"
# Invalid string (number) - should raise ValueError
with pytest.raises(ValueError, match="Expected string, got int"):
WebhookService._validate_json_value("name", 123, "string")
def test_validate_json_value_number(self):
"""Test JSON value validation for number type."""
# Valid integer
result = WebhookService._validate_json_value("count", 42, "number")
assert result == 42
# Valid float
result = WebhookService._validate_json_value("price", 19.99, "number")
assert result == 19.99
# Invalid number (string) - should raise ValueError
with pytest.raises(ValueError, match="Expected number, got str"):
WebhookService._validate_json_value("count", "42", "number")
def test_validate_json_value_bool(self):
"""Test JSON value validation for boolean type."""
# Valid boolean
result = WebhookService._validate_json_value("enabled", True, "boolean")
assert result is True
result = WebhookService._validate_json_value("enabled", False, "boolean")
assert result is False
# Invalid boolean (string) - should raise ValueError
with pytest.raises(ValueError, match="Expected boolean, got str"):
WebhookService._validate_json_value("enabled", "true", "boolean")
def test_validate_json_value_object(self):
"""Test JSON value validation for object type."""
# Valid object
result = WebhookService._validate_json_value("user", {"name": "John", "age": 30}, "object")
assert result == {"name": "John", "age": 30}
# Invalid object (string) - should raise ValueError
with pytest.raises(ValueError, match="Expected object, got str"):
WebhookService._validate_json_value("user", "not_an_object", "object")
def test_validate_json_value_array_string(self):
"""Test JSON value validation for array[string] type."""
# Valid array of strings
result = WebhookService._validate_json_value("tags", ["tag1", "tag2", "tag3"], "array[string]")
assert result == ["tag1", "tag2", "tag3"]
# Invalid - not an array
with pytest.raises(ValueError, match="Expected array of strings, got str"):
WebhookService._validate_json_value("tags", "not_an_array", "array[string]")
# Invalid - array with non-strings
with pytest.raises(ValueError, match="Expected array of strings, got list"):
WebhookService._validate_json_value("tags", ["tag1", 123, "tag3"], "array[string]")
def test_validate_json_value_array_number(self):
"""Test JSON value validation for array[number] type."""
# Valid array of numbers
result = WebhookService._validate_json_value("scores", [1, 2.5, 3, 4.7], "array[number]")
assert result == [1, 2.5, 3, 4.7]
# Invalid - array with non-numbers
with pytest.raises(ValueError, match="Expected array of numbers, got list"):
WebhookService._validate_json_value("scores", [1, "2", 3], "array[number]")
def test_validate_json_value_array_bool(self):
"""Test JSON value validation for array[boolean] type."""
# Valid array of booleans
result = WebhookService._validate_json_value("flags", [True, False, True], "array[boolean]")
assert result == [True, False, True]
# Invalid - array with non-booleans
with pytest.raises(ValueError, match="Expected array of booleans, got list"):
WebhookService._validate_json_value("flags", [True, "false", True], "array[boolean]")
def test_validate_json_value_array_object(self):
"""Test JSON value validation for array[object] type."""
# Valid array of objects
result = WebhookService._validate_json_value("users", [{"name": "John"}, {"name": "Jane"}], "array[object]")
assert result == [{"name": "John"}, {"name": "Jane"}]
# Invalid - array with non-objects
with pytest.raises(ValueError, match="Expected array of objects, got list"):
WebhookService._validate_json_value("users", [{"name": "John"}, "not_object"], "array[object]")
def test_convert_form_value_string(self):
"""Test form value conversion for string type."""
result = WebhookService._convert_form_value("test", "hello", "string")
assert result == "hello"
def test_convert_form_value_number(self):
"""Test form value conversion for number type."""
# Integer
result = WebhookService._convert_form_value("count", "42", "number")
assert result == 42
# Float
result = WebhookService._convert_form_value("price", "19.99", "number")
assert result == 19.99
# Invalid number
with pytest.raises(ValueError, match="Cannot convert 'not_a_number' to number"):
WebhookService._convert_form_value("count", "not_a_number", "number")
def test_convert_form_value_boolean(self):
"""Test form value conversion for boolean type."""
# True values
assert WebhookService._convert_form_value("flag", "true", "boolean") is True
assert WebhookService._convert_form_value("flag", "1", "boolean") is True
assert WebhookService._convert_form_value("flag", "yes", "boolean") is True
# False values
assert WebhookService._convert_form_value("flag", "false", "boolean") is False
assert WebhookService._convert_form_value("flag", "0", "boolean") is False
assert WebhookService._convert_form_value("flag", "no", "boolean") is False
# Invalid boolean
with pytest.raises(ValueError, match="Cannot convert 'maybe' to boolean"):
WebhookService._convert_form_value("flag", "maybe", "boolean")
def test_extract_and_validate_webhook_data_success(self):
"""Test successful unified data extraction and validation."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/json"},
query_string="count=42&enabled=true",
json={"message": "hello", "age": 25},
):
webhook_trigger = MagicMock()
node_config = {
"data": {
"method": "post",
"content_type": "application/json",
"params": [
{"name": "count", "type": "number", "required": True},
{"name": "enabled", "type": "boolean", "required": True},
],
"body": [
{"name": "message", "type": "string", "required": True},
{"name": "age", "type": "number", "required": True},
],
}
}
result = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
# Check that types are correctly converted
assert result["query_params"]["count"] == 42 # Converted to int
assert result["query_params"]["enabled"] is True # Converted to bool
assert result["body"]["message"] == "hello" # Already string
assert result["body"]["age"] == 25 # Already number
def test_extract_and_validate_webhook_data_invalid_json_error(self):
"""Invalid JSON should bubble up as a ValueError with details."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="POST",
headers={"Content-Type": "application/json"},
data='{"invalid": }',
):
webhook_trigger = MagicMock()
node_config = {
"data": {
"method": "post",
"content_type": "application/json",
}
}
with pytest.raises(ValueError, match="Invalid JSON body"):
WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
def test_extract_and_validate_webhook_data_validation_error(self):
"""Test unified data extraction with validation error."""
app = Flask(__name__)
with app.test_request_context(
"/webhook",
method="GET", # Wrong method
headers={"Content-Type": "application/json"},
):
webhook_trigger = MagicMock()
node_config = {
"data": {
"method": "post", # Expects POST
"content_type": "application/json",
}
}
with pytest.raises(ValueError, match="HTTP method mismatch"):
WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
def test_debug_mode_parameter_handling(self):
"""Test that the debug mode parameter is properly handled in _prepare_webhook_execution."""
from controllers.trigger.webhook import _prepare_webhook_execution
# Mock the WebhookService methods
with (
patch.object(WebhookService, "get_webhook_trigger_and_workflow") as mock_get_trigger,
patch.object(WebhookService, "extract_and_validate_webhook_data") as mock_extract,
):
mock_trigger = MagicMock()
mock_workflow = MagicMock()
mock_config = {"data": {"test": "config"}}
mock_data = {"test": "data"}
mock_get_trigger.return_value = (mock_trigger, mock_workflow, mock_config)
mock_extract.return_value = mock_data
result = _prepare_webhook_execution("test_webhook", is_debug=False)
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
# Reset mock
mock_get_trigger.reset_mock()
result = _prepare_webhook_execution("test_webhook", is_debug=True)
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)

View File

@@ -0,0 +1,200 @@
"""Comprehensive unit tests for WorkflowRunService class.
This test suite covers all pause state management operations including:
- Retrieving pause state for workflow runs
- Saving pause state with file uploads
- Marking paused workflows as resumed
- Error handling and edge cases
- Database transaction management
- Repository-based approach testing
"""
from datetime import datetime
from unittest.mock import MagicMock, create_autospec, patch
import pytest
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.enums import WorkflowExecutionStatus
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
from services.workflow_run_service import (
WorkflowRunService,
)
class TestDataFactory:
"""Factory class for creating test data objects."""
@staticmethod
def create_workflow_run_mock(
id: str = "workflow-run-123",
tenant_id: str = "tenant-456",
app_id: str = "app-789",
workflow_id: str = "workflow-101",
status: str | WorkflowExecutionStatus = "paused",
pause_id: str | None = None,
**kwargs,
) -> MagicMock:
"""Create a mock WorkflowRun object."""
mock_run = MagicMock()
mock_run.id = id
mock_run.tenant_id = tenant_id
mock_run.app_id = app_id
mock_run.workflow_id = workflow_id
mock_run.status = status
mock_run.pause_id = pause_id
for key, value in kwargs.items():
setattr(mock_run, key, value)
return mock_run
@staticmethod
def create_workflow_pause_mock(
id: str = "pause-123",
tenant_id: str = "tenant-456",
app_id: str = "app-789",
workflow_id: str = "workflow-101",
workflow_execution_id: str = "workflow-execution-123",
state_file_id: str = "file-456",
resumed_at: datetime | None = None,
**kwargs,
) -> MagicMock:
"""Create a mock WorkflowPauseModel object."""
mock_pause = MagicMock()
mock_pause.id = id
mock_pause.tenant_id = tenant_id
mock_pause.app_id = app_id
mock_pause.workflow_id = workflow_id
mock_pause.workflow_execution_id = workflow_execution_id
mock_pause.state_file_id = state_file_id
mock_pause.resumed_at = resumed_at
for key, value in kwargs.items():
setattr(mock_pause, key, value)
return mock_pause
@staticmethod
def create_upload_file_mock(
id: str = "file-456",
key: str = "upload_files/test/state.json",
name: str = "state.json",
tenant_id: str = "tenant-456",
**kwargs,
) -> MagicMock:
"""Create a mock UploadFile object."""
mock_file = MagicMock()
mock_file.id = id
mock_file.key = key
mock_file.name = name
mock_file.tenant_id = tenant_id
for key, value in kwargs.items():
setattr(mock_file, key, value)
return mock_file
@staticmethod
def create_pause_entity_mock(
pause_model: MagicMock | None = None,
upload_file: MagicMock | None = None,
) -> _PrivateWorkflowPauseEntity:
"""Create a mock _PrivateWorkflowPauseEntity object."""
if pause_model is None:
pause_model = TestDataFactory.create_workflow_pause_mock()
if upload_file is None:
upload_file = TestDataFactory.create_upload_file_mock()
return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file)
class TestWorkflowRunService:
"""Comprehensive unit tests for WorkflowRunService class."""
@pytest.fixture
def mock_session_factory(self):
"""Create a mock session factory with proper session management."""
mock_session = create_autospec(Session)
# Create a mock context manager for the session
mock_session_cm = MagicMock()
mock_session_cm.__enter__ = MagicMock(return_value=mock_session)
mock_session_cm.__exit__ = MagicMock(return_value=None)
# Create a mock context manager for the transaction
mock_transaction_cm = MagicMock()
mock_transaction_cm.__enter__ = MagicMock(return_value=mock_session)
mock_transaction_cm.__exit__ = MagicMock(return_value=None)
mock_session.begin = MagicMock(return_value=mock_transaction_cm)
# Create mock factory that returns the context manager
mock_factory = MagicMock(spec=sessionmaker)
mock_factory.return_value = mock_session_cm
return mock_factory, mock_session
@pytest.fixture
def mock_workflow_run_repository(self):
"""Create a mock APIWorkflowRunRepository."""
mock_repo = create_autospec(APIWorkflowRunRepository)
return mock_repo
@pytest.fixture
def workflow_run_service(self, mock_session_factory, mock_workflow_run_repository):
"""Create WorkflowRunService instance with mocked dependencies."""
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(session_factory)
return service
@pytest.fixture
def workflow_run_service_with_engine(self, mock_session_factory, mock_workflow_run_repository):
"""Create WorkflowRunService instance with Engine input."""
mock_engine = create_autospec(Engine)
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(mock_engine)
return service
# ==================== Initialization Tests ====================
def test_init_with_session_factory(self, mock_session_factory, mock_workflow_run_repository):
"""Test WorkflowRunService initialization with session_factory."""
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(session_factory)
assert service._session_factory == session_factory
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
def test_init_with_engine(self, mock_session_factory, mock_workflow_run_repository):
"""Test WorkflowRunService initialization with Engine (should convert to sessionmaker)."""
mock_engine = create_autospec(Engine)
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker:
service = WorkflowRunService(mock_engine)
mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
assert service._session_factory == session_factory
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
def test_init_with_default_dependencies(self, mock_session_factory):
"""Test WorkflowRunService initialization with default dependencies."""
session_factory, _ = mock_session_factory
service = WorkflowRunService(session_factory)
assert service._session_factory == session_factory

View File

@@ -0,0 +1,253 @@
"""Test cases for MCP tool transformation functionality."""
from unittest.mock import Mock
import pytest
from core.mcp.types import Tool as MCPTool
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
@pytest.fixture
def mock_user():
"""Provides a mock user object."""
user = Mock()
user.name = "Test User"
return user
@pytest.fixture
def mock_provider(mock_user):
"""Provides a mock MCPToolProvider with a loaded user."""
provider = Mock(spec=MCPToolProvider)
provider.load_user.return_value = mock_user
return provider
@pytest.fixture
def mock_provider_no_user():
"""Provides a mock MCPToolProvider with no user."""
provider = Mock(spec=MCPToolProvider)
provider.load_user.return_value = None
return provider
@pytest.fixture
def mock_provider_full(mock_user):
"""Provides a fully configured mock MCPToolProvider for detailed tests."""
provider = Mock(spec=MCPToolProvider)
provider.id = "provider-id-123"
provider.server_identifier = "server-identifier-456"
provider.name = "Test MCP Provider"
provider.provider_icon = "icon.png"
provider.authed = True
provider.masked_server_url = "https://*****.com/mcp"
provider.timeout = 30
provider.sse_read_timeout = 300
provider.masked_headers = {"Authorization": "Bearer *****"}
provider.decrypted_headers = {"Authorization": "Bearer secret-token"}
# Mock timestamp
mock_updated_at = Mock()
mock_updated_at.timestamp.return_value = 1234567890
provider.updated_at = mock_updated_at
provider.load_user.return_value = mock_user
return provider
@pytest.fixture
def sample_mcp_tools():
"""Provides sample MCP tools for testing."""
return {
"simple": MCPTool(
name="simple_tool", description="A simple test tool", inputSchema={"type": "object", "properties": {}}
),
"none_desc": MCPTool(name="tool_none_desc", description=None, inputSchema={"type": "object", "properties": {}}),
"complex": MCPTool(
name="complex_tool",
description="A tool with complex parameters",
inputSchema={
"type": "object",
"properties": {
"text": {"type": "string", "description": "Input text"},
"count": {"type": "integer", "description": "Number of items", "minimum": 1, "maximum": 100},
"options": {"type": "array", "items": {"type": "string"}, "description": "List of options"},
},
"required": ["text"],
},
),
}
class TestMCPToolTransform:
"""Test cases for MCP tool transformation methods."""
def test_mcp_tool_to_user_tool_with_none_description(self, mock_provider):
"""Test that mcp_tool_to_user_tool handles None description correctly."""
# Create MCP tools with None description
tools = [
MCPTool(
name="tool1",
description=None, # This is the case that caused the error
inputSchema={"type": "object", "properties": {}},
),
MCPTool(
name="tool2",
description=None,
inputSchema={
"type": "object",
"properties": {"param1": {"type": "string", "description": "A parameter"}},
},
),
]
# Call the method
result = ToolTransformService.mcp_tool_to_user_tool(mock_provider, tools)
# Verify the result
assert len(result) == 2
assert all(isinstance(tool, ToolApiEntity) for tool in result)
# Check first tool
assert result[0].name == "tool1"
assert result[0].author == "Test User"
assert isinstance(result[0].label, I18nObject)
assert result[0].label.en_US == "tool1"
assert isinstance(result[0].description, I18nObject)
assert result[0].description.en_US == "" # Should be empty string, not None
assert result[0].description.zh_Hans == ""
# Check second tool
assert result[1].name == "tool2"
assert result[1].description.en_US == ""
assert result[1].description.zh_Hans == ""
def test_mcp_tool_to_user_tool_with_description(self, mock_provider):
"""Test that mcp_tool_to_user_tool handles normal description correctly."""
# Create MCP tools with description
tools = [
MCPTool(
name="tool_with_desc",
description="This is a test tool that does something useful",
inputSchema={"type": "object", "properties": {}},
)
]
# Call the method
result = ToolTransformService.mcp_tool_to_user_tool(mock_provider, tools)
# Verify the result
assert len(result) == 1
assert isinstance(result[0], ToolApiEntity)
assert result[0].name == "tool_with_desc"
assert result[0].description.en_US == "This is a test tool that does something useful"
assert result[0].description.zh_Hans == "This is a test tool that does something useful"
def test_mcp_tool_to_user_tool_with_no_user(self, mock_provider_no_user):
"""Test that mcp_tool_to_user_tool handles None user correctly."""
# Create MCP tool
tools = [MCPTool(name="tool1", description="Test tool", inputSchema={"type": "object", "properties": {}})]
# Call the method
result = ToolTransformService.mcp_tool_to_user_tool(mock_provider_no_user, tools)
# Verify the result
assert len(result) == 1
assert result[0].author == "Anonymous"
def test_mcp_tool_to_user_tool_with_complex_schema(self, mock_provider, sample_mcp_tools):
"""Test that mcp_tool_to_user_tool correctly converts complex input schemas."""
# Use complex tool from fixtures
tools = [sample_mcp_tools["complex"]]
# Call the method
result = ToolTransformService.mcp_tool_to_user_tool(mock_provider, tools)
# Verify the result
assert len(result) == 1
assert result[0].name == "complex_tool"
assert result[0].parameters is not None
# The actual parameter conversion is handled by convert_mcp_schema_to_parameter
# which should be tested separately
def test_mcp_provider_to_user_provider_for_list(self, mock_provider_full):
"""Test mcp_provider_to_user_provider with for_list=True."""
# Set tools data with null description
mock_provider_full.tools = '[{"name": "tool1", "description": null, "inputSchema": {}}]'
# Mock the to_entity and to_api_response methods
mock_entity = Mock()
mock_entity.to_api_response.return_value = {
"name": "Test MCP Provider",
"type": ToolProviderType.MCP,
"is_team_authorization": True,
"server_url": "https://*****.com/mcp",
"provider_icon": "icon.png",
"masked_headers": {"Authorization": "Bearer *****"},
"updated_at": 1234567890,
"labels": [],
"author": "Test User",
"description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"),
"icon": "icon.png",
"label": I18nObject(en_US="Test MCP Provider", zh_Hans="Test MCP Provider"),
"masked_credentials": {},
}
mock_provider_full.to_entity.return_value = mock_entity
# Call the method with for_list=True
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=True)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == "provider-id-123" # Should use provider.id when for_list=True
assert result.name == "Test MCP Provider"
assert result.type == ToolProviderType.MCP
assert result.is_team_authorization is True
assert result.server_url == "https://*****.com/mcp"
assert len(result.tools) == 1
assert result.tools[0].description.en_US == "" # Should handle None description
def test_mcp_provider_to_user_provider_not_for_list(self, mock_provider_full):
"""Test mcp_provider_to_user_provider with for_list=False."""
# Set tools data with description
mock_provider_full.tools = '[{"name": "tool1", "description": "Tool description", "inputSchema": {}}]'
# Mock the to_entity and to_api_response methods
mock_entity = Mock()
mock_entity.to_api_response.return_value = {
"name": "Test MCP Provider",
"type": ToolProviderType.MCP,
"is_team_authorization": True,
"server_url": "https://*****.com/mcp",
"provider_icon": "icon.png",
"masked_headers": {"Authorization": "Bearer *****"},
"updated_at": 1234567890,
"labels": [],
"configuration": {"timeout": "30", "sse_read_timeout": "300"},
"original_headers": {"Authorization": "Bearer secret-token"},
"author": "Test User",
"description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"),
"icon": "icon.png",
"label": I18nObject(en_US="Test MCP Provider", zh_Hans="Test MCP Provider"),
"masked_credentials": {},
}
mock_provider_full.to_entity.return_value = mock_entity
# Call the method with for_list=False
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=False)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == "server-identifier-456" # Should use server_identifier when for_list=False
assert result.server_identifier == "server-identifier-456"
assert result.configuration is not None
assert result.configuration.timeout == 30
assert result.configuration.sse_read_timeout == 300
assert result.original_headers == {"Authorization": "Bearer secret-token"}
assert len(result.tools) == 1
assert result.tools[0].description.en_US == "Tool description"

View File

@@ -0,0 +1,301 @@
from unittest.mock import Mock
from core.tools.__base.tool import Tool
from core.tools.entities.api_entities import ToolApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolParameter
from services.tools.tools_transform_service import ToolTransformService
class TestToolTransformService:
"""Test cases for ToolTransformService.convert_tool_entity_to_api_entity method"""
def test_convert_tool_with_parameter_override(self):
"""Test that runtime parameters correctly override base parameters"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
base_param2 = Mock(spec=ToolParameter)
base_param2.name = "param2"
base_param2.form = ToolParameter.ToolParameterForm.FORM
base_param2.type = "string"
base_param2.label = "Base Param 2"
# Create mock runtime parameters that override base parameters
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1" # Different label to verify override
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1, base_param2]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.author == "test_author"
assert result.name == "test_tool"
assert result.parameters is not None
assert len(result.parameters) == 2
# Find the overridden parameter
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
assert overridden_param is not None
assert overridden_param.label == "Runtime Param 1" # Should be runtime version
# Find the non-overridden parameter
original_param = next((p for p in result.parameters if p.name == "param2"), None)
assert original_param is not None
assert original_param.label == "Base Param 2" # Should be base version
def test_convert_tool_with_additional_runtime_parameters(self):
"""Test that additional runtime parameters are added to the final list"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
# Create mock runtime parameters - one that overrides and one that's new
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1"
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "runtime_only"
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
runtime_param2.type = "string"
runtime_param2.label = "Runtime Only Param"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 2
# Check that both parameters are present
param_names = [p.name for p in result.parameters]
assert "param1" in param_names
assert "runtime_only" in param_names
# Verify the overridden parameter has runtime version
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
assert overridden_param is not None
assert overridden_param.label == "Runtime Param 1"
# Verify the new runtime parameter is included
new_param = next((p for p in result.parameters if p.name == "runtime_only"), None)
assert new_param is not None
assert new_param.label == "Runtime Only Param"
def test_convert_tool_with_non_form_runtime_parameters(self):
"""Test that non-FORM runtime parameters are not added as new parameters"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
# Create mock runtime parameters with different forms
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1"
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "llm_param"
runtime_param2.form = ToolParameter.ToolParameterForm.LLM
runtime_param2.type = "string"
runtime_param2.label = "LLM Param"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 1 # Only the FORM parameter should be present
# Check that only the FORM parameter is present
param_names = [p.name for p in result.parameters]
assert "param1" in param_names
assert "llm_param" not in param_names
def test_convert_tool_with_empty_parameters(self):
"""Test conversion with empty base and runtime parameters"""
# Create mock tool with no parameters
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = []
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = []
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 0
def test_convert_tool_with_none_parameters(self):
"""Test conversion when base parameters is None"""
# Create mock tool with None parameters
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = None
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = []
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 0
def test_convert_tool_parameter_order_preserved(self):
"""Test that parameter order is preserved correctly"""
# Create mock base parameters in specific order
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
base_param2 = Mock(spec=ToolParameter)
base_param2.name = "param2"
base_param2.form = ToolParameter.ToolParameterForm.FORM
base_param2.type = "string"
base_param2.label = "Base Param 2"
base_param3 = Mock(spec=ToolParameter)
base_param3.name = "param3"
base_param3.form = ToolParameter.ToolParameterForm.FORM
base_param3.type = "string"
base_param3.label = "Base Param 3"
# Create runtime parameter that overrides middle parameter
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "param2"
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
runtime_param2.type = "string"
runtime_param2.label = "Runtime Param 2"
# Create new runtime parameter
runtime_param4 = Mock(spec=ToolParameter)
runtime_param4.name = "param4"
runtime_param4.form = ToolParameter.ToolParameterForm.FORM
runtime_param4.type = "string"
runtime_param4.label = "Runtime Param 4"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1, base_param2, base_param3]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 4
# Check that order is maintained: base parameters first, then new runtime parameters
param_names = [p.name for p in result.parameters]
assert param_names == ["param1", "param2", "param3", "param4"]
# Verify that param2 was overridden with runtime version
param2 = result.parameters[1]
assert param2.name == "param2"
assert param2.label == "Runtime Param 2"

View File

@@ -0,0 +1,377 @@
"""Simplified unit tests for DraftVarLoader focusing on core functionality."""
import json
from unittest.mock import Mock, patch
import pytest
from sqlalchemy import Engine
from core.variables.segments import ObjectSegment, StringSegment
from core.variables.types import SegmentType
from models.model import UploadFile
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
from services.workflow_draft_variable_service import DraftVarLoader
class TestDraftVarLoaderSimple:
"""Simplified unit tests for DraftVarLoader core methods."""
@pytest.fixture
def mock_engine(self) -> Engine:
return Mock(spec=Engine)
@pytest.fixture
def draft_var_loader(self, mock_engine):
"""Create DraftVarLoader instance for testing."""
return DraftVarLoader(
engine=mock_engine, app_id="test-app-id", tenant_id="test-tenant-id", fallback_variables=[]
)
def test_load_offloaded_variable_string_type_unit(self, draft_var_loader):
"""Test _load_offloaded_variable with string type - isolated unit test."""
# Create mock objects
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/test.txt"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.STRING
variable_file.upload_file = upload_file
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.id = "draft-var-id"
draft_var.node_id = "test-node-id"
draft_var.name = "test_variable"
draft_var.description = "test description"
draft_var.get_selector.return_value = ["test-node-id", "test_variable"]
draft_var.variable_file = variable_file
test_content = "This is the full string content"
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = test_content.encode()
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
mock_variable = Mock()
mock_variable.id = "draft-var-id"
mock_variable.name = "test_variable"
mock_variable.value = StringSegment(value=test_content)
mock_segment_to_variable.return_value = mock_variable
# Execute the method
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
# Verify results
assert selector_tuple == ("test-node-id", "test_variable")
assert variable.id == "draft-var-id"
assert variable.name == "test_variable"
assert variable.description == "test description"
assert variable.value == test_content
# Verify storage was called correctly
mock_storage.load.assert_called_once_with("storage/key/test.txt")
def test_load_offloaded_variable_object_type_unit(self, draft_var_loader):
"""Test _load_offloaded_variable with object type - isolated unit test."""
# Create mock objects
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/test.json"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.OBJECT
variable_file.upload_file = upload_file
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.id = "draft-var-id"
draft_var.node_id = "test-node-id"
draft_var.name = "test_object"
draft_var.description = "test description"
draft_var.get_selector.return_value = ["test-node-id", "test_object"]
draft_var.variable_file = variable_file
test_object = {"key1": "value1", "key2": 42}
test_json_content = json.dumps(test_object, ensure_ascii=False, separators=(",", ":"))
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = test_json_content.encode()
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
mock_segment = ObjectSegment(value=test_object)
mock_build_segment.return_value = mock_segment
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
mock_variable = Mock()
mock_variable.id = "draft-var-id"
mock_variable.name = "test_object"
mock_variable.value = mock_segment
mock_segment_to_variable.return_value = mock_variable
# Execute the method
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
# Verify results
assert selector_tuple == ("test-node-id", "test_object")
assert variable.id == "draft-var-id"
assert variable.name == "test_object"
assert variable.description == "test description"
assert variable.value == test_object
# Verify method calls
mock_storage.load.assert_called_once_with("storage/key/test.json")
mock_build_segment.assert_called_once_with(SegmentType.OBJECT, test_object)
def test_load_offloaded_variable_missing_variable_file_unit(self, draft_var_loader):
"""Test that assertion error is raised when variable_file is None."""
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.variable_file = None
with pytest.raises(AssertionError):
draft_var_loader._load_offloaded_variable(draft_var)
def test_load_offloaded_variable_missing_upload_file_unit(self, draft_var_loader):
"""Test that assertion error is raised when upload_file is None."""
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.upload_file = None
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.variable_file = variable_file
with pytest.raises(AssertionError):
draft_var_loader._load_offloaded_variable(draft_var)
def test_load_variables_empty_selectors_unit(self, draft_var_loader):
"""Test load_variables returns empty list for empty selectors."""
result = draft_var_loader.load_variables([])
assert result == []
def test_selector_to_tuple_unit(self, draft_var_loader):
"""Test _selector_to_tuple method."""
selector = ["node_id", "var_name", "extra_field"]
result = draft_var_loader._selector_to_tuple(selector)
assert result == ("node_id", "var_name")
def test_load_offloaded_variable_number_type_unit(self, draft_var_loader):
"""Test _load_offloaded_variable with number type - isolated unit test."""
# Create mock objects
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/test_number.json"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.NUMBER
variable_file.upload_file = upload_file
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.id = "draft-var-id"
draft_var.node_id = "test-node-id"
draft_var.name = "test_number"
draft_var.description = "test number description"
draft_var.get_selector.return_value = ["test-node-id", "test_number"]
draft_var.variable_file = variable_file
test_number = 123.45
test_json_content = json.dumps(test_number)
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = test_json_content.encode()
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
from core.variables.segments import FloatSegment
mock_segment = FloatSegment(value=test_number)
mock_build_segment.return_value = mock_segment
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
mock_variable = Mock()
mock_variable.id = "draft-var-id"
mock_variable.name = "test_number"
mock_variable.value = mock_segment
mock_segment_to_variable.return_value = mock_variable
# Execute the method
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
# Verify results
assert selector_tuple == ("test-node-id", "test_number")
assert variable.id == "draft-var-id"
assert variable.name == "test_number"
assert variable.description == "test number description"
# Verify method calls
mock_storage.load.assert_called_once_with("storage/key/test_number.json")
mock_build_segment.assert_called_once_with(SegmentType.NUMBER, test_number)
def test_load_offloaded_variable_array_type_unit(self, draft_var_loader):
"""Test _load_offloaded_variable with array type - isolated unit test."""
# Create mock objects
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/test_array.json"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.ARRAY_ANY
variable_file.upload_file = upload_file
draft_var = Mock(spec=WorkflowDraftVariable)
draft_var.id = "draft-var-id"
draft_var.node_id = "test-node-id"
draft_var.name = "test_array"
draft_var.description = "test array description"
draft_var.get_selector.return_value = ["test-node-id", "test_array"]
draft_var.variable_file = variable_file
test_array = ["item1", "item2", "item3"]
test_json_content = json.dumps(test_array)
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = test_json_content.encode()
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
from core.variables.segments import ArrayAnySegment
mock_segment = ArrayAnySegment(value=test_array)
mock_build_segment.return_value = mock_segment
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
mock_variable = Mock()
mock_variable.id = "draft-var-id"
mock_variable.name = "test_array"
mock_variable.value = mock_segment
mock_segment_to_variable.return_value = mock_variable
# Execute the method
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
# Verify results
assert selector_tuple == ("test-node-id", "test_array")
assert variable.id == "draft-var-id"
assert variable.name == "test_array"
assert variable.description == "test array description"
# Verify method calls
mock_storage.load.assert_called_once_with("storage/key/test_array.json")
mock_build_segment.assert_called_once_with(SegmentType.ARRAY_ANY, test_array)
def test_load_variables_with_offloaded_variables_unit(self, draft_var_loader):
"""Test load_variables method with mix of regular and offloaded variables."""
selectors = [["node1", "regular_var"], ["node2", "offloaded_var"]]
# Mock regular variable
regular_draft_var = Mock(spec=WorkflowDraftVariable)
regular_draft_var.is_truncated.return_value = False
regular_draft_var.node_id = "node1"
regular_draft_var.name = "regular_var"
regular_draft_var.get_value.return_value = StringSegment(value="regular_value")
regular_draft_var.get_selector.return_value = ["node1", "regular_var"]
regular_draft_var.id = "regular-var-id"
regular_draft_var.description = "regular description"
# Mock offloaded variable
upload_file = Mock(spec=UploadFile)
upload_file.key = "storage/key/offloaded.txt"
variable_file = Mock(spec=WorkflowDraftVariableFile)
variable_file.value_type = SegmentType.STRING
variable_file.upload_file = upload_file
offloaded_draft_var = Mock(spec=WorkflowDraftVariable)
offloaded_draft_var.is_truncated.return_value = True
offloaded_draft_var.node_id = "node2"
offloaded_draft_var.name = "offloaded_var"
offloaded_draft_var.get_selector.return_value = ["node2", "offloaded_var"]
offloaded_draft_var.variable_file = variable_file
offloaded_draft_var.id = "offloaded-var-id"
offloaded_draft_var.description = "offloaded description"
draft_vars = [regular_draft_var, offloaded_draft_var]
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
mock_session = Mock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_service = Mock()
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
with patch(
"services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service
):
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
# Mock regular variable creation
regular_variable = Mock()
regular_variable.selector = ["node1", "regular_var"]
# Mock offloaded variable creation
offloaded_variable = Mock()
offloaded_variable.selector = ["node2", "offloaded_var"]
mock_segment_to_variable.return_value = regular_variable
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
mock_storage.load.return_value = b"offloaded_content"
with patch.object(draft_var_loader, "_load_offloaded_variable") as mock_load_offloaded:
mock_load_offloaded.return_value = (("node2", "offloaded_var"), offloaded_variable)
with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor_cls:
mock_executor = Mock()
mock_executor_cls.return_value.__enter__.return_value = mock_executor
mock_executor.map.return_value = [(("node2", "offloaded_var"), offloaded_variable)]
# Execute the method
result = draft_var_loader.load_variables(selectors)
# Verify results
assert len(result) == 2
# Verify service method was called
mock_service.get_draft_variables_by_selectors.assert_called_once_with(
draft_var_loader._app_id, selectors
)
# Verify offloaded variable loading was called
mock_load_offloaded.assert_called_once_with(offloaded_draft_var)
def test_load_variables_all_offloaded_variables_unit(self, draft_var_loader):
"""Test load_variables method with only offloaded variables."""
selectors = [["node1", "offloaded_var1"], ["node2", "offloaded_var2"]]
# Mock first offloaded variable
offloaded_var1 = Mock(spec=WorkflowDraftVariable)
offloaded_var1.is_truncated.return_value = True
offloaded_var1.node_id = "node1"
offloaded_var1.name = "offloaded_var1"
# Mock second offloaded variable
offloaded_var2 = Mock(spec=WorkflowDraftVariable)
offloaded_var2.is_truncated.return_value = True
offloaded_var2.node_id = "node2"
offloaded_var2.name = "offloaded_var2"
draft_vars = [offloaded_var1, offloaded_var2]
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
mock_session = Mock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_service = Mock()
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
with patch(
"services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service
):
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
with patch("services.workflow_draft_variable_service.ThreadPoolExecutor") as mock_executor_cls:
mock_executor = Mock()
mock_executor_cls.return_value.__enter__.return_value = mock_executor
mock_executor.map.return_value = [
(("node1", "offloaded_var1"), Mock()),
(("node2", "offloaded_var2"), Mock()),
]
# Execute the method
result = draft_var_loader.load_variables(selectors)
# Verify results - since we have only offloaded variables, should have 2 results
assert len(result) == 2
# Verify ThreadPoolExecutor was used
mock_executor_cls.assert_called_once_with(max_workers=10)
mock_executor.map.assert_called_once()

View File

@@ -0,0 +1,432 @@
# test for api/services/workflow/workflow_converter.py
import json
from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import (
AdvancedChatMessageEntity,
AdvancedChatPromptTemplateEntity,
AdvancedCompletionPromptTemplateEntity,
DatasetEntity,
DatasetRetrieveConfigEntity,
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
VariableEntityType,
)
from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from models.model import AppMode
from services.workflow.workflow_converter import WorkflowConverter
@pytest.fixture
def default_variables():
value = [
VariableEntity(
variable="text_input",
label="text-input",
type=VariableEntityType.TEXT_INPUT,
),
VariableEntity(
variable="paragraph",
label="paragraph",
type=VariableEntityType.PARAGRAPH,
),
VariableEntity(
variable="select",
label="select",
type=VariableEntityType.SELECT,
),
]
return value
def test__convert_to_start_node(default_variables):
# act
result = WorkflowConverter()._convert_to_start_node(default_variables)
# assert
assert isinstance(result["data"]["variables"][0]["type"], str)
assert result["data"]["variables"][0]["type"] == "text-input"
assert result["data"]["variables"][0]["variable"] == "text_input"
assert result["data"]["variables"][1]["variable"] == "paragraph"
assert result["data"]["variables"][2]["variable"] == "select"
def test__convert_to_http_request_node_for_chatbot(default_variables):
"""
Test convert to http request nodes for chatbot
:return:
"""
app_model = MagicMock()
app_model.id = "app_id"
app_model.tenant_id = "tenant_id"
app_model.mode = AppMode.CHAT
api_based_extension_id = "api_based_extension_id"
mock_api_based_extension = APIBasedExtension(
tenant_id="tenant_id",
name="api-1",
api_key="encrypted_api_key",
api_endpoint="https://dify.ai",
)
mock_api_based_extension.id = api_based_extension_id
workflow_converter = WorkflowConverter()
workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension)
encrypter.decrypt_token = MagicMock(return_value="api_key")
external_data_variables = [
ExternalDataVariableEntity(
variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id}
)
]
nodes, _ = workflow_converter._convert_to_http_request_node(
app_model=app_model, variables=default_variables, external_data_variables=external_data_variables
)
assert len(nodes) == 2
assert nodes[0]["data"]["type"] == "http-request"
http_request_node = nodes[0]
assert http_request_node["data"]["method"] == "post"
assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint
assert http_request_node["data"]["authorization"]["type"] == "api-key"
assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"}
assert http_request_node["data"]["body"]["type"] == "json"
body_data = http_request_node["data"]["body"]["data"]
assert body_data
body_data_json = json.loads(body_data)
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
body_params = body_data_json["params"]
assert body_params["app_id"] == app_model.id
assert body_params["tool_variable"] == external_data_variables[0].variable
assert len(body_params["inputs"]) == 3
assert body_params["query"] == "{{#sys.query#}}" # for chatbot
code_node = nodes[1]
assert code_node["data"]["type"] == "code"
def test__convert_to_http_request_node_for_workflow_app(default_variables):
"""
Test convert to http request nodes for workflow app
:return:
"""
app_model = MagicMock()
app_model.id = "app_id"
app_model.tenant_id = "tenant_id"
app_model.mode = AppMode.WORKFLOW
api_based_extension_id = "api_based_extension_id"
mock_api_based_extension = APIBasedExtension(
tenant_id="tenant_id",
name="api-1",
api_key="encrypted_api_key",
api_endpoint="https://dify.ai",
)
mock_api_based_extension.id = api_based_extension_id
workflow_converter = WorkflowConverter()
workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension)
encrypter.decrypt_token = MagicMock(return_value="api_key")
external_data_variables = [
ExternalDataVariableEntity(
variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id}
)
]
nodes, _ = workflow_converter._convert_to_http_request_node(
app_model=app_model, variables=default_variables, external_data_variables=external_data_variables
)
assert len(nodes) == 2
assert nodes[0]["data"]["type"] == "http-request"
http_request_node = nodes[0]
assert http_request_node["data"]["method"] == "post"
assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint
assert http_request_node["data"]["authorization"]["type"] == "api-key"
assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"}
assert http_request_node["data"]["body"]["type"] == "json"
body_data = http_request_node["data"]["body"]["data"]
assert body_data
body_data_json = json.loads(body_data)
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
body_params = body_data_json["params"]
assert body_params["app_id"] == app_model.id
assert body_params["tool_variable"] == external_data_variables[0].variable
assert len(body_params["inputs"]) == 3
assert body_params["query"] == ""
code_node = nodes[1]
assert code_node["data"]["type"] == "code"
def test__convert_to_knowledge_retrieval_node_for_chatbot():
new_app_mode = AppMode.ADVANCED_CHAT
dataset_config = DatasetEntity(
dataset_ids=["dataset_id_1", "dataset_id_2"],
retrieve_config=DatasetRetrieveConfigEntity(
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=5,
score_threshold=0.8,
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
reranking_enabled=True,
),
)
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config
)
assert node is not None
assert node["data"]["type"] == "knowledge-retrieval"
assert node["data"]["query_variable_selector"] == ["sys", "query"]
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value
assert node["data"]["multiple_retrieval_config"] == {
"top_k": dataset_config.retrieve_config.top_k,
"score_threshold": dataset_config.retrieve_config.score_threshold,
"reranking_model": dataset_config.retrieve_config.reranking_model,
}
def test__convert_to_knowledge_retrieval_node_for_workflow_app():
new_app_mode = AppMode.WORKFLOW
dataset_config = DatasetEntity(
dataset_ids=["dataset_id_1", "dataset_id_2"],
retrieve_config=DatasetRetrieveConfigEntity(
query_variable="query",
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=5,
score_threshold=0.8,
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
reranking_enabled=True,
),
)
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config
)
assert node is not None
assert node["data"]["type"] == "knowledge-retrieval"
assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable]
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value
assert node["data"]["multiple_retrieval_config"] == {
"top_k": dataset_config.retrieve_config.top_k,
"score_threshold": dataset_config.retrieve_config.score_threshold,
"reranking_model": dataset_config.retrieve_config.reranking_model,
}
def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
new_app_mode = AppMode.ADVANCED_CHAT
model = "gpt-4"
model_mode = LLMMode.CHAT
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
model_config_mock.stop = []
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.",
)
llm_node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
template = prompt_template.simple_prompt_template
assert template is not None
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n"
assert llm_node["data"]["context"]["enabled"] is False
def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables):
new_app_mode = AppMode.ADVANCED_CHAT
model = "gpt-3.5-turbo-instruct"
model_mode = LLMMode.COMPLETION
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
model_config_mock.stop = []
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.",
)
llm_node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
template = prompt_template.simple_prompt_template
assert template is not None
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"]["text"] == template + "\n"
assert llm_node["data"]["context"]["enabled"] is False
def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables):
new_app_mode = AppMode.ADVANCED_CHAT
model = "gpt-4"
model_mode = LLMMode.CHAT
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
model_config_mock.stop = []
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
messages=[
AdvancedChatMessageEntity(
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
role=PromptMessageRole.SYSTEM,
),
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
]
),
)
llm_node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
assert isinstance(llm_node["data"]["prompt_template"], list)
assert prompt_template.advanced_chat_prompt_template is not None
assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages)
template = prompt_template.advanced_chat_prompt_template.messages[0].text
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"][0]["text"] == template
def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables):
new_app_mode = AppMode.ADVANCED_CHAT
model = "gpt-3.5-turbo-instruct"
model_mode = LLMMode.COMPLETION
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
model_config_mock.stop = []
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ",
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"),
),
)
llm_node = workflow_converter._convert_to_llm_node(
original_app_mode=AppMode.CHAT,
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
assert isinstance(llm_node["data"]["prompt_template"], dict)
assert prompt_template.advanced_completion_prompt_template is not None
template = prompt_template.advanced_completion_prompt_template.prompt
for v in default_variables:
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"]["text"] == template

View File

@@ -0,0 +1,127 @@
from unittest.mock import MagicMock
import pytest
from sqlalchemy.orm import Session
from models.model import App
from models.workflow import Workflow
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
@pytest.fixture
def workflow_setup():
mock_session_maker = MagicMock()
workflow_service = WorkflowService(mock_session_maker)
session = MagicMock(spec=Session)
tenant_id = "test-tenant-id"
workflow_id = "test-workflow-id"
# Mock workflow
workflow = MagicMock(spec=Workflow)
workflow.id = workflow_id
workflow.tenant_id = tenant_id
workflow.version = "1.0" # Not a draft
workflow.tool_published = False # Not published as a tool by default
# Mock app
app = MagicMock(spec=App)
app.id = "test-app-id"
app.name = "Test App"
app.workflow_id = None # Not used by an app by default
return {
"workflow_service": workflow_service,
"session": session,
"tenant_id": tenant_id,
"workflow_id": workflow_id,
"workflow": workflow,
"app": app,
}
def test_delete_workflow_success(workflow_setup):
# Setup mocks
# Mock the tool provider query to return None (not published as a tool)
workflow_setup["session"].query.return_value.where.return_value.first.return_value = None
workflow_setup["session"].scalar = MagicMock(
side_effect=[workflow_setup["workflow"], None]
) # Return workflow first, then None for app
# Call the method
result = workflow_setup["workflow_service"].delete_workflow(
session=workflow_setup["session"],
workflow_id=workflow_setup["workflow_id"],
tenant_id=workflow_setup["tenant_id"],
)
# Verify
assert result is True
workflow_setup["session"].delete.assert_called_once_with(workflow_setup["workflow"])
def test_delete_workflow_draft_error(workflow_setup):
# Setup mocks
workflow_setup["workflow"].version = "draft"
workflow_setup["session"].scalar = MagicMock(return_value=workflow_setup["workflow"])
# Call the method and verify exception
with pytest.raises(DraftWorkflowDeletionError):
workflow_setup["workflow_service"].delete_workflow(
session=workflow_setup["session"],
workflow_id=workflow_setup["workflow_id"],
tenant_id=workflow_setup["tenant_id"],
)
# Verify
workflow_setup["session"].delete.assert_not_called()
def test_delete_workflow_in_use_by_app_error(workflow_setup):
# Setup mocks
workflow_setup["app"].workflow_id = workflow_setup["workflow_id"]
workflow_setup["session"].scalar = MagicMock(
side_effect=[workflow_setup["workflow"], workflow_setup["app"]]
) # Return workflow first, then app
# Call the method and verify exception
with pytest.raises(WorkflowInUseError) as excinfo:
workflow_setup["workflow_service"].delete_workflow(
session=workflow_setup["session"],
workflow_id=workflow_setup["workflow_id"],
tenant_id=workflow_setup["tenant_id"],
)
# Verify error message contains app name
assert "Cannot delete workflow that is currently in use by app" in str(excinfo.value)
# Verify
workflow_setup["session"].delete.assert_not_called()
def test_delete_workflow_published_as_tool_error(workflow_setup):
# Setup mocks
from models.tools import WorkflowToolProvider
# Mock the tool provider query
mock_tool_provider = MagicMock(spec=WorkflowToolProvider)
workflow_setup["session"].query.return_value.where.return_value.first.return_value = mock_tool_provider
workflow_setup["session"].scalar = MagicMock(
side_effect=[workflow_setup["workflow"], None]
) # Return workflow first, then None for app
# Call the method and verify exception
with pytest.raises(WorkflowInUseError) as excinfo:
workflow_setup["workflow_service"].delete_workflow(
session=workflow_setup["session"],
workflow_id=workflow_setup["workflow_id"],
tenant_id=workflow_setup["tenant_id"],
)
# Verify error message
assert "Cannot delete workflow that is published as a tool" in str(excinfo.value)
# Verify
workflow_setup["session"].delete.assert_not_called()

View File

@@ -0,0 +1,477 @@
import dataclasses
import secrets
import uuid
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy import Engine
from sqlalchemy.orm import Session
from core.variables.segments import StringSegment
from core.variables.types import SegmentType
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import NodeType
from libs.uuid_utils import uuidv7
from models.account import Account
from models.enums import DraftVariableType
from models.workflow import (
Workflow,
WorkflowDraftVariable,
WorkflowDraftVariableFile,
WorkflowNodeExecutionModel,
is_system_variable_editable,
)
from services.workflow_draft_variable_service import (
DraftVariableSaver,
VariableResetError,
WorkflowDraftVariableService,
)
@pytest.fixture
def mock_engine() -> Engine:
return Mock(spec=Engine)
@pytest.fixture
def mock_session(mock_engine) -> Session:
mock_session = Mock(spec=Session)
mock_session.get_bind.return_value = mock_engine
return mock_session
class TestDraftVariableSaver:
def _get_test_app_id(self):
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def test__should_variable_be_visible(self):
mock_session = MagicMock(spec=Session)
mock_user = Account(name="test", email="test@example.com")
mock_user.id = str(uuid.uuid4())
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
app_id=test_app_id,
node_id="test_node_id",
node_type=NodeType.START,
node_execution_id="test_execution_id",
user=mock_user,
)
assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False
assert saver._should_variable_be_visible("123", NodeType.START, "output") == True
def test__normalize_variable_for_start_node(self):
@dataclasses.dataclass(frozen=True)
class TestCase:
name: str
input_node_id: str
input_name: str
expected_node_id: str
expected_name: str
_NODE_ID = "1747228642872"
cases = [
TestCase(
name="name with `sys.` prefix should return the system node_id",
input_node_id=_NODE_ID,
input_name="sys.workflow_id",
expected_node_id=SYSTEM_VARIABLE_NODE_ID,
expected_name="workflow_id",
),
TestCase(
name="name without `sys.` prefix should return the original input node_id",
input_node_id=_NODE_ID,
input_name="start_input",
expected_node_id=_NODE_ID,
expected_name="start_input",
),
TestCase(
name="dummy_variable should return the original input node_id",
input_node_id=_NODE_ID,
input_name="__dummy__",
expected_node_id=_NODE_ID,
expected_name="__dummy__",
),
]
mock_session = MagicMock(spec=Session)
mock_user = MagicMock()
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
app_id=test_app_id,
node_id=_NODE_ID,
node_type=NodeType.START,
node_execution_id="test_execution_id",
user=mock_user,
)
for idx, c in enumerate(cases, 1):
fail_msg = f"Test case {c.name} failed, index={idx}"
node_id, name = saver._normalize_variable_for_start_node(c.input_name)
assert node_id == c.expected_node_id, fail_msg
assert name == c.expected_name, fail_msg
@pytest.fixture
def mock_session(self):
"""Mock SQLAlchemy session."""
from sqlalchemy import Engine
mock_session = MagicMock(spec=Session)
mock_engine = MagicMock(spec=Engine)
mock_session.get_bind.return_value = mock_engine
return mock_session
@pytest.fixture
def draft_saver(self, mock_session):
"""Create DraftVariableSaver instance with user context."""
# Create a mock user
mock_user = MagicMock(spec=Account)
mock_user.id = "test-user-id"
mock_user.tenant_id = "test-tenant-id"
return DraftVariableSaver(
session=mock_session,
app_id="test-app-id",
node_id="test-node-id",
node_type=NodeType.LLM,
node_execution_id="test-execution-id",
user=mock_user,
)
def test_draft_saver_with_small_variables(self, draft_saver, mock_session):
with patch(
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
) as _mock_try_offload:
_mock_try_offload.return_value = None
mock_segment = StringSegment(value="small value")
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
# Should not have large variable metadata
assert draft_var.file_id is None
_mock_try_offload.return_value = None
def test_draft_saver_with_large_variables(self, draft_saver, mock_session):
with patch(
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
) as _mock_try_offload:
mock_segment = StringSegment(value="small value")
mock_draft_var_file = WorkflowDraftVariableFile(
id=str(uuidv7()),
size=1024,
length=10,
value_type=SegmentType.ARRAY_STRING,
upload_file_id=str(uuid.uuid4()),
)
_mock_try_offload.return_value = mock_segment, mock_draft_var_file
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
# Should not have large variable metadata
assert draft_var.file_id == mock_draft_var_file.id
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable")
def test_save_method_integration(self, mock_batch_upsert, draft_saver):
"""Test complete save workflow."""
outputs = {"result": {"data": "test_output"}, "metadata": {"type": "llm_response"}}
draft_saver.save(outputs=outputs)
# Should batch upsert draft variables
mock_batch_upsert.assert_called_once()
draft_vars = mock_batch_upsert.call_args[0][1]
assert len(draft_vars) == 2
class TestWorkflowDraftVariableService:
def _get_test_app_id(self):
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def _create_test_workflow(self, app_id: str) -> Workflow:
"""Create a real Workflow instance for testing"""
return Workflow.new(
tenant_id="test_tenant_id",
app_id=app_id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features="{}",
created_by="test_user_id",
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
def test_reset_conversation_variable(self, mock_session):
"""Test resetting a conversation variable"""
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create real conversation variable
test_value = StringSegment(value="test_value")
variable = WorkflowDraftVariable.new_conversation_variable(
app_id=test_app_id, name="test_var", value=test_value, description="Test conversation variable"
)
# Mock the _reset_conv_var method
expected_result = WorkflowDraftVariable.new_conversation_variable(
app_id=test_app_id,
name="test_var",
value=StringSegment(value="reset_value"),
)
with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv:
result = service.reset_variable(workflow, variable)
mock_reset_conv.assert_called_once_with(workflow, variable)
assert result == expected_result
def test_reset_node_variable_with_no_execution_id(self, mock_session):
"""Test resetting a node variable with no execution ID - should delete variable"""
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create real node variable with no execution ID
test_value = StringSegment(value="test_value")
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id,
node_id="test_node_id",
name="test_var",
value=test_value,
node_execution_id="exec-id", # Set initially
)
# Manually set to None to simulate the test condition
variable.node_execution_id = None
result = service._reset_node_var_or_sys_var(workflow, variable)
# Should delete the variable and return None
mock_session.delete.assert_called_once_with(instance=variable)
mock_session.flush.assert_called_once()
assert result is None
def test_reset_node_variable_with_missing_execution_record(
self,
mock_engine,
mock_session,
monkeypatch,
):
"""Test resetting a node variable when execution record doesn't exist"""
mock_repo_session = Mock(spec=Session)
mock_session_maker = MagicMock()
# Mock the context manager protocol for sessionmaker
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
mock_session_maker.return_value.__exit__.return_value = None
monkeypatch.setattr("services.workflow_draft_variable_service.sessionmaker", mock_session_maker)
service = WorkflowDraftVariableService(mock_session)
# Mock the repository to return None (no execution record found)
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = None
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create real node variable with execution ID
test_value = StringSegment(value="test_value")
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
# Variable is editable by default from factory method
result = service._reset_node_var_or_sys_var(workflow, variable)
mock_session_maker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
# Should delete the variable and return None
mock_session.delete.assert_called_once_with(instance=variable)
mock_session.flush.assert_called_once()
assert result is None
def test_reset_node_variable_with_valid_execution_record(
self,
mock_session,
monkeypatch,
):
"""Test resetting a node variable with valid execution record - should restore from execution"""
mock_repo_session = Mock(spec=Session)
mock_session_maker = MagicMock()
# Mock the context manager protocol for sessionmaker
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
mock_session_maker.return_value.__exit__.return_value = None
mock_session_maker = monkeypatch.setattr(
"services.workflow_draft_variable_service.sessionmaker", mock_session_maker
)
service = WorkflowDraftVariableService(mock_session)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.load_full_outputs.return_value = {"test_var": "output_value"}
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create real node variable with execution ID
test_value = StringSegment(value="original_value")
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
# Variable is editable by default from factory method
# Mock workflow methods
mock_node_config = {"type": "test_node"}
with (
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config),
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM),
):
result = service._reset_node_var_or_sys_var(workflow, variable)
# Verify last_edited_at was reset
assert variable.last_edited_at is None
# Verify session.flush was called
mock_session.flush.assert_called()
# Should return the updated variable
assert result == variable
def test_reset_non_editable_system_variable_raises_error(self, mock_session):
"""Test that resetting a non-editable system variable raises an error"""
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create a non-editable system variable (workflow_id is not editable)
test_value = StringSegment(value="test_workflow_id")
variable = WorkflowDraftVariable.new_sys_variable(
app_id=test_app_id,
name="workflow_id", # This is not in _EDITABLE_SYSTEM_VARIABLE
value=test_value,
node_execution_id="exec-id",
editable=False, # Non-editable system variable
)
with pytest.raises(VariableResetError) as exc_info:
service.reset_variable(workflow, variable)
assert "cannot reset system variable" in str(exc_info.value)
assert f"variable_id={variable.id}" in str(exc_info.value)
def test_reset_editable_system_variable_succeeds(self, mock_session):
"""Test that resetting an editable system variable succeeds"""
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create an editable system variable (files is editable)
test_value = StringSegment(value="[]")
variable = WorkflowDraftVariable.new_sys_variable(
app_id=test_app_id,
name="files", # This is in _EDITABLE_SYSTEM_VARIABLE
value=test_value,
node_execution_id="exec-id",
editable=True, # Editable system variable
)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.load_full_outputs.return_value = {"sys.files": "[]"}
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
result = service._reset_node_var_or_sys_var(workflow, variable)
# Should succeed and return the variable
assert result == variable
assert variable.last_edited_at is None
mock_session.flush.assert_called()
def test_reset_query_system_variable_succeeds(self, mock_session):
"""Test that resetting query system variable (another editable one) succeeds"""
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
# Create an editable system variable (query is editable)
test_value = StringSegment(value="original query")
variable = WorkflowDraftVariable.new_sys_variable(
app_id=test_app_id,
name="query", # This is in _EDITABLE_SYSTEM_VARIABLE
value=test_value,
node_execution_id="exec-id",
editable=True, # Editable system variable
)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.load_full_outputs.return_value = {"sys.query": "reset query"}
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
result = service._reset_node_var_or_sys_var(workflow, variable)
# Should succeed and return the variable
assert result == variable
assert variable.last_edited_at is None
mock_session.flush.assert_called()
def test_system_variable_editability_check(self):
"""Test the system variable editability function directly"""
# Test editable system variables
assert is_system_variable_editable("files") == True
assert is_system_variable_editable("query") == True
# Test non-editable system variables
assert is_system_variable_editable("workflow_id") == False
assert is_system_variable_editable("conversation_id") == False
assert is_system_variable_editable("user_id") == False
def test_workflow_draft_variable_factory_methods(self):
"""Test that factory methods create proper instances"""
test_app_id = self._get_test_app_id()
test_value = StringSegment(value="test_value")
# Test conversation variable factory
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=test_app_id, name="conv_var", value=test_value, description="Test conversation variable"
)
assert conv_var.get_variable_type() == DraftVariableType.CONVERSATION
assert conv_var.editable == True
assert conv_var.node_execution_id is None
# Test system variable factory
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=test_app_id, name="workflow_id", value=test_value, node_execution_id="exec-id", editable=False
)
assert sys_var.get_variable_type() == DraftVariableType.SYS
assert sys_var.editable == False
assert sys_var.node_execution_id == "exec-id"
# Test node variable factory
node_var = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id,
node_id="node-id",
name="node_var",
value=test_value,
node_execution_id="exec-id",
visible=True,
editable=True,
)
assert node_var.get_variable_type() == DraftVariableType.NODE
assert node_var.visible == True
assert node_var.editable == True
assert node_var.node_execution_id == "exec-id"

View File

@@ -0,0 +1,288 @@
from datetime import datetime
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from models.workflow import WorkflowNodeExecutionModel
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
@pytest.fixture
def repository(self):
mock_session_maker = MagicMock()
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
@pytest.fixture
def mock_execution(self):
execution = MagicMock(spec=WorkflowNodeExecutionModel)
execution.id = str(uuid4())
execution.tenant_id = "tenant-123"
execution.app_id = "app-456"
execution.workflow_id = "workflow-789"
execution.workflow_run_id = "run-101"
execution.node_id = "node-202"
execution.index = 1
execution.created_at = "2023-01-01T00:00:00Z"
return execution
def test_get_node_last_execution_found(self, repository, mock_execution):
"""Test getting the last execution for a node when it exists."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = mock_execution
# Act
result = repository.get_node_last_execution(
tenant_id="tenant-123",
app_id="app-456",
workflow_id="workflow-789",
node_id="node-202",
)
# Assert
assert result == mock_execution
mock_session.scalar.assert_called_once()
# Verify the query was constructed correctly
call_args = mock_session.scalar.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
def test_get_node_last_execution_not_found(self, repository):
"""Test getting the last execution for a node when it doesn't exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = None
# Act
result = repository.get_node_last_execution(
tenant_id="tenant-123",
app_id="app-456",
workflow_id="workflow-789",
node_id="node-202",
)
# Assert
assert result is None
mock_session.scalar.assert_called_once()
def test_get_executions_by_workflow_run(self, repository, mock_execution):
"""Test getting all executions for a workflow run."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
executions = [mock_execution]
mock_session.execute.return_value.scalars.return_value.all.return_value = executions
# Act
result = repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-456",
workflow_run_id="run-101",
)
# Assert
assert result == executions
mock_session.execute.assert_called_once()
# Verify the query was constructed correctly
call_args = mock_session.execute.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
def test_get_executions_by_workflow_run_empty(self, repository):
"""Test getting executions for a workflow run when none exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalars.return_value.all.return_value = []
# Act
result = repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-456",
workflow_run_id="run-101",
)
# Assert
assert result == []
mock_session.execute.assert_called_once()
def test_get_execution_by_id_found(self, repository, mock_execution):
"""Test getting execution by ID when it exists."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = mock_execution
# Act
result = repository.get_execution_by_id(mock_execution.id)
# Assert
assert result == mock_execution
mock_session.scalar.assert_called_once()
def test_get_execution_by_id_not_found(self, repository):
"""Test getting execution by ID when it doesn't exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = None
# Act
result = repository.get_execution_by_id("non-existent-id")
# Assert
assert result is None
mock_session.scalar.assert_called_once()
def test_repository_implements_protocol(self, repository):
"""Test that the repository implements the required protocol methods."""
# Verify all protocol methods are implemented
assert hasattr(repository, "get_node_last_execution")
assert hasattr(repository, "get_executions_by_workflow_run")
assert hasattr(repository, "get_execution_by_id")
# Verify methods are callable
assert callable(repository.get_node_last_execution)
assert callable(repository.get_executions_by_workflow_run)
assert callable(repository.get_execution_by_id)
assert callable(repository.delete_expired_executions)
assert callable(repository.delete_executions_by_app)
assert callable(repository.get_expired_executions_batch)
assert callable(repository.delete_executions_by_ids)
def test_delete_expired_executions(self, repository):
"""Test deleting expired executions."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
before_date = datetime(2023, 1, 1)
# Act
result = repository.delete_expired_executions(
tenant_id="tenant-123",
before_date=before_date,
batch_size=1000,
)
# Assert
assert result == 2
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_delete_executions_by_app(self, repository):
"""Test deleting executions by app."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"]
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
# Act
result = repository.delete_executions_by_app(
tenant_id="tenant-123",
app_id="app-456",
batch_size=1000,
)
# Assert
assert result == 2
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_get_expired_executions_batch(self, repository):
"""Test getting expired executions batch for backup."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Create mock execution objects
mock_execution1 = MagicMock()
mock_execution1.id = "exec-1"
mock_execution2 = MagicMock()
mock_execution2.id = "exec-2"
mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
before_date = datetime(2023, 1, 1)
# Act
result = repository.get_expired_executions_batch(
tenant_id="tenant-123",
before_date=before_date,
batch_size=1000,
)
# Assert
assert len(result) == 2
assert result[0].id == "exec-1"
assert result[1].id == "exec-2"
mock_session.execute.assert_called_once()
def test_delete_executions_by_ids(self, repository):
"""Test deleting executions by IDs."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the delete query result
mock_result = MagicMock()
mock_result.rowcount = 3
mock_session.execute.return_value = mock_result
execution_ids = ["id1", "id2", "id3"]
# Act
result = repository.delete_executions_by_ids(execution_ids)
# Assert
assert result == 3
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_delete_executions_by_ids_empty_list(self, repository):
"""Test deleting executions with empty ID list."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Act
result = repository.delete_executions_by_ids([])
# Assert
assert result == 0
mock_session.query.assert_not_called()
mock_session.commit.assert_not_called()

View File

@@ -0,0 +1,163 @@
from unittest.mock import MagicMock
import pytest
from models.model import App
from models.workflow import Workflow
from services.workflow_service import WorkflowService
class TestWorkflowService:
@pytest.fixture
def workflow_service(self):
mock_session_maker = MagicMock()
return WorkflowService(mock_session_maker)
@pytest.fixture
def mock_app(self):
app = MagicMock(spec=App)
app.id = "app-id-1"
app.workflow_id = "workflow-id-1"
app.tenant_id = "tenant-id-1"
return app
@pytest.fixture
def mock_workflows(self):
workflows = []
for i in range(5):
workflow = MagicMock(spec=Workflow)
workflow.id = f"workflow-id-{i}"
workflow.app_id = "app-id-1"
workflow.created_at = f"2023-01-0{5 - i}" # Descending date order
workflow.created_by = "user-id-1" if i % 2 == 0 else "user-id-2"
workflow.marked_name = f"Workflow {i}" if i % 2 == 0 else ""
workflows.append(workflow)
return workflows
def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app):
mock_app.workflow_id = None
mock_session = MagicMock()
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
)
assert workflows == []
assert has_more is False
mock_session.scalars.assert_not_called()
def test_get_all_published_workflow_basic(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
mock_scalar_result.all.return_value = mock_workflows[:3]
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
)
assert workflows == mock_workflows[:3]
assert has_more is False
mock_session.scalars.assert_called_once()
def test_get_all_published_workflow_pagination(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Return 4 items when limit is 3, which should indicate has_more=True
mock_scalar_result.all.return_value = mock_workflows[:4]
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
)
# Should return only the first 3 items
assert len(workflows) == 3
assert workflows == mock_workflows[:3]
assert has_more is True
# Test page 2
mock_scalar_result.all.return_value = mock_workflows[3:]
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=2, limit=3, user_id=None
)
assert len(workflows) == 2
assert has_more is False
def test_get_all_published_workflow_user_filter(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Filter workflows for user-id-1
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1"]
mock_scalar_result.all.return_value = filtered_workflows
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1"
)
assert workflows == filtered_workflows
assert has_more is False
mock_session.scalars.assert_called_once()
# Verify that the select contains a user filter clause
args = mock_session.scalars.call_args[0][0]
assert "created_by" in str(args)
def test_get_all_published_workflow_named_only(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Filter workflows that have a marked_name
named_workflows = [w for w in mock_workflows if w.marked_name]
mock_scalar_result.all.return_value = named_workflows
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None, named_only=True
)
assert workflows == named_workflows
assert has_more is False
mock_session.scalars.assert_called_once()
# Verify that the select contains a named_only filter clause
args = mock_session.scalars.call_args[0][0]
assert "marked_name !=" in str(args)
def test_get_all_published_workflow_combined_filters(self, workflow_service, mock_app, mock_workflows):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
# Combined filter: user-id-1 and has marked_name
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1" and w.marked_name]
mock_scalar_result.all.return_value = filtered_workflows
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1", named_only=True
)
assert workflows == filtered_workflows
assert has_more is False
mock_session.scalars.assert_called_once()
# Verify that both filters are applied
args = mock_session.scalars.call_args[0][0]
assert "created_by" in str(args)
assert "marked_name !=" in str(args)
def test_get_all_published_workflow_empty_result(self, workflow_service, mock_app):
mock_session = MagicMock()
mock_scalar_result = MagicMock()
mock_scalar_result.all.return_value = []
mock_session.scalars.return_value = mock_scalar_result
workflows, has_more = workflow_service.get_all_published_workflow(
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
)
assert workflows == []
assert has_more is False
mock_session.scalars.assert_called_once()