dify
This commit is contained in:
0
dify/api/tests/unit_tests/services/__init__.py
Normal file
0
dify/api/tests/unit_tests/services/__init__.py
Normal file
1
dify/api/tests/unit_tests/services/auth/__init__.py
Normal file
1
dify/api/tests/unit_tests/services/auth/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# API authentication service test module
|
||||
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
||||
class ConcreteApiKeyAuth(ApiKeyAuthBase):
|
||||
"""Concrete implementation for testing abstract base class"""
|
||||
|
||||
def validate_credentials(self):
|
||||
return True
|
||||
|
||||
|
||||
class TestApiKeyAuthBase:
|
||||
def test_should_store_credentials_on_init(self):
|
||||
"""Test that credentials are properly stored during initialization"""
|
||||
credentials = {"api_key": "test_key", "auth_type": "bearer"}
|
||||
auth = ConcreteApiKeyAuth(credentials)
|
||||
assert auth.credentials == credentials
|
||||
|
||||
def test_should_not_instantiate_abstract_class(self):
|
||||
"""Test that ApiKeyAuthBase cannot be instantiated directly"""
|
||||
credentials = {"api_key": "test_key"}
|
||||
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
ApiKeyAuthBase(credentials)
|
||||
|
||||
assert "Can't instantiate abstract class" in str(exc_info.value)
|
||||
assert "validate_credentials" in str(exc_info.value)
|
||||
|
||||
def test_should_allow_subclass_implementation(self):
|
||||
"""Test that subclasses can properly implement the abstract method"""
|
||||
credentials = {"api_key": "test_key", "auth_type": "bearer"}
|
||||
auth = ConcreteApiKeyAuth(credentials)
|
||||
|
||||
# Should not raise any exception
|
||||
result = auth.validate_credentials()
|
||||
assert result is True
|
||||
|
||||
def test_should_handle_empty_credentials(self):
|
||||
"""Test initialization with empty credentials"""
|
||||
credentials = {}
|
||||
auth = ConcreteApiKeyAuth(credentials)
|
||||
assert auth.credentials == {}
|
||||
|
||||
def test_should_handle_none_credentials(self):
|
||||
"""Test initialization with None credentials"""
|
||||
credentials = None
|
||||
auth = ConcreteApiKeyAuth(credentials)
|
||||
assert auth.credentials is None
|
||||
@@ -0,0 +1,81 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
from services.auth.auth_type import AuthType
|
||||
|
||||
|
||||
class TestApiKeyAuthFactory:
|
||||
"""Test cases for ApiKeyAuthFactory"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "auth_class_path"),
|
||||
[
|
||||
(AuthType.FIRECRAWL, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
|
||||
(AuthType.WATERCRAWL, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
|
||||
(AuthType.JINA, "services.auth.jina.jina.JinaAuth"),
|
||||
],
|
||||
)
|
||||
def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path):
|
||||
"""Test getting auth factory for all valid providers"""
|
||||
with patch(auth_class_path) as mock_auth:
|
||||
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
|
||||
assert auth_class == mock_auth
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_provider",
|
||||
[
|
||||
"invalid_provider",
|
||||
"",
|
||||
None,
|
||||
123,
|
||||
"UNSUPPORTED",
|
||||
],
|
||||
)
|
||||
def test_get_apikey_auth_factory_invalid_providers(self, invalid_provider):
|
||||
"""Test getting auth factory with various invalid providers"""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiKeyAuthFactory.get_apikey_auth_factory(invalid_provider)
|
||||
assert str(exc_info.value) == "Invalid provider"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("credentials_return_value", "expected_result"),
|
||||
[
|
||||
(True, True),
|
||||
(False, False),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
|
||||
def test_validate_credentials_delegates_to_auth_instance(
|
||||
self, mock_get_factory, credentials_return_value, expected_result
|
||||
):
|
||||
"""Test that validate_credentials delegates to auth instance correctly"""
|
||||
# Arrange
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_auth_instance.validate_credentials.return_value = credentials_return_value
|
||||
mock_auth_class = MagicMock(return_value=mock_auth_instance)
|
||||
mock_get_factory.return_value = mock_auth_class
|
||||
|
||||
# Act
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
|
||||
result = factory.validate_credentials()
|
||||
|
||||
# Assert
|
||||
assert result is expected_result
|
||||
mock_auth_instance.validate_credentials.assert_called_once()
|
||||
|
||||
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
|
||||
def test_validate_credentials_propagates_exceptions(self, mock_get_factory):
|
||||
"""Test that exceptions from auth instance are propagated"""
|
||||
# Arrange
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_auth_instance.validate_credentials.side_effect = Exception("Authentication error")
|
||||
mock_auth_class = MagicMock(return_value=mock_auth_instance)
|
||||
mock_get_factory.return_value = mock_auth_class
|
||||
|
||||
# Act & Assert
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
factory.validate_credentials()
|
||||
assert str(exc_info.value) == "Authentication error"
|
||||
@@ -0,0 +1,387 @@
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
|
||||
class TestApiKeyAuthService:
|
||||
"""API key authentication service security tests"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test fixtures"""
|
||||
self.tenant_id = "test_tenant_123"
|
||||
self.category = "search"
|
||||
self.provider = "google"
|
||||
self.binding_id = "binding_123"
|
||||
self.mock_credentials = {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}}
|
||||
self.mock_args = {"category": self.category, "provider": self.provider, "credentials": self.mock_credentials}
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_success(self, mock_session):
|
||||
"""Test get provider auth list - success scenario"""
|
||||
# Mock database query result
|
||||
mock_binding = Mock()
|
||||
mock_binding.tenant_id = self.tenant_id
|
||||
mock_binding.provider = self.provider
|
||||
mock_binding.disabled = False
|
||||
|
||||
mock_session.scalars.return_value.all.return_value = [mock_binding]
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].tenant_id == self.tenant_id
|
||||
assert mock_session.scalars.call_count == 1
|
||||
select_arg = mock_session.scalars.call_args[0][0]
|
||||
assert "data_source_api_key_auth_binding" in str(select_arg).lower()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_empty(self, mock_session):
|
||||
"""Test get provider auth list - empty result"""
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_filters_disabled(self, mock_session):
|
||||
"""Test get provider auth list - filters disabled items"""
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
select_stmt = mock_session.scalars.call_args[0][0]
|
||||
where_clauses = list(getattr(select_stmt, "_where_criteria", []) or [])
|
||||
# Ensure both tenant filter and disabled filter exist
|
||||
where_strs = [str(c).lower() for c in where_clauses]
|
||||
assert any("tenant_id" in s for s in where_strs)
|
||||
assert any("disabled" in s for s in where_strs)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
def test_create_provider_auth_success(self, mock_encrypter, mock_factory, mock_session):
|
||||
"""Test create provider auth - success scenario"""
|
||||
# Mock successful auth validation
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = True
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
|
||||
# Mock encryption
|
||||
encrypted_key = "encrypted_test_key_123"
|
||||
mock_encrypter.encrypt_token.return_value = encrypted_key
|
||||
|
||||
# Mock database operations
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
|
||||
|
||||
# Verify factory class calls
|
||||
mock_factory.assert_called_once_with(self.provider, self.mock_credentials)
|
||||
mock_auth_instance.validate_credentials.assert_called_once()
|
||||
|
||||
# Verify encryption calls
|
||||
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, "test_secret_key_123")
|
||||
|
||||
# Verify database operations
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
def test_create_provider_auth_validation_failed(self, mock_factory, mock_session):
|
||||
"""Test create provider auth - validation failed"""
|
||||
# Mock failed auth validation
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = False
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
|
||||
|
||||
# Verify no database operations when validation fails
|
||||
mock_session.add.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
def test_create_provider_auth_encrypts_api_key(self, mock_encrypter, mock_factory, mock_session):
|
||||
"""Test create provider auth - ensures API key is encrypted"""
|
||||
# Mock successful auth validation
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = True
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
|
||||
# Mock encryption
|
||||
encrypted_key = "encrypted_test_key_123"
|
||||
mock_encrypter.encrypt_token.return_value = encrypted_key
|
||||
|
||||
# Mock database operations
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
args_copy = self.mock_args.copy()
|
||||
original_key = args_copy["credentials"]["config"]["api_key"]
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
|
||||
|
||||
# Verify original key is replaced with encrypted key
|
||||
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key
|
||||
assert args_copy["credentials"]["config"]["api_key"] != original_key
|
||||
|
||||
# Verify encryption function is called correctly
|
||||
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_auth_credentials_success(self, mock_session):
|
||||
"""Test get auth credentials - success scenario"""
|
||||
# Mock database query result
|
||||
mock_binding = Mock()
|
||||
mock_binding.credentials = json.dumps(self.mock_credentials)
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
assert result == self.mock_credentials
|
||||
mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_auth_credentials_not_found(self, mock_session):
|
||||
"""Test get auth credentials - not found"""
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_auth_credentials_filters_correctly(self, mock_session):
|
||||
"""Test get auth credentials - applies correct filters"""
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
# Verify where conditions are correct
|
||||
where_call = mock_session.query.return_value.where.call_args[0]
|
||||
assert len(where_call) == 4 # tenant_id, category, provider, disabled
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_auth_credentials_json_parsing(self, mock_session):
|
||||
"""Test get auth credentials - JSON parsing"""
|
||||
# Mock credentials with special characters
|
||||
special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}}
|
||||
|
||||
mock_binding = Mock()
|
||||
mock_binding.credentials = json.dumps(special_credentials, ensure_ascii=False)
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
assert result == special_credentials
|
||||
assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%"
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_delete_provider_auth_success(self, mock_session):
|
||||
"""Test delete provider auth - success scenario"""
|
||||
# Mock database query result
|
||||
mock_binding = Mock()
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
|
||||
|
||||
# Verify delete operations
|
||||
mock_session.delete.assert_called_once_with(mock_binding)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_delete_provider_auth_not_found(self, mock_session):
|
||||
"""Test delete provider auth - not found"""
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
|
||||
|
||||
# Verify no delete operations when not found
|
||||
mock_session.delete.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_delete_provider_auth_filters_by_tenant(self, mock_session):
|
||||
"""Test delete provider auth - filters by tenant"""
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
|
||||
|
||||
# Verify where conditions include tenant_id and binding_id
|
||||
where_call = mock_session.query.return_value.where.call_args[0]
|
||||
assert len(where_call) == 2
|
||||
|
||||
def test_validate_api_key_auth_args_success(self):
|
||||
"""Test API key auth args validation - success scenario"""
|
||||
# Should not raise any exception
|
||||
ApiKeyAuthService.validate_api_key_auth_args(self.mock_args)
|
||||
|
||||
def test_validate_api_key_auth_args_missing_category(self):
|
||||
"""Test API key auth args validation - missing category"""
|
||||
args = self.mock_args.copy()
|
||||
del args["category"]
|
||||
|
||||
with pytest.raises(ValueError, match="category is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_empty_category(self):
|
||||
"""Test API key auth args validation - empty category"""
|
||||
args = self.mock_args.copy()
|
||||
args["category"] = ""
|
||||
|
||||
with pytest.raises(ValueError, match="category is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_missing_provider(self):
|
||||
"""Test API key auth args validation - missing provider"""
|
||||
args = self.mock_args.copy()
|
||||
del args["provider"]
|
||||
|
||||
with pytest.raises(ValueError, match="provider is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_empty_provider(self):
|
||||
"""Test API key auth args validation - empty provider"""
|
||||
args = self.mock_args.copy()
|
||||
args["provider"] = ""
|
||||
|
||||
with pytest.raises(ValueError, match="provider is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_missing_credentials(self):
|
||||
"""Test API key auth args validation - missing credentials"""
|
||||
args = self.mock_args.copy()
|
||||
del args["credentials"]
|
||||
|
||||
with pytest.raises(ValueError, match="credentials is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_empty_credentials(self):
|
||||
"""Test API key auth args validation - empty credentials"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"] = None
|
||||
|
||||
with pytest.raises(ValueError, match="credentials is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_invalid_credentials_type(self):
|
||||
"""Test API key auth args validation - invalid credentials type"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"] = "not_a_dict"
|
||||
|
||||
with pytest.raises(ValueError, match="credentials must be a dictionary"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_missing_auth_type(self):
|
||||
"""Test API key auth args validation - missing auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
del args["credentials"]["auth_type"]
|
||||
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
def test_validate_api_key_auth_args_empty_auth_type(self):
|
||||
"""Test API key auth args validation - empty auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"]["auth_type"] = ""
|
||||
|
||||
with pytest.raises(ValueError, match="auth_type is required"):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"malicious_input",
|
||||
[
|
||||
"<script>alert('xss')</script>",
|
||||
"'; DROP TABLE users; --",
|
||||
"../../../etc/passwd",
|
||||
"\\x00\\x00", # null bytes
|
||||
"A" * 10000, # very long input
|
||||
],
|
||||
)
|
||||
def test_validate_api_key_auth_args_malicious_input(self, malicious_input):
|
||||
"""Test API key auth args validation - malicious input"""
|
||||
args = self.mock_args.copy()
|
||||
args["category"] = malicious_input
|
||||
|
||||
# Verify parameter validator doesn't crash on malicious input
|
||||
# Should validate normally rather than raising security-related exceptions
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
def test_create_provider_auth_database_error_handling(self, mock_encrypter, mock_factory, mock_session):
|
||||
"""Test create provider auth - database error handling"""
|
||||
# Mock successful auth validation
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = True
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
|
||||
# Mock encryption
|
||||
mock_encrypter.encrypt_token.return_value = "encrypted_key"
|
||||
|
||||
# Mock database error
|
||||
mock_session.commit.side_effect = Exception("Database error")
|
||||
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_auth_credentials_invalid_json(self, mock_session):
|
||||
"""Test get auth credentials - invalid JSON"""
|
||||
# Mock database returning invalid JSON
|
||||
mock_binding = Mock()
|
||||
mock_binding.credentials = "invalid json content"
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_binding
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
def test_create_provider_auth_factory_exception(self, mock_factory, mock_session):
|
||||
"""Test create provider auth - factory exception"""
|
||||
# Mock factory raising exception
|
||||
mock_factory.side_effect = Exception("Factory error")
|
||||
|
||||
with pytest.raises(Exception, match="Factory error"):
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, mock_session):
|
||||
"""Test create provider auth - encryption exception"""
|
||||
# Mock successful auth validation
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = True
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
|
||||
# Mock encryption exception
|
||||
mock_encrypter.encrypt_token.side_effect = Exception("Encryption error")
|
||||
|
||||
with pytest.raises(Exception, match="Encryption error"):
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
|
||||
|
||||
def test_validate_api_key_auth_args_none_input(self):
|
||||
"""Test API key auth args validation - None input"""
|
||||
with pytest.raises(TypeError):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(None)
|
||||
|
||||
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
|
||||
"""Test API key auth args validation - dict credentials with list auth_type"""
|
||||
args = self.mock_args.copy()
|
||||
args["credentials"]["auth_type"] = ["api_key"]
|
||||
|
||||
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
|
||||
# So this should not raise exception, this test should pass
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
231
dify/api/tests/unit_tests/services/auth/test_auth_integration.py
Normal file
231
dify/api/tests/unit_tests/services/auth/test_auth_integration.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
API Key Authentication System Integration Tests
|
||||
"""
|
||||
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
from services.auth.auth_type import AuthType
|
||||
|
||||
|
||||
class TestAuthIntegration:
|
||||
def setup_method(self):
|
||||
self.tenant_id_1 = "tenant_123"
|
||||
self.tenant_id_2 = "tenant_456" # For multi-tenant isolation testing
|
||||
self.category = "search"
|
||||
|
||||
# Realistic authentication configurations
|
||||
self.firecrawl_credentials = {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}}
|
||||
self.jina_credentials = {"auth_type": "bearer", "config": {"api_key": "jina_test_key_456"}}
|
||||
self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}}
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
|
||||
def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session):
|
||||
"""Test complete authentication flow: request → validation → encryption → storage"""
|
||||
mock_http.return_value = self._create_success_response()
|
||||
mock_encrypt.return_value = "encrypted_fc_test_key_123"
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
|
||||
|
||||
mock_http.assert_called_once()
|
||||
call_args = mock_http.call_args
|
||||
assert "https://api.firecrawl.dev/v1/crawl" in call_args[0][0]
|
||||
assert call_args[1]["headers"]["Authorization"] == "Bearer fc_test_key_123"
|
||||
|
||||
mock_encrypt.assert_called_once_with(self.tenant_id_1, "fc_test_key_123")
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_cross_component_integration(self, mock_http):
|
||||
"""Test factory → provider → HTTP call integration"""
|
||||
mock_http.return_value = self._create_success_response()
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
|
||||
result = factory.validate_credentials()
|
||||
|
||||
assert result is True
|
||||
mock_http.assert_called_once()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_multi_tenant_isolation(self, mock_session):
|
||||
"""Ensure complete tenant data isolation"""
|
||||
tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
|
||||
tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)
|
||||
|
||||
mock_session.scalars.return_value.all.return_value = [tenant1_binding]
|
||||
result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)
|
||||
|
||||
mock_session.scalars.return_value.all.return_value = [tenant2_binding]
|
||||
result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)
|
||||
|
||||
assert len(result1) == 1
|
||||
assert result1[0].tenant_id == self.tenant_id_1
|
||||
assert len(result2) == 1
|
||||
assert result2[0].tenant_id == self.tenant_id_2
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_cross_tenant_access_prevention(self, mock_session):
|
||||
"""Test prevention of cross-tenant credential access"""
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_sensitive_data_protection(self):
|
||||
"""Ensure API keys don't leak to logs"""
|
||||
credentials_with_secrets = {
|
||||
"auth_type": "bearer",
|
||||
"config": {"api_key": "super_secret_key_do_not_log", "secret": "another_secret"},
|
||||
}
|
||||
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, credentials_with_secrets)
|
||||
factory_str = str(factory)
|
||||
|
||||
assert "super_secret_key_do_not_log" not in factory_str
|
||||
assert "another_secret" not in factory_str
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
|
||||
def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session):
|
||||
"""Test concurrent authentication creation safety"""
|
||||
mock_http.return_value = self._create_success_response()
|
||||
mock_encrypt.return_value = "encrypted_key"
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
|
||||
|
||||
results = []
|
||||
exceptions = []
|
||||
|
||||
def create_auth():
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
|
||||
results.append("success")
|
||||
except Exception as e:
|
||||
exceptions.append(e)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(create_auth) for _ in range(5)]
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
assert len(results) == 5
|
||||
assert len(exceptions) == 0
|
||||
assert mock_session.add.call_count == 5
|
||||
assert mock_session.commit.call_count == 5
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_input",
|
||||
[
|
||||
None, # Null input
|
||||
{}, # Empty dictionary - missing required fields
|
||||
{"auth_type": "bearer"}, # Missing config section
|
||||
{"auth_type": "bearer", "config": {}}, # Missing api_key
|
||||
],
|
||||
)
|
||||
def test_invalid_input_boundary(self, invalid_input):
|
||||
"""Test boundary handling for invalid inputs"""
|
||||
with pytest.raises((ValueError, KeyError, TypeError, AttributeError)):
|
||||
ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input)
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_http_error_handling(self, mock_http):
|
||||
"""Test proper HTTP error handling"""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = '{"error": "Unauthorized"}'
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized")
|
||||
mock_http.return_value = mock_response
|
||||
|
||||
# PT012: Split into single statement for pytest.raises
|
||||
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
|
||||
with pytest.raises((httpx.HTTPError, Exception)):
|
||||
factory.validate_credentials()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_network_failure_recovery(self, mock_http, mock_session):
|
||||
"""Test system recovery from network failures"""
|
||||
mock_http.side_effect = httpx.RequestError("Network timeout")
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
|
||||
|
||||
with pytest.raises(httpx.RequestError):
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
|
||||
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "credentials"),
|
||||
[
|
||||
(AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "fc_key"}}),
|
||||
(AuthType.JINA, {"auth_type": "bearer", "config": {"api_key": "jina_key"}}),
|
||||
(AuthType.WATERCRAWL, {"auth_type": "x-api-key", "config": {"api_key": "wc_key"}}),
|
||||
],
|
||||
)
|
||||
def test_all_providers_factory_creation(self, provider, credentials):
|
||||
"""Test factory creation for all supported providers"""
|
||||
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
|
||||
assert auth_class is not None
|
||||
|
||||
factory = ApiKeyAuthFactory(provider, credentials)
|
||||
assert factory.auth is not None
|
||||
|
||||
def _create_success_response(self, status_code=200):
|
||||
"""Create successful HTTP response mock"""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.json.return_value = {"status": "success"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
return mock_response
|
||||
|
||||
def _create_mock_binding(self, tenant_id: str, provider: str, credentials: dict) -> Mock:
|
||||
"""Create realistic database binding mock"""
|
||||
mock_binding = Mock()
|
||||
mock_binding.id = f"binding_{provider}_{tenant_id}"
|
||||
mock_binding.tenant_id = tenant_id
|
||||
mock_binding.category = self.category
|
||||
mock_binding.provider = provider
|
||||
mock_binding.credentials = json.dumps(credentials, ensure_ascii=False)
|
||||
mock_binding.disabled = False
|
||||
|
||||
mock_binding.created_at = Mock()
|
||||
mock_binding.created_at.timestamp.return_value = 1640995200
|
||||
mock_binding.updated_at = Mock()
|
||||
mock_binding.updated_at.timestamp.return_value = 1640995200
|
||||
|
||||
return mock_binding
|
||||
|
||||
def test_integration_coverage_validation(self):
|
||||
"""Validate integration test coverage meets quality standards"""
|
||||
core_scenarios = {
|
||||
"business_logic": ["end_to_end_auth_flow", "cross_component_integration"],
|
||||
"security": ["multi_tenant_isolation", "cross_tenant_access_prevention", "sensitive_data_protection"],
|
||||
"reliability": ["concurrent_creation_safety", "network_failure_recovery"],
|
||||
"compatibility": ["all_providers_factory_creation"],
|
||||
"boundaries": ["invalid_input_boundary", "http_error_handling"],
|
||||
}
|
||||
|
||||
total_scenarios = sum(len(scenarios) for scenarios in core_scenarios.values())
|
||||
assert total_scenarios >= 10
|
||||
|
||||
security_tests = core_scenarios["security"]
|
||||
assert "multi_tenant_isolation" in security_tests
|
||||
assert "sensitive_data_protection" in security_tests
|
||||
assert True
|
||||
150
dify/api/tests/unit_tests/services/auth/test_auth_type.py
Normal file
150
dify/api/tests/unit_tests/services/auth/test_auth_type.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import pytest
|
||||
|
||||
from services.auth.auth_type import AuthType
|
||||
|
||||
|
||||
class TestAuthType:
|
||||
"""Test cases for AuthType enum"""
|
||||
|
||||
def test_auth_type_is_str_enum(self):
|
||||
"""Test that AuthType is properly a StrEnum"""
|
||||
assert issubclass(AuthType, str)
|
||||
assert hasattr(AuthType, "__members__")
|
||||
|
||||
def test_auth_type_has_expected_values(self):
|
||||
"""Test that all expected auth types exist with correct values"""
|
||||
expected_values = {
|
||||
"FIRECRAWL": "firecrawl",
|
||||
"WATERCRAWL": "watercrawl",
|
||||
"JINA": "jinareader",
|
||||
}
|
||||
|
||||
# Verify all expected members exist
|
||||
for member_name, expected_value in expected_values.items():
|
||||
assert hasattr(AuthType, member_name)
|
||||
assert getattr(AuthType, member_name).value == expected_value
|
||||
|
||||
# Verify no extra members exist
|
||||
assert len(AuthType) == len(expected_values)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("auth_type", "expected_string"),
|
||||
[
|
||||
(AuthType.FIRECRAWL, "firecrawl"),
|
||||
(AuthType.WATERCRAWL, "watercrawl"),
|
||||
(AuthType.JINA, "jinareader"),
|
||||
],
|
||||
)
|
||||
def test_auth_type_string_representation(self, auth_type, expected_string):
|
||||
"""Test string representation of auth types"""
|
||||
assert str(auth_type) == expected_string
|
||||
assert auth_type.value == expected_string
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("auth_type", "compare_value", "expected_result"),
|
||||
[
|
||||
(AuthType.FIRECRAWL, "firecrawl", True),
|
||||
(AuthType.WATERCRAWL, "watercrawl", True),
|
||||
(AuthType.JINA, "jinareader", True),
|
||||
(AuthType.FIRECRAWL, "FIRECRAWL", False), # Case sensitive
|
||||
(AuthType.FIRECRAWL, "watercrawl", False),
|
||||
(AuthType.JINA, "jina", False), # Full value mismatch
|
||||
],
|
||||
)
|
||||
def test_auth_type_comparison(self, auth_type, compare_value, expected_result):
|
||||
"""Test auth type comparison with strings"""
|
||||
assert (auth_type == compare_value) is expected_result
|
||||
|
||||
def test_auth_type_iteration(self):
|
||||
"""Test that AuthType can be iterated over"""
|
||||
auth_types = list(AuthType)
|
||||
assert len(auth_types) == 3
|
||||
assert AuthType.FIRECRAWL in auth_types
|
||||
assert AuthType.WATERCRAWL in auth_types
|
||||
assert AuthType.JINA in auth_types
|
||||
|
||||
def test_auth_type_membership(self):
|
||||
"""Test membership checking for AuthType"""
|
||||
assert "firecrawl" in [auth.value for auth in AuthType]
|
||||
assert "watercrawl" in [auth.value for auth in AuthType]
|
||||
assert "jinareader" in [auth.value for auth in AuthType]
|
||||
assert "invalid" not in [auth.value for auth in AuthType]
|
||||
|
||||
def test_auth_type_invalid_attribute_access(self):
|
||||
"""Test accessing non-existent auth type raises AttributeError"""
|
||||
with pytest.raises(AttributeError):
|
||||
_ = AuthType.INVALID_TYPE
|
||||
|
||||
def test_auth_type_immutability(self):
|
||||
"""Test that enum values cannot be modified"""
|
||||
# In Python 3.11+, enum members are read-only
|
||||
with pytest.raises(AttributeError):
|
||||
AuthType.FIRECRAWL = "modified"
|
||||
|
||||
def test_auth_type_from_value(self):
|
||||
"""Test creating AuthType from string value"""
|
||||
assert AuthType("firecrawl") == AuthType.FIRECRAWL
|
||||
assert AuthType("watercrawl") == AuthType.WATERCRAWL
|
||||
assert AuthType("jinareader") == AuthType.JINA
|
||||
|
||||
# Test invalid value
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AuthType("invalid_auth_type")
|
||||
assert "invalid_auth_type" in str(exc_info.value)
|
||||
|
||||
def test_auth_type_name_property(self):
|
||||
"""Test the name property of enum members"""
|
||||
assert AuthType.FIRECRAWL.name == "FIRECRAWL"
|
||||
assert AuthType.WATERCRAWL.name == "WATERCRAWL"
|
||||
assert AuthType.JINA.name == "JINA"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"auth_type",
|
||||
[AuthType.FIRECRAWL, AuthType.WATERCRAWL, AuthType.JINA],
|
||||
)
|
||||
def test_auth_type_isinstance_checks(self, auth_type):
|
||||
"""Test isinstance checks for auth types"""
|
||||
assert isinstance(auth_type, AuthType)
|
||||
assert isinstance(auth_type, str)
|
||||
assert isinstance(auth_type.value, str)
|
||||
|
||||
def test_auth_type_hash(self):
|
||||
"""Test that auth types are hashable and can be used in sets/dicts"""
|
||||
auth_set = {AuthType.FIRECRAWL, AuthType.WATERCRAWL, AuthType.JINA}
|
||||
assert len(auth_set) == 3
|
||||
|
||||
auth_dict = {
|
||||
AuthType.FIRECRAWL: "firecrawl_handler",
|
||||
AuthType.WATERCRAWL: "watercrawl_handler",
|
||||
AuthType.JINA: "jina_handler",
|
||||
}
|
||||
assert auth_dict[AuthType.FIRECRAWL] == "firecrawl_handler"
|
||||
|
||||
def test_auth_type_json_serializable(self):
|
||||
"""Test that auth types can be JSON serialized"""
|
||||
import json
|
||||
|
||||
auth_data = {
|
||||
"provider": AuthType.FIRECRAWL,
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
# Should serialize to string value
|
||||
json_str = json.dumps(auth_data, default=str)
|
||||
assert '"provider": "firecrawl"' in json_str
|
||||
|
||||
def test_auth_type_matches_factory_usage(self):
|
||||
"""Test that all AuthType values are handled by ApiKeyAuthFactory"""
|
||||
# This test verifies that the enum values match what's expected
|
||||
# by the factory implementation
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
|
||||
for auth_type in AuthType:
|
||||
# Should not raise ValueError for valid auth types
|
||||
try:
|
||||
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(auth_type)
|
||||
assert auth_class is not None
|
||||
except ImportError:
|
||||
# It's OK if the actual auth implementation doesn't exist
|
||||
# We're just testing that the enum value is recognized
|
||||
pass
|
||||
191
dify/api/tests/unit_tests/services/auth/test_firecrawl_auth.py
Normal file
191
dify/api/tests/unit_tests/services/auth/test_firecrawl_auth.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from services.auth.firecrawl.firecrawl import FirecrawlAuth
|
||||
|
||||
|
||||
class TestFirecrawlAuth:
|
||||
@pytest.fixture
|
||||
def valid_credentials(self):
|
||||
"""Fixture for valid bearer credentials"""
|
||||
return {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
|
||||
@pytest.fixture
|
||||
def auth_instance(self, valid_credentials):
|
||||
"""Fixture for FirecrawlAuth instance with valid credentials"""
|
||||
return FirecrawlAuth(valid_credentials)
|
||||
|
||||
def test_should_initialize_with_valid_bearer_credentials(self, valid_credentials):
|
||||
"""Test successful initialization with valid bearer credentials"""
|
||||
auth = FirecrawlAuth(valid_credentials)
|
||||
assert auth.api_key == "test_api_key_123"
|
||||
assert auth.base_url == "https://api.firecrawl.dev"
|
||||
assert auth.credentials == valid_credentials
|
||||
|
||||
def test_should_initialize_with_custom_base_url(self):
|
||||
"""Test initialization with custom base URL"""
|
||||
credentials = {
|
||||
"auth_type": "bearer",
|
||||
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
|
||||
}
|
||||
auth = FirecrawlAuth(credentials)
|
||||
assert auth.api_key == "test_api_key_123"
|
||||
assert auth.base_url == "https://custom.firecrawl.dev"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("auth_type", "expected_error"),
|
||||
[
|
||||
("basic", "Invalid auth type, Firecrawl auth type must be Bearer"),
|
||||
("x-api-key", "Invalid auth type, Firecrawl auth type must be Bearer"),
|
||||
("", "Invalid auth type, Firecrawl auth type must be Bearer"),
|
||||
],
|
||||
)
|
||||
def test_should_raise_error_for_invalid_auth_type(self, auth_type, expected_error):
|
||||
"""Test that non-bearer auth types raise ValueError"""
|
||||
credentials = {"auth_type": auth_type, "config": {"api_key": "test_api_key_123"}}
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
FirecrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("credentials", "expected_error"),
|
||||
[
|
||||
({"auth_type": "bearer", "config": {}}, "No API key provided"),
|
||||
({"auth_type": "bearer"}, "No API key provided"),
|
||||
({"auth_type": "bearer", "config": {"api_key": ""}}, "No API key provided"),
|
||||
({"auth_type": "bearer", "config": {"api_key": None}}, "No API key provided"),
|
||||
],
|
||||
)
|
||||
def test_should_raise_error_for_missing_api_key(self, credentials, expected_error):
|
||||
"""Test that missing or empty API key raises ValueError"""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
FirecrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = auth_instance.validate_credentials()
|
||||
|
||||
assert result is True
|
||||
expected_data = {
|
||||
"url": "https://example.com",
|
||||
"includePaths": [],
|
||||
"excludePaths": [],
|
||||
"limit": 1,
|
||||
"scrapeOptions": {"onlyMainContent": True},
|
||||
}
|
||||
mock_post.assert_called_once_with(
|
||||
"https://api.firecrawl.dev/v1/crawl",
|
||||
headers={"Content-Type": "application/json", "Authorization": "Bearer test_api_key_123"},
|
||||
json=expected_data,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("status_code", "error_message"),
|
||||
[
|
||||
(402, "Payment required"),
|
||||
(409, "Conflict error"),
|
||||
(500, "Internal server error"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.json.return_value = {"error": error_message}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
assert str(exc_info.value) == f"Failed to authorize. Status code: {status_code}. Error: {error_message}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("status_code", "response_text", "has_json_error", "expected_error_contains"),
|
||||
[
|
||||
(403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
|
||||
(404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
|
||||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
):
|
||||
"""Test handling of unexpected errors with various response formats"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.text = response_text
|
||||
if has_json_error:
|
||||
mock_response.json.side_effect = Exception("Not JSON")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
assert expected_error_contains in str(exc_info.value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception_type", "exception_message"),
|
||||
[
|
||||
(httpx.ConnectError, "Network error"),
|
||||
(httpx.TimeoutException, "Request timeout"),
|
||||
(httpx.ReadTimeout, "Read timeout"),
|
||||
(httpx.ConnectTimeout, "Connection timeout"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_post.side_effect = exception_type(exception_message)
|
||||
|
||||
with pytest.raises(exception_type) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
assert exception_message in str(exc_info.value)
|
||||
|
||||
def test_should_not_expose_api_key_in_error_messages(self):
|
||||
"""Test that API key is not exposed in error messages"""
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}
|
||||
auth = FirecrawlAuth(credentials)
|
||||
|
||||
# Verify API key is stored but not in any error message
|
||||
assert auth.api_key == "super_secret_key_12345"
|
||||
|
||||
# Test various error scenarios don't expose the key
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_post):
|
||||
"""Test that custom base URL is used in validation"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {
|
||||
"auth_type": "bearer",
|
||||
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
|
||||
}
|
||||
auth = FirecrawlAuth(credentials)
|
||||
result = auth.validate_credentials()
|
||||
|
||||
assert result is True
|
||||
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
|
||||
|
||||
with pytest.raises(httpx.TimeoutException) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
|
||||
# Verify the timeout exception is raised with original message
|
||||
assert "timed out" in str(exc_info.value)
|
||||
155
dify/api/tests/unit_tests/services/auth/test_jina_auth.py
Normal file
155
dify/api/tests/unit_tests/services/auth/test_jina_auth.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from services.auth.jina.jina import JinaAuth
|
||||
|
||||
|
||||
class TestJinaAuth:
|
||||
def test_should_initialize_with_valid_bearer_credentials(self):
|
||||
"""Test successful initialization with valid bearer credentials"""
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
assert auth.api_key == "test_api_key_123"
|
||||
assert auth.credentials == credentials
|
||||
|
||||
def test_should_raise_error_for_invalid_auth_type(self):
|
||||
"""Test that non-bearer auth type raises ValueError"""
|
||||
credentials = {"auth_type": "basic", "config": {"api_key": "test_api_key_123"}}
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
JinaAuth(credentials)
|
||||
assert str(exc_info.value) == "Invalid auth type, Jina Reader auth type must be Bearer"
|
||||
|
||||
def test_should_raise_error_for_missing_api_key(self):
|
||||
"""Test that missing API key raises ValueError"""
|
||||
credentials = {"auth_type": "bearer", "config": {}}
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
JinaAuth(credentials)
|
||||
assert str(exc_info.value) == "No API key provided"
|
||||
|
||||
def test_should_raise_error_for_missing_config(self):
|
||||
"""Test that missing config section raises ValueError"""
|
||||
credentials = {"auth_type": "bearer"}
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
JinaAuth(credentials)
|
||||
assert str(exc_info.value) == "No API key provided"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
result = auth.validate_credentials()
|
||||
|
||||
assert result is True
|
||||
mock_post.assert_called_once_with(
|
||||
"https://r.jina.ai",
|
||||
headers={"Content-Type": "application/json", "Authorization": "Bearer test_api_key_123"},
|
||||
json={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_http_402_error(self, mock_post):
|
||||
"""Test handling of 402 Payment Required error"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 402
|
||||
mock_response.json.return_value = {"error": "Payment required"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_http_409_error(self, mock_post):
|
||||
"""Test handling of 409 Conflict error"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 409
|
||||
mock_response.json.return_value = {"error": "Conflict error"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_http_500_error(self, mock_post):
|
||||
"""Test handling of 500 Internal Server Error"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.json.return_value = {"error": "Internal server error"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
|
||||
"""Test handling of unexpected errors with text response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 403
|
||||
mock_response.text = '{"error": "Forbidden"}'
|
||||
mock_response.json.side_effect = Exception("Not JSON")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_unexpected_error_without_text(self, mock_post):
|
||||
"""Test handling of unexpected errors without text response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
mock_response.text = ""
|
||||
mock_response.json.side_effect = Exception("Not JSON")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
def test_should_handle_network_errors(self, mock_post):
|
||||
"""Test handling of network connection errors"""
|
||||
mock_post.side_effect = httpx.ConnectError("Network error")
|
||||
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
auth.validate_credentials()
|
||||
|
||||
def test_should_not_expose_api_key_in_error_messages(self):
|
||||
"""Test that API key is not exposed in error messages"""
|
||||
credentials = {"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}
|
||||
auth = JinaAuth(credentials)
|
||||
|
||||
# Verify API key is stored but not in any error message
|
||||
assert auth.api_key == "super_secret_key_12345"
|
||||
|
||||
# Test various error scenarios don't expose the key
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
JinaAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
205
dify/api/tests/unit_tests/services/auth/test_watercrawl_auth.py
Normal file
205
dify/api/tests/unit_tests/services/auth/test_watercrawl_auth.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from services.auth.watercrawl.watercrawl import WatercrawlAuth
|
||||
|
||||
|
||||
class TestWatercrawlAuth:
|
||||
@pytest.fixture
|
||||
def valid_credentials(self):
|
||||
"""Fixture for valid x-api-key credentials"""
|
||||
return {"auth_type": "x-api-key", "config": {"api_key": "test_api_key_123"}}
|
||||
|
||||
@pytest.fixture
|
||||
def auth_instance(self, valid_credentials):
|
||||
"""Fixture for WatercrawlAuth instance with valid credentials"""
|
||||
return WatercrawlAuth(valid_credentials)
|
||||
|
||||
def test_should_initialize_with_valid_x_api_key_credentials(self, valid_credentials):
|
||||
"""Test successful initialization with valid x-api-key credentials"""
|
||||
auth = WatercrawlAuth(valid_credentials)
|
||||
assert auth.api_key == "test_api_key_123"
|
||||
assert auth.base_url == "https://app.watercrawl.dev"
|
||||
assert auth.credentials == valid_credentials
|
||||
|
||||
def test_should_initialize_with_custom_base_url(self):
|
||||
"""Test initialization with custom base URL"""
|
||||
credentials = {
|
||||
"auth_type": "x-api-key",
|
||||
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.watercrawl.dev"},
|
||||
}
|
||||
auth = WatercrawlAuth(credentials)
|
||||
assert auth.api_key == "test_api_key_123"
|
||||
assert auth.base_url == "https://custom.watercrawl.dev"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("auth_type", "expected_error"),
|
||||
[
|
||||
("bearer", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
|
||||
("basic", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
|
||||
("", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
|
||||
],
|
||||
)
|
||||
def test_should_raise_error_for_invalid_auth_type(self, auth_type, expected_error):
|
||||
"""Test that non-x-api-key auth types raise ValueError"""
|
||||
credentials = {"auth_type": auth_type, "config": {"api_key": "test_api_key_123"}}
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WatercrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("credentials", "expected_error"),
|
||||
[
|
||||
({"auth_type": "x-api-key", "config": {}}, "No API key provided"),
|
||||
({"auth_type": "x-api-key"}, "No API key provided"),
|
||||
({"auth_type": "x-api-key", "config": {"api_key": ""}}, "No API key provided"),
|
||||
({"auth_type": "x-api-key", "config": {"api_key": None}}, "No API key provided"),
|
||||
],
|
||||
)
|
||||
def test_should_raise_error_for_missing_api_key(self, credentials, expected_error):
|
||||
"""Test that missing or empty API key raises ValueError"""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WatercrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = auth_instance.validate_credentials()
|
||||
|
||||
assert result is True
|
||||
mock_get.assert_called_once_with(
|
||||
"https://app.watercrawl.dev/api/v1/core/crawl-requests/",
|
||||
headers={"Content-Type": "application/json", "X-API-KEY": "test_api_key_123"},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("status_code", "error_message"),
|
||||
[
|
||||
(402, "Payment required"),
|
||||
(409, "Conflict error"),
|
||||
(500, "Internal server error"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.json.return_value = {"error": error_message}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
assert str(exc_info.value) == f"Failed to authorize. Status code: {status_code}. Error: {error_message}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("status_code", "response_text", "has_json_error", "expected_error_contains"),
|
||||
[
|
||||
(403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
|
||||
(404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
|
||||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
):
|
||||
"""Test handling of unexpected errors with various response formats"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.text = response_text
|
||||
if has_json_error:
|
||||
mock_response.json.side_effect = Exception("Not JSON")
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
assert expected_error_contains in str(exc_info.value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception_type", "exception_message"),
|
||||
[
|
||||
(httpx.ConnectError, "Network error"),
|
||||
(httpx.TimeoutException, "Request timeout"),
|
||||
(httpx.ReadTimeout, "Read timeout"),
|
||||
(httpx.ConnectTimeout, "Connection timeout"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_get.side_effect = exception_type(exception_message)
|
||||
|
||||
with pytest.raises(exception_type) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
assert exception_message in str(exc_info.value)
|
||||
|
||||
def test_should_not_expose_api_key_in_error_messages(self):
|
||||
"""Test that API key is not exposed in error messages"""
|
||||
credentials = {"auth_type": "x-api-key", "config": {"api_key": "super_secret_key_12345"}}
|
||||
auth = WatercrawlAuth(credentials)
|
||||
|
||||
# Verify API key is stored but not in any error message
|
||||
assert auth.api_key == "super_secret_key_12345"
|
||||
|
||||
# Test various error scenarios don't expose the key
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_get):
|
||||
"""Test that custom base URL is used in validation"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
credentials = {
|
||||
"auth_type": "x-api-key",
|
||||
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.watercrawl.dev"},
|
||||
}
|
||||
auth = WatercrawlAuth(credentials)
|
||||
result = auth.validate_credentials()
|
||||
|
||||
assert result is True
|
||||
assert mock_get.call_args[0][0] == "https://custom.watercrawl.dev/api/v1/core/crawl-requests/"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("base_url", "expected_url"),
|
||||
[
|
||||
("https://app.watercrawl.dev", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
|
||||
("https://app.watercrawl.dev/", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
|
||||
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
|
||||
"""Test that urljoin is used correctly for URL construction with various base URLs"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
credentials = {"auth_type": "x-api-key", "config": {"api_key": "test_api_key_123", "base_url": base_url}}
|
||||
auth = WatercrawlAuth(credentials)
|
||||
auth.validate_credentials()
|
||||
|
||||
# Verify the correct URL was called
|
||||
assert mock_get.call_args[0][0] == expected_url
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
|
||||
|
||||
with pytest.raises(httpx.TimeoutException) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
|
||||
# Verify the timeout exception is raised with original message
|
||||
assert "timed out" in str(exc_info.value)
|
||||
1093
dify/api/tests/unit_tests/services/segment_service.py
Normal file
1093
dify/api/tests/unit_tests/services/segment_service.py
Normal file
File diff suppressed because it is too large
Load Diff
59
dify/api/tests/unit_tests/services/services_test_help.py
Normal file
59
dify/api/tests/unit_tests/services/services_test_help.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
class ServiceDbTestHelper:
|
||||
"""
|
||||
Helper class for service database query tests.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def setup_db_query_filter_by_mock(mock_db, query_results):
|
||||
"""
|
||||
Smart database query mock that responds based on model type and query parameters.
|
||||
|
||||
Args:
|
||||
mock_db: Mock database session
|
||||
query_results: Dict mapping (model_name, filter_key, filter_value) to return value
|
||||
Example: {('Account', 'email', 'test@example.com'): mock_account}
|
||||
"""
|
||||
|
||||
def query_side_effect(model):
|
||||
mock_query = MagicMock()
|
||||
|
||||
def filter_by_side_effect(**kwargs):
|
||||
mock_filter_result = MagicMock()
|
||||
|
||||
def first_side_effect():
|
||||
# Find matching result based on model and filter parameters
|
||||
for (model_name, filter_key, filter_value), result in query_results.items():
|
||||
if model.__name__ == model_name and filter_key in kwargs and kwargs[filter_key] == filter_value:
|
||||
return result
|
||||
return None
|
||||
|
||||
mock_filter_result.first.side_effect = first_side_effect
|
||||
|
||||
# Handle order_by calls for complex queries
|
||||
def order_by_side_effect(*args, **kwargs):
|
||||
mock_order_result = MagicMock()
|
||||
|
||||
def order_first_side_effect():
|
||||
# Look for order_by results in the same query_results dict
|
||||
for (model_name, filter_key, filter_value), result in query_results.items():
|
||||
if (
|
||||
model.__name__ == model_name
|
||||
and filter_key == "order_by"
|
||||
and filter_value == "first_available"
|
||||
):
|
||||
return result
|
||||
return None
|
||||
|
||||
mock_order_result.first.side_effect = order_first_side_effect
|
||||
return mock_order_result
|
||||
|
||||
mock_filter_result.order_by.side_effect = order_by_side_effect
|
||||
return mock_filter_result
|
||||
|
||||
mock_query.filter_by.side_effect = filter_by_side_effect
|
||||
return mock_query
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
1545
dify/api/tests/unit_tests/services/test_account_service.py
Normal file
1545
dify/api/tests/unit_tests/services/test_account_service.py
Normal file
File diff suppressed because it is too large
Load Diff
106
dify/api/tests/unit_tests/services/test_app_task_service.py
Normal file
106
dify/api/tests/unit_tests/services/test_app_task_service.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.model import AppMode
|
||||
from services.app_task_service import AppTaskService
|
||||
|
||||
|
||||
class TestAppTaskService:
|
||||
"""Test suite for AppTaskService.stop_task method."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("app_mode", "should_call_graph_engine"),
|
||||
[
|
||||
(AppMode.CHAT, False),
|
||||
(AppMode.COMPLETION, False),
|
||||
(AppMode.AGENT_CHAT, False),
|
||||
(AppMode.CHANNEL, False),
|
||||
(AppMode.RAG_PIPELINE, False),
|
||||
(AppMode.ADVANCED_CHAT, True),
|
||||
(AppMode.WORKFLOW, True),
|
||||
],
|
||||
)
|
||||
@patch("services.app_task_service.AppQueueManager")
|
||||
@patch("services.app_task_service.GraphEngineManager")
|
||||
def test_stop_task_with_different_app_modes(
|
||||
self, mock_graph_engine_manager, mock_app_queue_manager, app_mode, should_call_graph_engine
|
||||
):
|
||||
"""Test stop_task behavior with different app modes.
|
||||
|
||||
Verifies that:
|
||||
- Legacy Redis flag is always set via AppQueueManager
|
||||
- GraphEngine stop command is only sent for ADVANCED_CHAT and WORKFLOW modes
|
||||
"""
|
||||
# Arrange
|
||||
task_id = "task-123"
|
||||
invoke_from = InvokeFrom.WEB_APP
|
||||
user_id = "user-456"
|
||||
|
||||
# Act
|
||||
AppTaskService.stop_task(task_id, invoke_from, user_id, app_mode)
|
||||
|
||||
# Assert
|
||||
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
|
||||
if should_call_graph_engine:
|
||||
mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id)
|
||||
else:
|
||||
mock_graph_engine_manager.send_stop_command.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invoke_from",
|
||||
[
|
||||
InvokeFrom.WEB_APP,
|
||||
InvokeFrom.SERVICE_API,
|
||||
InvokeFrom.DEBUGGER,
|
||||
InvokeFrom.EXPLORE,
|
||||
],
|
||||
)
|
||||
@patch("services.app_task_service.AppQueueManager")
|
||||
@patch("services.app_task_service.GraphEngineManager")
|
||||
def test_stop_task_with_different_invoke_sources(
|
||||
self, mock_graph_engine_manager, mock_app_queue_manager, invoke_from
|
||||
):
|
||||
"""Test stop_task behavior with different invoke sources.
|
||||
|
||||
Verifies that the method works correctly regardless of the invoke source.
|
||||
"""
|
||||
# Arrange
|
||||
task_id = "task-789"
|
||||
user_id = "user-999"
|
||||
app_mode = AppMode.ADVANCED_CHAT
|
||||
|
||||
# Act
|
||||
AppTaskService.stop_task(task_id, invoke_from, user_id, app_mode)
|
||||
|
||||
# Assert
|
||||
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
|
||||
mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id)
|
||||
|
||||
@patch("services.app_task_service.GraphEngineManager")
|
||||
@patch("services.app_task_service.AppQueueManager")
|
||||
def test_stop_task_legacy_mechanism_called_even_if_graph_engine_fails(
|
||||
self, mock_app_queue_manager, mock_graph_engine_manager
|
||||
):
|
||||
"""Test that legacy Redis flag is set even if GraphEngine fails.
|
||||
|
||||
This ensures backward compatibility: the legacy mechanism should complete
|
||||
before attempting the GraphEngine command, so the stop flag is set
|
||||
regardless of GraphEngine success.
|
||||
"""
|
||||
# Arrange
|
||||
task_id = "task-123"
|
||||
invoke_from = InvokeFrom.WEB_APP
|
||||
user_id = "user-456"
|
||||
app_mode = AppMode.ADVANCED_CHAT
|
||||
|
||||
# Simulate GraphEngine failure
|
||||
mock_graph_engine_manager.send_stop_command.side_effect = Exception("GraphEngine error")
|
||||
|
||||
# Act & Assert - should raise the exception since it's not caught
|
||||
with pytest.raises(Exception, match="GraphEngine error"):
|
||||
AppTaskService.stop_task(task_id, invoke_from, user_id, app_mode)
|
||||
|
||||
# Verify legacy mechanism was still called before the exception
|
||||
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
|
||||
236
dify/api/tests/unit_tests/services/test_billing_service.py
Normal file
236
dify/api/tests/unit_tests/services/test_billing_service.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
class TestBillingServiceSendRequest:
|
||||
"""Unit tests for BillingService._send_request method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_request(self):
|
||||
"""Mock httpx.request for testing."""
|
||||
with patch("services.billing_service.httpx.request") as mock_request:
|
||||
yield mock_request
|
||||
|
||||
@pytest.fixture
|
||||
def mock_billing_config(self):
|
||||
"""Mock BillingService configuration."""
|
||||
with (
|
||||
patch.object(BillingService, "base_url", "https://billing-api.example.com"),
|
||||
patch.object(BillingService, "secret_key", "test-secret-key"),
|
||||
):
|
||||
yield
|
||||
|
||||
def test_get_request_success(self, mock_httpx_request, mock_billing_config):
|
||||
"""Test successful GET request."""
|
||||
# Arrange
|
||||
expected_response = {"result": "success", "data": {"info": "test"}}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = httpx.codes.OK
|
||||
mock_response.json.return_value = expected_response
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act
|
||||
result = BillingService._send_request("GET", "/test", params={"key": "value"})
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
mock_httpx_request.assert_called_once()
|
||||
call_args = mock_httpx_request.call_args
|
||||
assert call_args[0][0] == "GET"
|
||||
assert call_args[0][1] == "https://billing-api.example.com/test"
|
||||
assert call_args[1]["params"] == {"key": "value"}
|
||||
assert call_args[1]["headers"]["Billing-Api-Secret-Key"] == "test-secret-key"
|
||||
assert call_args[1]["headers"]["Content-Type"] == "application/json"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code", [httpx.codes.NOT_FOUND, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.BAD_REQUEST]
|
||||
)
|
||||
def test_get_request_non_200_status_code(self, mock_httpx_request, mock_billing_config, status_code):
|
||||
"""Test GET request with non-200 status code raises ValueError."""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
BillingService._send_request("GET", "/test")
|
||||
assert "Unable to retrieve billing information" in str(exc_info.value)
|
||||
|
||||
def test_put_request_success(self, mock_httpx_request, mock_billing_config):
|
||||
"""Test successful PUT request."""
|
||||
# Arrange
|
||||
expected_response = {"result": "success"}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = httpx.codes.OK
|
||||
mock_response.json.return_value = expected_response
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act
|
||||
result = BillingService._send_request("PUT", "/test", json={"key": "value"})
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
call_args = mock_httpx_request.call_args
|
||||
assert call_args[0][0] == "PUT"
|
||||
|
||||
def test_put_request_internal_server_error(self, mock_httpx_request, mock_billing_config):
|
||||
"""Test PUT request with INTERNAL_SERVER_ERROR raises InternalServerError."""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = httpx.codes.INTERNAL_SERVER_ERROR
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(InternalServerError) as exc_info:
|
||||
BillingService._send_request("PUT", "/test", json={"key": "value"})
|
||||
assert exc_info.value.code == 500
|
||||
assert "Unable to process billing request" in str(exc_info.value.description)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.NOT_FOUND, httpx.codes.UNAUTHORIZED, httpx.codes.FORBIDDEN]
|
||||
)
|
||||
def test_put_request_non_200_non_500(self, mock_httpx_request, mock_billing_config, status_code):
|
||||
"""Test PUT request with non-200 and non-500 status code raises ValueError."""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
BillingService._send_request("PUT", "/test", json={"key": "value"})
|
||||
assert "Invalid arguments." in str(exc_info.value)
|
||||
|
||||
@pytest.mark.parametrize("method", ["POST", "DELETE"])
|
||||
def test_non_get_non_put_request_success(self, mock_httpx_request, mock_billing_config, method):
|
||||
"""Test successful POST/DELETE request."""
|
||||
# Arrange
|
||||
expected_response = {"result": "success"}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = httpx.codes.OK
|
||||
mock_response.json.return_value = expected_response
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act
|
||||
result = BillingService._send_request(method, "/test", json={"key": "value"})
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
call_args = mock_httpx_request.call_args
|
||||
assert call_args[0][0] == method
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND]
|
||||
)
|
||||
def test_post_request_non_200_with_valid_json(self, mock_httpx_request, mock_billing_config, status_code):
|
||||
"""Test POST request with non-200 status code raises ValueError."""
|
||||
# Arrange
|
||||
error_response = {"detail": "Error message"}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.json.return_value = error_response
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
BillingService._send_request("POST", "/test", json={"key": "value"})
|
||||
assert "Unable to send request to" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND]
|
||||
)
|
||||
def test_delete_request_non_200_with_valid_json(self, mock_httpx_request, mock_billing_config, status_code):
|
||||
"""Test DELETE request with non-200 status code but valid JSON response.
|
||||
|
||||
DELETE doesn't check status code, so it returns the error JSON.
|
||||
"""
|
||||
# Arrange
|
||||
error_response = {"detail": "Error message"}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.json.return_value = error_response
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act
|
||||
result = BillingService._send_request("DELETE", "/test", json={"key": "value"})
|
||||
|
||||
# Assert
|
||||
assert result == error_response
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND]
|
||||
)
|
||||
def test_post_request_non_200_with_invalid_json(self, mock_httpx_request, mock_billing_config, status_code):
|
||||
"""Test POST request with non-200 status code raises ValueError before JSON parsing."""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.text = ""
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
# POST checks status code before calling response.json(), so ValueError is raised
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
BillingService._send_request("POST", "/test", json={"key": "value"})
|
||||
assert "Unable to send request to" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND]
|
||||
)
|
||||
def test_delete_request_non_200_with_invalid_json(self, mock_httpx_request, mock_billing_config, status_code):
|
||||
"""Test DELETE request with non-200 status code and invalid JSON response raises exception.
|
||||
|
||||
DELETE doesn't check status code, so it calls response.json() which raises JSONDecodeError
|
||||
when the response cannot be parsed as JSON (e.g., empty response).
|
||||
"""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.text = ""
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
|
||||
mock_httpx_request.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
BillingService._send_request("DELETE", "/test", json={"key": "value"})
|
||||
|
||||
def test_retry_on_request_error(self, mock_httpx_request, mock_billing_config):
|
||||
"""Test that _send_request retries on httpx.RequestError."""
|
||||
# Arrange
|
||||
expected_response = {"result": "success"}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = httpx.codes.OK
|
||||
mock_response.json.return_value = expected_response
|
||||
|
||||
# First call raises RequestError, second succeeds
|
||||
mock_httpx_request.side_effect = [
|
||||
httpx.RequestError("Network error"),
|
||||
mock_response,
|
||||
]
|
||||
|
||||
# Act
|
||||
result = BillingService._send_request("GET", "/test")
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
assert mock_httpx_request.call_count == 2
|
||||
|
||||
def test_retry_exhausted_raises_exception(self, mock_httpx_request, mock_billing_config):
|
||||
"""Test that _send_request raises exception after retries are exhausted."""
|
||||
# Arrange
|
||||
mock_httpx_request.side_effect = httpx.RequestError("Network error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(httpx.RequestError):
|
||||
BillingService._send_request("GET", "/test")
|
||||
|
||||
# Should retry multiple times (wait=2, stop_before_delay=10 means ~5 attempts)
|
||||
assert mock_httpx_request.call_count > 1
|
||||
@@ -0,0 +1,168 @@
|
||||
import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
|
||||
|
||||
class TestClearFreePlanTenantExpiredLogs:
|
||||
"""Unit tests for ClearFreePlanTenantExpiredLogs._clear_message_related_tables method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock database session."""
|
||||
session = Mock(spec=Session)
|
||||
session.query.return_value.filter.return_value.all.return_value = []
|
||||
session.query.return_value.filter.return_value.delete.return_value = 0
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage(self):
|
||||
"""Create a mock storage object."""
|
||||
storage = Mock()
|
||||
storage.save.return_value = None
|
||||
return storage
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message_ids(self):
|
||||
"""Sample message IDs for testing."""
|
||||
return ["msg-1", "msg-2", "msg-3"]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_records(self):
|
||||
"""Sample records for testing."""
|
||||
records = []
|
||||
for i in range(3):
|
||||
record = Mock()
|
||||
record.id = f"record-{i}"
|
||||
record.to_dict.return_value = {
|
||||
"id": f"record-{i}",
|
||||
"message_id": f"msg-{i}",
|
||||
"created_at": datetime.datetime.now().isoformat(),
|
||||
}
|
||||
records.append(record)
|
||||
return records
|
||||
|
||||
def test_clear_message_related_tables_empty_message_ids(self, mock_session):
|
||||
"""Test that method returns early when message_ids is empty."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", [])
|
||||
|
||||
# Should not call any database operations
|
||||
mock_session.query.assert_not_called()
|
||||
mock_storage.save.assert_not_called()
|
||||
|
||||
def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids):
|
||||
"""Test when no related records are found."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.where.return_value.all.return_value = []
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should call query for each related table but find no records
|
||||
assert mock_session.query.call_count > 0
|
||||
mock_storage.save.assert_not_called()
|
||||
|
||||
def test_clear_message_related_tables_with_records_and_to_dict(
|
||||
self, mock_session, sample_message_ids, sample_records
|
||||
):
|
||||
"""Test when records are found and have to_dict method."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.where.return_value.all.return_value = sample_records
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should call to_dict on each record (called once per table, so 7 times total)
|
||||
for record in sample_records:
|
||||
assert record.to_dict.call_count == 7
|
||||
|
||||
# Should save backup data
|
||||
assert mock_storage.save.call_count > 0
|
||||
|
||||
def test_clear_message_related_tables_with_records_no_to_dict(self, mock_session, sample_message_ids):
|
||||
"""Test when records are found but don't have to_dict method."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
# Create records without to_dict method
|
||||
records = []
|
||||
for i in range(2):
|
||||
record = Mock()
|
||||
mock_table = Mock()
|
||||
mock_id_column = Mock()
|
||||
mock_id_column.name = "id"
|
||||
mock_message_id_column = Mock()
|
||||
mock_message_id_column.name = "message_id"
|
||||
mock_table.columns = [mock_id_column, mock_message_id_column]
|
||||
record.__table__ = mock_table
|
||||
record.id = f"record-{i}"
|
||||
record.message_id = f"msg-{i}"
|
||||
del record.to_dict
|
||||
records.append(record)
|
||||
|
||||
# Mock records for first table only, empty for others
|
||||
mock_session.query.return_value.where.return_value.all.side_effect = [
|
||||
records,
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
]
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should save backup data even without to_dict
|
||||
assert mock_storage.save.call_count > 0
|
||||
|
||||
def test_clear_message_related_tables_storage_error_continues(
|
||||
self, mock_session, sample_message_ids, sample_records
|
||||
):
|
||||
"""Test that method continues even when storage.save fails."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_storage.save.side_effect = Exception("Storage error")
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = sample_records
|
||||
|
||||
# Should not raise exception
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should still delete records even if backup fails
|
||||
assert mock_session.query.return_value.where.return_value.delete.called
|
||||
|
||||
def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids):
|
||||
"""Test that method continues even when record serialization fails."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
record = Mock()
|
||||
record.id = "record-1"
|
||||
record.to_dict.side_effect = Exception("Serialization error")
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [record]
|
||||
|
||||
# Should not raise exception
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should still delete records even if serialization fails
|
||||
assert mock_session.query.return_value.where.return_value.delete.called
|
||||
|
||||
def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records):
|
||||
"""Test that deletion is called for found records."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.where.return_value.all.return_value = sample_records
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should call delete for each table that has records
|
||||
assert mock_session.query.return_value.where.return_value.delete.called
|
||||
|
||||
def test_clear_message_related_tables_logging_output(
|
||||
self, mock_session, sample_message_ids, sample_records, capsys
|
||||
):
|
||||
"""Test that logging output is generated."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.where.return_value.all.return_value = sample_records
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
pass
|
||||
127
dify/api/tests/unit_tests/services/test_conversation_service.py
Normal file
127
dify/api/tests/unit_tests/services/test_conversation_service.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
|
||||
class TestConversationService:
|
||||
def test_pagination_with_empty_include_ids(self):
|
||||
"""Test that empty include_ids returns empty result"""
|
||||
mock_session = MagicMock()
|
||||
mock_app_model = MagicMock(id=str(uuid.uuid4()))
|
||||
mock_user = MagicMock(id=str(uuid.uuid4()))
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=mock_session,
|
||||
app_model=mock_app_model,
|
||||
user=mock_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
include_ids=[], # Empty include_ids should return empty result
|
||||
exclude_ids=None,
|
||||
)
|
||||
|
||||
assert result.data == []
|
||||
assert result.has_more is False
|
||||
assert result.limit == 20
|
||||
|
||||
def test_pagination_with_non_empty_include_ids(self):
|
||||
"""Test that non-empty include_ids filters properly"""
|
||||
mock_session = MagicMock()
|
||||
mock_app_model = MagicMock(id=str(uuid.uuid4()))
|
||||
mock_user = MagicMock(id=str(uuid.uuid4()))
|
||||
|
||||
# Mock the query results
|
||||
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_conversations
|
||||
mock_session.scalar.return_value = 0
|
||||
|
||||
with patch("services.conversation_service.select") as mock_select:
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_stmt.subquery.return_value = MagicMock()
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=mock_session,
|
||||
app_model=mock_app_model,
|
||||
user=mock_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
include_ids=["conv1", "conv2"], # Non-empty include_ids
|
||||
exclude_ids=None,
|
||||
)
|
||||
|
||||
# Verify the where clause was called with id.in_
|
||||
assert mock_stmt.where.called
|
||||
|
||||
def test_pagination_with_empty_exclude_ids(self):
|
||||
"""Test that empty exclude_ids doesn't filter"""
|
||||
mock_session = MagicMock()
|
||||
mock_app_model = MagicMock(id=str(uuid.uuid4()))
|
||||
mock_user = MagicMock(id=str(uuid.uuid4()))
|
||||
|
||||
# Mock the query results
|
||||
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(5)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_conversations
|
||||
mock_session.scalar.return_value = 0
|
||||
|
||||
with patch("services.conversation_service.select") as mock_select:
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_stmt.subquery.return_value = MagicMock()
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=mock_session,
|
||||
app_model=mock_app_model,
|
||||
user=mock_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
include_ids=None,
|
||||
exclude_ids=[], # Empty exclude_ids should not filter
|
||||
)
|
||||
|
||||
# Result should contain the mocked conversations
|
||||
assert len(result.data) == 5
|
||||
|
||||
def test_pagination_with_non_empty_exclude_ids(self):
|
||||
"""Test that non-empty exclude_ids filters properly"""
|
||||
mock_session = MagicMock()
|
||||
mock_app_model = MagicMock(id=str(uuid.uuid4()))
|
||||
mock_user = MagicMock(id=str(uuid.uuid4()))
|
||||
|
||||
# Mock the query results
|
||||
mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)]
|
||||
mock_session.scalars.return_value.all.return_value = mock_conversations
|
||||
mock_session.scalar.return_value = 0
|
||||
|
||||
with patch("services.conversation_service.select") as mock_select:
|
||||
mock_stmt = MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
mock_stmt.limit.return_value = mock_stmt
|
||||
mock_stmt.subquery.return_value = MagicMock()
|
||||
|
||||
result = ConversationService.pagination_by_last_id(
|
||||
session=mock_session,
|
||||
app_model=mock_app_model,
|
||||
user=mock_user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
include_ids=None,
|
||||
exclude_ids=["conv1", "conv2"], # Non-empty exclude_ids
|
||||
)
|
||||
|
||||
# Verify the where clause was called for exclusion
|
||||
assert mock_stmt.where.called
|
||||
305
dify/api/tests/unit_tests/services/test_dataset_permission.py
Normal file
305
dify/api/tests/unit_tests/services/test_dataset_permission.py
Normal file
@@ -0,0 +1,305 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetPermission, DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
|
||||
class DatasetPermissionTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for dataset permission tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
tenant_id: str = "test-tenant-123",
|
||||
created_by: str = "creator-456",
|
||||
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.created_by = created_by
|
||||
dataset.permission = permission
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_user_mock(
|
||||
user_id: str = "user-789",
|
||||
tenant_id: str = "test-tenant-123",
|
||||
role: TenantAccountRole = TenantAccountRole.NORMAL,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock user with specified attributes."""
|
||||
user = Mock(spec=Account)
|
||||
user.id = user_id
|
||||
user.current_tenant_id = tenant_id
|
||||
user.current_role = role
|
||||
for key, value in kwargs.items():
|
||||
setattr(user, key, value)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_permission_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
account_id: str = "user-789",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset permission record."""
|
||||
permission = Mock(spec=DatasetPermission)
|
||||
permission.dataset_id = dataset_id
|
||||
permission.account_id = account_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(permission, key, value)
|
||||
return permission
|
||||
|
||||
|
||||
class TestDatasetPermissionService:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.check_dataset_permission method.
|
||||
|
||||
This test suite covers all permission scenarios including:
|
||||
- Cross-tenant access restrictions
|
||||
- Owner privilege checks
|
||||
- Different permission levels (ONLY_ME, ALL_TEAM, PARTIAL_TEAM)
|
||||
- Explicit permission checks for PARTIAL_TEAM
|
||||
- Error conditions and logging
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""Common mock setup for dataset service dependencies."""
|
||||
with patch("services.dataset_service.db.session") as mock_session:
|
||||
yield {
|
||||
"db_session": mock_session,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logging_dependencies(self):
|
||||
"""Mock setup for logging tests."""
|
||||
with patch("services.dataset_service.logger") as mock_logging:
|
||||
yield {
|
||||
"logging": mock_logging,
|
||||
}
|
||||
|
||||
def _assert_permission_check_passes(self, dataset: Mock, user: Mock):
|
||||
"""Helper method to verify that permission check passes without raising exceptions."""
|
||||
# Should not raise any exception
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
def _assert_permission_check_fails(
|
||||
self, dataset: Mock, user: Mock, expected_message: str = "You do not have permission to access this dataset."
|
||||
):
|
||||
"""Helper method to verify that permission check fails with expected error."""
|
||||
with pytest.raises(NoPermissionError, match=expected_message):
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
def _assert_database_query_called(self, mock_session: Mock, dataset_id: str, account_id: str):
|
||||
"""Helper method to verify database query calls for permission checks."""
|
||||
mock_session.query().filter_by.assert_called_with(dataset_id=dataset_id, account_id=account_id)
|
||||
|
||||
def _assert_database_query_not_called(self, mock_session: Mock):
|
||||
"""Helper method to verify that database query was not called."""
|
||||
mock_session.query.assert_not_called()
|
||||
|
||||
# ==================== Cross-Tenant Access Tests ====================
|
||||
|
||||
def test_permission_check_different_tenant_should_fail(self):
|
||||
"""Test that users from different tenants cannot access dataset regardless of other permissions."""
|
||||
# Create dataset and user from different tenants
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
tenant_id="tenant-123", permission=DatasetPermissionEnum.ALL_TEAM
|
||||
)
|
||||
user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="user-789", tenant_id="different-tenant-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Should fail due to different tenant
|
||||
self._assert_permission_check_fails(dataset, user)
|
||||
|
||||
# ==================== Owner Privilege Tests ====================
|
||||
|
||||
def test_owner_can_access_any_dataset(self):
|
||||
"""Test that tenant owners can access any dataset regardless of permission level."""
|
||||
# Create dataset with restrictive permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME)
|
||||
|
||||
# Create owner user
|
||||
owner_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="owner-999", role=TenantAccountRole.OWNER
|
||||
)
|
||||
|
||||
# Owner should have access regardless of dataset permission
|
||||
self._assert_permission_check_passes(dataset, owner_user)
|
||||
|
||||
# ==================== ONLY_ME Permission Tests ====================
|
||||
|
||||
def test_only_me_permission_creator_can_access(self):
|
||||
"""Test ONLY_ME permission allows only the dataset creator to access."""
|
||||
# Create dataset with ONLY_ME permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME
|
||||
)
|
||||
|
||||
# Create creator user
|
||||
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="creator-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Creator should be able to access
|
||||
self._assert_permission_check_passes(dataset, creator_user)
|
||||
|
||||
def test_only_me_permission_others_cannot_access(self):
|
||||
"""Test ONLY_ME permission denies access to non-creators."""
|
||||
# Create dataset with ONLY_ME permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME
|
||||
)
|
||||
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Non-creator should be denied access
|
||||
self._assert_permission_check_fails(dataset, normal_user)
|
||||
|
||||
# ==================== ALL_TEAM Permission Tests ====================
|
||||
|
||||
def test_all_team_permission_allows_access(self):
|
||||
"""Test ALL_TEAM permission allows any team member to access the dataset."""
|
||||
# Create dataset with ALL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ALL_TEAM)
|
||||
|
||||
# Create different types of team members
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
editor_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="editor-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# All team members should have access
|
||||
self._assert_permission_check_passes(dataset, normal_user)
|
||||
self._assert_permission_check_passes(dataset, editor_user)
|
||||
|
||||
# ==================== PARTIAL_TEAM Permission Tests ====================
|
||||
|
||||
def test_partial_team_permission_creator_can_access(self, mock_dataset_service_dependencies):
|
||||
"""Test PARTIAL_TEAM permission allows creator to access without database query."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
|
||||
)
|
||||
|
||||
# Create creator user
|
||||
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="creator-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Creator should have access without database query
|
||||
self._assert_permission_check_passes(dataset, creator_user)
|
||||
self._assert_database_query_not_called(mock_dataset_service_dependencies["db_session"])
|
||||
|
||||
def test_partial_team_permission_with_explicit_permission(self, mock_dataset_service_dependencies):
|
||||
"""Test PARTIAL_TEAM permission allows users with explicit permission records."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return a permission record
|
||||
mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock(
|
||||
dataset_id=dataset.id, account_id=normal_user.id
|
||||
)
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = mock_permission
|
||||
|
||||
# User with explicit permission should have access
|
||||
self._assert_permission_check_passes(dataset, normal_user)
|
||||
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id)
|
||||
|
||||
def test_partial_team_permission_without_explicit_permission(self, mock_dataset_service_dependencies):
|
||||
"""Test PARTIAL_TEAM permission denies users without explicit permission records."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return None (no permission record)
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
|
||||
|
||||
# User without explicit permission should be denied access
|
||||
self._assert_permission_check_fails(dataset, normal_user)
|
||||
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id)
|
||||
|
||||
def test_partial_team_permission_non_creator_without_permission_fails(self, mock_dataset_service_dependencies):
|
||||
"""Test that non-creators without explicit permission are denied access to PARTIAL_TEAM datasets."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
|
||||
)
|
||||
|
||||
# Create a different user (not the creator)
|
||||
other_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="other-user-123", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return None (no permission record)
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
|
||||
|
||||
# Non-creator without explicit permission should be denied access
|
||||
self._assert_permission_check_fails(dataset, other_user)
|
||||
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, other_user.id)
|
||||
|
||||
# ==================== Enum Usage Tests ====================
|
||||
|
||||
def test_partial_team_permission_uses_correct_enum(self):
|
||||
"""Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM instead of string literals."""
|
||||
# Create dataset with PARTIAL_TEAM permission using enum
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
|
||||
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
|
||||
)
|
||||
|
||||
# Create creator user
|
||||
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="creator-456", role=TenantAccountRole.EDITOR
|
||||
)
|
||||
|
||||
# Creator should always have access regardless of permission level
|
||||
self._assert_permission_check_passes(dataset, creator_user)
|
||||
|
||||
# ==================== Logging Tests ====================
|
||||
|
||||
def test_permission_denied_logs_debug_message(self, mock_dataset_service_dependencies, mock_logging_dependencies):
|
||||
"""Test that permission denied events are properly logged for debugging purposes."""
|
||||
# Create dataset with PARTIAL_TEAM permission
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
||||
|
||||
# Create normal user (not the creator)
|
||||
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
|
||||
user_id="normal-789", role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock database query to return None (no permission record)
|
||||
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
|
||||
|
||||
# Attempt permission check (should fail)
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.check_dataset_permission(dataset, normal_user)
|
||||
|
||||
# Verify debug message was logged with correct user and dataset information
|
||||
mock_logging_dependencies["logging"].debug.assert_called_with(
|
||||
"User %s does not have permission to access dataset %s", normal_user.id, dataset.id
|
||||
)
|
||||
@@ -0,0 +1,800 @@
|
||||
import datetime
|
||||
|
||||
# Mock redis_client before importing dataset_service
|
||||
from unittest.mock import Mock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.dataset import Dataset, Document
|
||||
from services.dataset_service import DocumentService
|
||||
from services.errors.document import DocumentIndexingError
|
||||
from tests.unit_tests.conftest import redis_mock
|
||||
|
||||
|
||||
class DocumentBatchUpdateTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for document batch update tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(dataset_id: str = "dataset-123", tenant_id: str = "tenant-456") -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_user_mock(user_id: str = "user-789") -> Mock:
|
||||
"""Create a mock user."""
|
||||
user = Mock()
|
||||
user.id = user_id
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_document_mock(
|
||||
document_id: str = "doc-1",
|
||||
name: str = "test_document.pdf",
|
||||
enabled: bool = True,
|
||||
archived: bool = False,
|
||||
indexing_status: str = "completed",
|
||||
completed_at: datetime.datetime | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock document with specified attributes."""
|
||||
document = Mock(spec=Document)
|
||||
document.id = document_id
|
||||
document.name = name
|
||||
document.enabled = enabled
|
||||
document.archived = archived
|
||||
document.indexing_status = indexing_status
|
||||
document.completed_at = completed_at or datetime.datetime.now()
|
||||
|
||||
# Set default values for optional fields
|
||||
document.disabled_at = None
|
||||
document.disabled_by = None
|
||||
document.archived_at = None
|
||||
document.archived_by = None
|
||||
document.updated_at = None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(document, key, value)
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def create_multiple_documents(
|
||||
document_ids: list[str], enabled: bool = True, archived: bool = False, indexing_status: str = "completed"
|
||||
) -> list[Mock]:
|
||||
"""Create multiple mock documents with specified attributes."""
|
||||
documents = []
|
||||
for doc_id in document_ids:
|
||||
doc = DocumentBatchUpdateTestDataFactory.create_document_mock(
|
||||
document_id=doc_id,
|
||||
name=f"document_{doc_id}.pdf",
|
||||
enabled=enabled,
|
||||
archived=archived,
|
||||
indexing_status=indexing_status,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
|
||||
class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
"""
|
||||
Comprehensive unit tests for DocumentService.batch_update_document_status method.
|
||||
|
||||
This test suite covers all supported actions (enable, disable, archive, un_archive),
|
||||
error conditions, edge cases, and validates proper interaction with Redis cache,
|
||||
database operations, and async task triggers.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_service_dependencies(self):
|
||||
"""Common mock setup for document service dependencies."""
|
||||
with (
|
||||
patch("services.dataset_service.DocumentService.get_document") as mock_get_doc,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
||||
):
|
||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
mock_naive_utc_now.return_value = current_time
|
||||
|
||||
yield {
|
||||
"get_document": mock_get_doc,
|
||||
"db_session": mock_db,
|
||||
"naive_utc_now": mock_naive_utc_now,
|
||||
"current_time": current_time,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_task_dependencies(self):
|
||||
"""Mock setup for async task dependencies."""
|
||||
with (
|
||||
patch("services.dataset_service.add_document_to_index_task") as mock_add_task,
|
||||
patch("services.dataset_service.remove_document_from_index_task") as mock_remove_task,
|
||||
):
|
||||
yield {"add_task": mock_add_task, "remove_task": mock_remove_task}
|
||||
|
||||
def _assert_document_enabled(self, document: Mock, user_id: str, current_time: datetime.datetime):
|
||||
"""Helper method to verify document was enabled correctly."""
|
||||
assert document.enabled == True
|
||||
assert document.disabled_at is None
|
||||
assert document.disabled_by is None
|
||||
assert document.updated_at == current_time
|
||||
|
||||
def _assert_document_disabled(self, document: Mock, user_id: str, current_time: datetime.datetime):
|
||||
"""Helper method to verify document was disabled correctly."""
|
||||
assert document.enabled == False
|
||||
assert document.disabled_at == current_time
|
||||
assert document.disabled_by == user_id
|
||||
assert document.updated_at == current_time
|
||||
|
||||
def _assert_document_archived(self, document: Mock, user_id: str, current_time: datetime.datetime):
|
||||
"""Helper method to verify document was archived correctly."""
|
||||
assert document.archived == True
|
||||
assert document.archived_at == current_time
|
||||
assert document.archived_by == user_id
|
||||
assert document.updated_at == current_time
|
||||
|
||||
def _assert_document_unarchived(self, document: Mock):
|
||||
"""Helper method to verify document was unarchived correctly."""
|
||||
assert document.archived == False
|
||||
assert document.archived_at is None
|
||||
assert document.archived_by is None
|
||||
|
||||
def _assert_redis_cache_operations(self, document_ids: list[str], action: str = "setex"):
|
||||
"""Helper method to verify Redis cache operations."""
|
||||
if action == "setex":
|
||||
expected_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids]
|
||||
redis_mock.setex.assert_has_calls(expected_calls)
|
||||
elif action == "get":
|
||||
expected_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids]
|
||||
redis_mock.get.assert_has_calls(expected_calls)
|
||||
|
||||
def _assert_async_task_calls(self, mock_task, document_ids: list[str], task_type: str):
|
||||
"""Helper method to verify async task calls."""
|
||||
expected_calls = [call(doc_id) for doc_id in document_ids]
|
||||
if task_type in {"add", "remove"}:
|
||||
mock_task.delay.assert_has_calls(expected_calls)
|
||||
|
||||
# ==================== Enable Document Tests ====================
|
||||
|
||||
def test_batch_update_enable_documents_success(
|
||||
self, mock_document_service_dependencies, mock_async_task_dependencies
|
||||
):
|
||||
"""Test successful enabling of disabled documents."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create disabled documents
|
||||
disabled_docs = DocumentBatchUpdateTestDataFactory.create_multiple_documents(["doc-1", "doc-2"], enabled=False)
|
||||
mock_document_service_dependencies["get_document"].side_effect = disabled_docs
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Call the method to enable documents
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1", "doc-2"], action="enable", user=user
|
||||
)
|
||||
|
||||
# Verify document attributes were updated correctly
|
||||
for doc in disabled_docs:
|
||||
self._assert_document_enabled(doc, user.id, mock_document_service_dependencies["current_time"])
|
||||
|
||||
# Verify Redis cache operations
|
||||
self._assert_redis_cache_operations(["doc-1", "doc-2"], "get")
|
||||
self._assert_redis_cache_operations(["doc-1", "doc-2"], "setex")
|
||||
|
||||
# Verify async tasks were triggered for indexing
|
||||
self._assert_async_task_calls(mock_async_task_dependencies["add_task"], ["doc-1", "doc-2"], "add")
|
||||
|
||||
# Verify database operations
|
||||
mock_db = mock_document_service_dependencies["db_session"]
|
||||
assert mock_db.add.call_count == 2
|
||||
assert mock_db.commit.call_count == 1
|
||||
|
||||
def test_batch_update_enable_already_enabled_document_skipped(self, mock_document_service_dependencies):
|
||||
"""Test enabling documents that are already enabled."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create already enabled document
|
||||
enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True)
|
||||
mock_document_service_dependencies["get_document"].return_value = enabled_doc
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Attempt to enable already enabled document
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1"], action="enable", user=user
|
||||
)
|
||||
|
||||
# Verify no database operations occurred (document was skipped)
|
||||
mock_db = mock_document_service_dependencies["db_session"]
|
||||
mock_db.commit.assert_not_called()
|
||||
|
||||
# Verify no Redis setex operations occurred (document was skipped)
|
||||
redis_mock.setex.assert_not_called()
|
||||
|
||||
# ==================== Disable Document Tests ====================
|
||||
|
||||
def test_batch_update_disable_documents_success(
|
||||
self, mock_document_service_dependencies, mock_async_task_dependencies
|
||||
):
|
||||
"""Test successful disabling of enabled and completed documents."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create enabled documents
|
||||
enabled_docs = DocumentBatchUpdateTestDataFactory.create_multiple_documents(["doc-1", "doc-2"], enabled=True)
|
||||
mock_document_service_dependencies["get_document"].side_effect = enabled_docs
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Call the method to disable documents
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1", "doc-2"], action="disable", user=user
|
||||
)
|
||||
|
||||
# Verify document attributes were updated correctly
|
||||
for doc in enabled_docs:
|
||||
self._assert_document_disabled(doc, user.id, mock_document_service_dependencies["current_time"])
|
||||
|
||||
# Verify Redis cache operations for indexing prevention
|
||||
self._assert_redis_cache_operations(["doc-1", "doc-2"], "setex")
|
||||
|
||||
# Verify async tasks were triggered to remove from index
|
||||
self._assert_async_task_calls(mock_async_task_dependencies["remove_task"], ["doc-1", "doc-2"], "remove")
|
||||
|
||||
# Verify database operations
|
||||
mock_db = mock_document_service_dependencies["db_session"]
|
||||
assert mock_db.add.call_count == 2
|
||||
assert mock_db.commit.call_count == 1
|
||||
|
||||
def test_batch_update_disable_already_disabled_document_skipped(
|
||||
self, mock_document_service_dependencies, mock_async_task_dependencies
|
||||
):
|
||||
"""Test disabling documents that are already disabled."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create already disabled document
|
||||
disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False)
|
||||
mock_document_service_dependencies["get_document"].return_value = disabled_doc
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Attempt to disable already disabled document
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1"], action="disable", user=user
|
||||
)
|
||||
|
||||
# Verify no database operations occurred (document was skipped)
|
||||
mock_db = mock_document_service_dependencies["db_session"]
|
||||
mock_db.commit.assert_not_called()
|
||||
|
||||
# Verify no Redis setex operations occurred (document was skipped)
|
||||
redis_mock.setex.assert_not_called()
|
||||
|
||||
# Verify no async tasks were triggered (document was skipped)
|
||||
mock_async_task_dependencies["add_task"].delay.assert_not_called()
|
||||
|
||||
def test_batch_update_disable_non_completed_document_error(self, mock_document_service_dependencies):
|
||||
"""Test that DocumentIndexingError is raised when trying to disable non-completed documents."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create a document that's not completed
|
||||
non_completed_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(
|
||||
enabled=True,
|
||||
indexing_status="indexing", # Not completed
|
||||
completed_at=None, # Not completed
|
||||
)
|
||||
mock_document_service_dependencies["get_document"].return_value = non_completed_doc
|
||||
|
||||
# Verify that DocumentIndexingError is raised
|
||||
with pytest.raises(DocumentIndexingError) as exc_info:
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1"], action="disable", user=user
|
||||
)
|
||||
|
||||
# Verify error message indicates document is not completed
|
||||
assert "is not completed" in str(exc_info.value)
|
||||
|
||||
# ==================== Archive Document Tests ====================
|
||||
|
||||
def test_batch_update_archive_documents_success(
|
||||
self, mock_document_service_dependencies, mock_async_task_dependencies
|
||||
):
|
||||
"""Test successful archiving of unarchived documents."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create unarchived enabled document
|
||||
unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=False)
|
||||
mock_document_service_dependencies["get_document"].return_value = unarchived_doc
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Call the method to archive documents
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1"], action="archive", user=user
|
||||
)
|
||||
|
||||
# Verify document attributes were updated correctly
|
||||
self._assert_document_archived(unarchived_doc, user.id, mock_document_service_dependencies["current_time"])
|
||||
|
||||
# Verify Redis cache was set (because document was enabled)
|
||||
redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1)
|
||||
|
||||
# Verify async task was triggered to remove from index (because enabled)
|
||||
mock_async_task_dependencies["remove_task"].delay.assert_called_once_with("doc-1")
|
||||
|
||||
# Verify database operations
|
||||
mock_db = mock_document_service_dependencies["db_session"]
|
||||
mock_db.add.assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_batch_update_archive_already_archived_document_skipped(self, mock_document_service_dependencies):
|
||||
"""Test archiving documents that are already archived."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create already archived document
|
||||
archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=True)
|
||||
mock_document_service_dependencies["get_document"].return_value = archived_doc
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Attempt to archive already archived document
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-3"], action="archive", user=user
|
||||
)
|
||||
|
||||
# Verify no database operations occurred (document was skipped)
|
||||
mock_db = mock_document_service_dependencies["db_session"]
|
||||
mock_db.commit.assert_not_called()
|
||||
|
||||
# Verify no Redis setex operations occurred (document was skipped)
|
||||
redis_mock.setex.assert_not_called()
|
||||
|
||||
def test_batch_update_archive_disabled_document_no_index_removal(
|
||||
self, mock_document_service_dependencies, mock_async_task_dependencies
|
||||
):
|
||||
"""Test archiving disabled documents (should not trigger index removal)."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Set up disabled, unarchived document
|
||||
disabled_unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False, archived=False)
|
||||
mock_document_service_dependencies["get_document"].return_value = disabled_unarchived_doc
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Archive the disabled document
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1"], action="archive", user=user
|
||||
)
|
||||
|
||||
# Verify document was archived
|
||||
self._assert_document_archived(
|
||||
disabled_unarchived_doc, user.id, mock_document_service_dependencies["current_time"]
|
||||
)
|
||||
|
||||
# Verify no Redis cache was set (document is disabled)
|
||||
redis_mock.setex.assert_not_called()
|
||||
|
||||
# Verify no index removal task was triggered (document is disabled)
|
||||
mock_async_task_dependencies["remove_task"].delay.assert_not_called()
|
||||
|
||||
# Verify database operations still occurred
|
||||
mock_db = mock_document_service_dependencies["db_session"]
|
||||
mock_db.add.assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
# ==================== Unarchive Document Tests ====================
|
||||
|
||||
def test_batch_update_unarchive_documents_success(
|
||||
self, mock_document_service_dependencies, mock_async_task_dependencies
|
||||
):
|
||||
"""Test successful unarchiving of archived documents."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create mock archived document
|
||||
archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=True)
|
||||
mock_document_service_dependencies["get_document"].return_value = archived_doc
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Call the method to unarchive documents
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user
|
||||
)
|
||||
|
||||
# Verify document attributes were updated correctly
|
||||
self._assert_document_unarchived(archived_doc)
|
||||
assert archived_doc.updated_at == mock_document_service_dependencies["current_time"]
|
||||
|
||||
# Verify Redis cache was set (because document is enabled)
|
||||
redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1)
|
||||
|
||||
# Verify async task was triggered to add back to index (because enabled)
|
||||
mock_async_task_dependencies["add_task"].delay.assert_called_once_with("doc-1")
|
||||
|
||||
# Verify database operations
|
||||
mock_db = mock_document_service_dependencies["db_session"]
|
||||
mock_db.add.assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_batch_update_unarchive_already_unarchived_document_skipped(
|
||||
self, mock_document_service_dependencies, mock_async_task_dependencies
|
||||
):
|
||||
"""Test unarchiving documents that are already unarchived."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create already unarchived document
|
||||
unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=False)
|
||||
mock_document_service_dependencies["get_document"].return_value = unarchived_doc
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Attempt to unarchive already unarchived document
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user
|
||||
)
|
||||
|
||||
# Verify no database operations occurred (document was skipped)
|
||||
mock_db = mock_document_service_dependencies["db_session"]
|
||||
mock_db.commit.assert_not_called()
|
||||
|
||||
# Verify no Redis setex operations occurred (document was skipped)
|
||||
redis_mock.setex.assert_not_called()
|
||||
|
||||
# Verify no async tasks were triggered (document was skipped)
|
||||
mock_async_task_dependencies["add_task"].delay.assert_not_called()
|
||||
|
||||
def test_batch_update_unarchive_disabled_document_no_index_addition(
|
||||
self, mock_document_service_dependencies, mock_async_task_dependencies
|
||||
):
|
||||
"""Test unarchiving disabled documents (should not trigger index addition)."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create mock archived but disabled document
|
||||
archived_disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False, archived=True)
|
||||
mock_document_service_dependencies["get_document"].return_value = archived_disabled_doc
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Unarchive the disabled document
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user
|
||||
)
|
||||
|
||||
# Verify document was unarchived
|
||||
self._assert_document_unarchived(archived_disabled_doc)
|
||||
assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"]
|
||||
|
||||
# Verify no Redis cache was set (document is disabled)
|
||||
redis_mock.setex.assert_not_called()
|
||||
|
||||
# Verify no index addition task was triggered (document is disabled)
|
||||
mock_async_task_dependencies["add_task"].delay.assert_not_called()
|
||||
|
||||
# Verify database operations still occurred
|
||||
mock_db = mock_document_service_dependencies["db_session"]
|
||||
mock_db.add.assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
def test_batch_update_document_indexing_error_redis_cache_hit(self, mock_document_service_dependencies):
|
||||
"""Test that DocumentIndexingError is raised when documents are currently being indexed."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create mock enabled document
|
||||
enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True)
|
||||
mock_document_service_dependencies["get_document"].return_value = enabled_doc
|
||||
|
||||
# Set up mock to indicate document is being indexed
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = "indexing"
|
||||
|
||||
# Verify that DocumentIndexingError is raised
|
||||
with pytest.raises(DocumentIndexingError) as exc_info:
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1"], action="enable", user=user
|
||||
)
|
||||
|
||||
# Verify error message contains document name
|
||||
assert "test_document.pdf" in str(exc_info.value)
|
||||
assert "is being indexed" in str(exc_info.value)
|
||||
|
||||
# Verify Redis cache was checked
|
||||
redis_mock.get.assert_called_once_with("document_doc-1_indexing")
|
||||
|
||||
def test_batch_update_invalid_action_error(self, mock_document_service_dependencies):
|
||||
"""Test that ValueError is raised when an invalid action is provided."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create mock document
|
||||
doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True)
|
||||
mock_document_service_dependencies["get_document"].return_value = doc
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Test with invalid action
|
||||
invalid_action = "invalid_action"
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1"], action=invalid_action, user=user
|
||||
)
|
||||
|
||||
# Verify error message contains the invalid action
|
||||
assert invalid_action in str(exc_info.value)
|
||||
assert "Invalid action" in str(exc_info.value)
|
||||
|
||||
# Verify no Redis operations occurred
|
||||
redis_mock.setex.assert_not_called()
|
||||
|
||||
def test_batch_update_async_task_error_handling(
|
||||
self, mock_document_service_dependencies, mock_async_task_dependencies
|
||||
):
|
||||
"""Test handling of async task errors during batch operations."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create mock disabled document
|
||||
disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False)
|
||||
mock_document_service_dependencies["get_document"].return_value = disabled_doc
|
||||
|
||||
# Mock async task to raise an exception
|
||||
mock_async_task_dependencies["add_task"].delay.side_effect = Exception("Celery task error")
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Verify that async task error is propagated
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1"], action="enable", user=user
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert "Celery task error" in str(exc_info.value)
|
||||
|
||||
# Verify database operations completed successfully
|
||||
mock_db = mock_document_service_dependencies["db_session"]
|
||||
mock_db.add.assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
# Verify Redis cache was set successfully
|
||||
redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1)
|
||||
|
||||
# Verify document was updated
|
||||
self._assert_document_enabled(disabled_doc, user.id, mock_document_service_dependencies["current_time"])
|
||||
|
||||
# ==================== Edge Case Tests ====================
|
||||
|
||||
def test_batch_update_empty_document_list(self, mock_document_service_dependencies):
|
||||
"""Test batch operations with an empty document ID list."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Call method with empty document list
|
||||
result = DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=[], action="enable", user=user
|
||||
)
|
||||
|
||||
# Verify no document lookups were performed
|
||||
mock_document_service_dependencies["get_document"].assert_not_called()
|
||||
|
||||
# Verify method returns None (early return)
|
||||
assert result is None
|
||||
|
||||
def test_batch_update_document_not_found_skipped(self, mock_document_service_dependencies):
|
||||
"""Test behavior when some documents don't exist in the database."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Mock document service to return None (document not found)
|
||||
mock_document_service_dependencies["get_document"].return_value = None
|
||||
|
||||
# Call method with non-existent document ID
|
||||
# This should not raise an error, just skip the missing document
|
||||
try:
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["non-existent-doc"], action="enable", user=user
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Method should not raise exception for missing documents: {e}")
|
||||
|
||||
# Verify document lookup was attempted
|
||||
mock_document_service_dependencies["get_document"].assert_called_once_with(dataset.id, "non-existent-doc")
|
||||
|
||||
def test_batch_update_mixed_document_states_and_actions(
|
||||
self, mock_document_service_dependencies, mock_async_task_dependencies
|
||||
):
|
||||
"""Test batch operations on documents with mixed states and various scenarios."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create documents in various states
|
||||
disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-1", enabled=False)
|
||||
enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-2", enabled=True)
|
||||
archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-3", enabled=True, archived=True)
|
||||
|
||||
# Mix of different document states
|
||||
documents = [disabled_doc, enabled_doc, archived_doc]
|
||||
mock_document_service_dependencies["get_document"].side_effect = documents
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Perform enable operation on mixed state documents
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=["doc-1", "doc-2", "doc-3"], action="enable", user=user
|
||||
)
|
||||
|
||||
# Verify only the disabled document was processed
|
||||
# (enabled and archived documents should be skipped for enable action)
|
||||
|
||||
# Only one add should occur (for the disabled document that was enabled)
|
||||
mock_db = mock_document_service_dependencies["db_session"]
|
||||
mock_db.add.assert_called_once()
|
||||
# Only one commit should occur
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
# Only one Redis setex should occur (for the document that was enabled)
|
||||
redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1)
|
||||
|
||||
# Only one async task should be triggered (for the document that was enabled)
|
||||
mock_async_task_dependencies["add_task"].delay.assert_called_once_with("doc-1")
|
||||
|
||||
# ==================== Performance Tests ====================
|
||||
|
||||
def test_batch_update_large_document_list_performance(
|
||||
self, mock_document_service_dependencies, mock_async_task_dependencies
|
||||
):
|
||||
"""Test batch operations with a large number of documents."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create large list of document IDs
|
||||
document_ids = [f"doc-{i}" for i in range(1, 101)] # 100 documents
|
||||
|
||||
# Create mock documents
|
||||
mock_documents = DocumentBatchUpdateTestDataFactory.create_multiple_documents(
|
||||
document_ids,
|
||||
enabled=False, # All disabled, will be enabled
|
||||
)
|
||||
mock_document_service_dependencies["get_document"].side_effect = mock_documents
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Perform batch enable operation
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset, document_ids=document_ids, action="enable", user=user
|
||||
)
|
||||
|
||||
# Verify all documents were processed
|
||||
assert mock_document_service_dependencies["get_document"].call_count == 100
|
||||
|
||||
# Verify all documents were updated
|
||||
for mock_doc in mock_documents:
|
||||
self._assert_document_enabled(mock_doc, user.id, mock_document_service_dependencies["current_time"])
|
||||
|
||||
# Verify database operations
|
||||
mock_db = mock_document_service_dependencies["db_session"]
|
||||
assert mock_db.add.call_count == 100
|
||||
assert mock_db.commit.call_count == 1
|
||||
|
||||
# Verify Redis cache operations occurred for each document
|
||||
assert redis_mock.setex.call_count == 100
|
||||
|
||||
# Verify async tasks were triggered for each document
|
||||
assert mock_async_task_dependencies["add_task"].delay.call_count == 100
|
||||
|
||||
# Verify correct Redis cache keys were set
|
||||
expected_redis_calls = [call(f"document_doc-{i}_indexing", 600, 1) for i in range(1, 101)]
|
||||
redis_mock.setex.assert_has_calls(expected_redis_calls)
|
||||
|
||||
# Verify correct async task calls
|
||||
expected_task_calls = [call(f"doc-{i}") for i in range(1, 101)]
|
||||
mock_async_task_dependencies["add_task"].delay.assert_has_calls(expected_task_calls)
|
||||
|
||||
def test_batch_update_mixed_document_states_complex_scenario(
|
||||
self, mock_document_service_dependencies, mock_async_task_dependencies
|
||||
):
|
||||
"""Test complex batch operations with documents in various states."""
|
||||
dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock()
|
||||
user = DocumentBatchUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Create documents in various states
|
||||
doc1 = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-1", enabled=False) # Will be enabled
|
||||
doc2 = DocumentBatchUpdateTestDataFactory.create_document_mock(
|
||||
"doc-2", enabled=True
|
||||
) # Already enabled, will be skipped
|
||||
doc3 = DocumentBatchUpdateTestDataFactory.create_document_mock(
|
||||
"doc-3", enabled=True
|
||||
) # Already enabled, will be skipped
|
||||
doc4 = DocumentBatchUpdateTestDataFactory.create_document_mock(
|
||||
"doc-4", enabled=True
|
||||
) # Not affected by enable action
|
||||
doc5 = DocumentBatchUpdateTestDataFactory.create_document_mock(
|
||||
"doc-5", enabled=True, archived=True
|
||||
) # Not affected by enable action
|
||||
doc6 = None # Non-existent, will be skipped
|
||||
|
||||
mock_document_service_dependencies["get_document"].side_effect = [doc1, doc2, doc3, doc4, doc5, doc6]
|
||||
|
||||
# Reset module-level Redis mock
|
||||
redis_mock.reset_mock()
|
||||
redis_mock.get.return_value = None
|
||||
|
||||
# Perform mixed batch operations
|
||||
DocumentService.batch_update_document_status(
|
||||
dataset=dataset,
|
||||
document_ids=["doc-1", "doc-2", "doc-3", "doc-4", "doc-5", "doc-6"],
|
||||
action="enable", # This will only affect doc1
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Verify document 1 was enabled
|
||||
self._assert_document_enabled(doc1, user.id, mock_document_service_dependencies["current_time"])
|
||||
|
||||
# Verify other documents were skipped appropriately
|
||||
assert doc2.enabled == True # No change
|
||||
assert doc3.enabled == True # No change
|
||||
assert doc4.enabled == True # No change
|
||||
assert doc5.enabled == True # No change
|
||||
|
||||
# Verify database commits occurred for processed documents
|
||||
# Only doc1 should be added (others were skipped, doc6 doesn't exist)
|
||||
mock_db = mock_document_service_dependencies["db_session"]
|
||||
assert mock_db.add.call_count == 1
|
||||
assert mock_db.commit.call_count == 1
|
||||
|
||||
# Verify Redis cache operations occurred for processed documents
|
||||
# Only doc1 should have Redis operations
|
||||
assert redis_mock.setex.call_count == 1
|
||||
|
||||
# Verify async tasks were triggered for processed documents
|
||||
# Only doc1 should trigger tasks
|
||||
assert mock_async_task_dependencies["add_task"].delay.call_count == 1
|
||||
|
||||
# Verify correct Redis cache keys were set
|
||||
expected_redis_calls = [call("document_doc-1_indexing", 600, 1)]
|
||||
redis_mock.setex.assert_has_calls(expected_redis_calls)
|
||||
|
||||
# Verify correct async task calls
|
||||
expected_task_calls = [call("doc-1")]
|
||||
mock_async_task_dependencies["add_task"].delay.assert_has_calls(expected_task_calls)
|
||||
@@ -0,0 +1,819 @@
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService creation methods.
|
||||
|
||||
This test suite covers:
|
||||
- create_empty_dataset for internal datasets
|
||||
- create_empty_dataset for external datasets
|
||||
- create_empty_rag_pipeline_dataset
|
||||
- Error conditions and edge cases
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, Pipeline
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
IconInfo,
|
||||
RagPipelineDatasetCreateEntity,
|
||||
)
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
|
||||
|
||||
class DatasetCreateTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for dataset creation tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_account_mock(
|
||||
account_id: str = "account-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock account."""
|
||||
account = create_autospec(Account, instance=True)
|
||||
account.id = account_id
|
||||
account.current_tenant_id = tenant_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(account, key, value)
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
|
||||
"""Create a mock embedding model."""
|
||||
embedding_model = Mock()
|
||||
embedding_model.model = model
|
||||
embedding_model.provider = provider
|
||||
return embedding_model
|
||||
|
||||
@staticmethod
|
||||
def create_retrieval_model_mock() -> Mock:
|
||||
"""Create a mock retrieval model."""
|
||||
retrieval_model = Mock(spec=RetrievalModel)
|
||||
retrieval_model.model_dump.return_value = {
|
||||
"search_method": "semantic_search",
|
||||
"top_k": 2,
|
||||
"score_threshold": 0.0,
|
||||
}
|
||||
retrieval_model.reranking_model = None
|
||||
return retrieval_model
|
||||
|
||||
@staticmethod
|
||||
def create_external_knowledge_api_mock(api_id: str = "api-123", **kwargs) -> Mock:
|
||||
"""Create a mock external knowledge API."""
|
||||
api = Mock()
|
||||
api.id = api_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(api, key, value)
|
||||
return api
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
name: str = "Test Dataset",
|
||||
tenant_id: str = "tenant-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset."""
|
||||
dataset = create_autospec(Dataset, instance=True)
|
||||
dataset.id = dataset_id
|
||||
dataset.name = name
|
||||
dataset.tenant_id = tenant_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_pipeline_mock(
|
||||
pipeline_id: str = "pipeline-123",
|
||||
name: str = "Test Pipeline",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock pipeline."""
|
||||
pipeline = Mock(spec=Pipeline)
|
||||
pipeline.id = pipeline_id
|
||||
pipeline.name = name
|
||||
for key, value in kwargs.items():
|
||||
setattr(pipeline, key, value)
|
||||
return pipeline
|
||||
|
||||
|
||||
class TestDatasetServiceCreateEmptyDataset:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.create_empty_dataset method.
|
||||
|
||||
This test suite covers:
|
||||
- Internal dataset creation (vendor provider)
|
||||
- External dataset creation
|
||||
- High quality indexing technique with embedding models
|
||||
- Economy indexing technique
|
||||
- Retrieval model configuration
|
||||
- Error conditions (duplicate names, missing external knowledge IDs)
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""Common mock setup for dataset service dependencies."""
|
||||
with (
|
||||
patch("services.dataset_service.db.session") as mock_db,
|
||||
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
||||
patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding,
|
||||
patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking,
|
||||
patch("services.dataset_service.ExternalDatasetService") as mock_external_service,
|
||||
):
|
||||
yield {
|
||||
"db_session": mock_db,
|
||||
"model_manager": mock_model_manager,
|
||||
"check_embedding": mock_check_embedding,
|
||||
"check_reranking": mock_check_reranking,
|
||||
"external_service": mock_external_service,
|
||||
}
|
||||
|
||||
# ==================== Internal Dataset Creation Tests ====================
|
||||
|
||||
def test_create_internal_dataset_basic_success(self, mock_dataset_service_dependencies):
|
||||
"""Test successful creation of basic internal dataset."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "Test Dataset"
|
||||
description = "Test description"
|
||||
|
||||
# Mock database query to return None (no duplicate name)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock database session operations
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=description,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.name == name
|
||||
assert result.description == description
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.created_by == account.id
|
||||
assert result.updated_by == account.id
|
||||
assert result.provider == "vendor"
|
||||
assert result.permission == "only_me"
|
||||
mock_db.add.assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_internal_dataset_with_economy_indexing(self, mock_dataset_service_dependencies):
|
||||
"""Test successful creation of internal dataset with economy indexing."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "Economy Dataset"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique="economy",
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.indexing_technique == "economy"
|
||||
assert result.embedding_model_provider is None
|
||||
assert result.embedding_model is None
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_internal_dataset_with_high_quality_indexing_default_embedding(
|
||||
self, mock_dataset_service_dependencies
|
||||
):
|
||||
"""Test creation with high_quality indexing using default embedding model."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "High Quality Dataset"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock model manager
|
||||
embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock()
|
||||
mock_model_manager_instance = Mock()
|
||||
mock_model_manager_instance.get_default_model_instance.return_value = embedding_model
|
||||
mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.embedding_model_provider == embedding_model.provider
|
||||
assert result.embedding_model == embedding_model.model
|
||||
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
|
||||
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_internal_dataset_with_high_quality_indexing_custom_embedding(
|
||||
self, mock_dataset_service_dependencies
|
||||
):
|
||||
"""Test creation with high_quality indexing using custom embedding model."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "Custom Embedding Dataset"
|
||||
embedding_provider = "openai"
|
||||
embedding_model_name = "text-embedding-3-small"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock model manager
|
||||
embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock(
|
||||
model=embedding_model_name, provider=embedding_provider
|
||||
)
|
||||
mock_model_manager_instance = Mock()
|
||||
mock_model_manager_instance.get_model_instance.return_value = embedding_model
|
||||
mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
account=account,
|
||||
embedding_model_provider=embedding_provider,
|
||||
embedding_model_name=embedding_model_name,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.embedding_model_provider == embedding_provider
|
||||
assert result.embedding_model == embedding_model_name
|
||||
mock_dataset_service_dependencies["check_embedding"].assert_called_once_with(
|
||||
tenant_id, embedding_provider, embedding_model_name
|
||||
)
|
||||
mock_model_manager_instance.get_model_instance.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
provider=embedding_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=embedding_model_name,
|
||||
)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_internal_dataset_with_retrieval_model(self, mock_dataset_service_dependencies):
|
||||
"""Test creation with retrieval model configuration."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "Retrieval Model Dataset"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock retrieval model
|
||||
retrieval_model = DatasetCreateTestDataFactory.create_retrieval_model_mock()
|
||||
retrieval_model_dict = {"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0}
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.retrieval_model == retrieval_model_dict
|
||||
retrieval_model.model_dump.assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_internal_dataset_with_retrieval_model_reranking(self, mock_dataset_service_dependencies):
|
||||
"""Test creation with retrieval model that includes reranking."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "Reranking Dataset"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock model manager
|
||||
embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock()
|
||||
mock_model_manager_instance = Mock()
|
||||
mock_model_manager_instance.get_default_model_instance.return_value = embedding_model
|
||||
mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance
|
||||
|
||||
# Mock retrieval model with reranking
|
||||
reranking_model = Mock()
|
||||
reranking_model.reranking_provider_name = "cohere"
|
||||
reranking_model.reranking_model_name = "rerank-english-v3.0"
|
||||
|
||||
retrieval_model = DatasetCreateTestDataFactory.create_retrieval_model_mock()
|
||||
retrieval_model.reranking_model = reranking_model
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_dataset_service_dependencies["check_reranking"].assert_called_once_with(
|
||||
tenant_id, "cohere", "rerank-english-v3.0"
|
||||
)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_internal_dataset_with_custom_permission(self, mock_dataset_service_dependencies):
|
||||
"""Test creation with custom permission setting."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "Custom Permission Dataset"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
permission="all_team_members",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.permission == "all_team_members"
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
# ==================== External Dataset Creation Tests ====================
|
||||
|
||||
def test_create_external_dataset_success(self, mock_dataset_service_dependencies):
|
||||
"""Test successful creation of external dataset."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "External Dataset"
|
||||
external_api_id = "external-api-123"
|
||||
external_knowledge_id = "external-knowledge-456"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock external knowledge API
|
||||
external_api = DatasetCreateTestDataFactory.create_external_knowledge_api_mock(api_id=external_api_id)
|
||||
mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = external_api
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
provider="external",
|
||||
external_knowledge_api_id=external_api_id,
|
||||
external_knowledge_id=external_knowledge_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.provider == "external"
|
||||
assert mock_db.add.call_count == 2 # Dataset + ExternalKnowledgeBindings
|
||||
mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.assert_called_once_with(
|
||||
external_api_id
|
||||
)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when external knowledge API is not found."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "External Dataset"
|
||||
external_api_id = "non-existent-api"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock external knowledge API not found
|
||||
mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = None
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="External API template not found"):
|
||||
DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
provider="external",
|
||||
external_knowledge_api_id=external_api_id,
|
||||
external_knowledge_id="knowledge-123",
|
||||
)
|
||||
|
||||
def test_create_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when external knowledge ID is missing."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "External Dataset"
|
||||
external_api_id = "external-api-123"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock external knowledge API
|
||||
external_api = DatasetCreateTestDataFactory.create_external_knowledge_api_mock(api_id=external_api_id)
|
||||
mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = external_api
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="external_knowledge_id is required"):
|
||||
DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
provider="external",
|
||||
external_knowledge_api_id=external_api_id,
|
||||
external_knowledge_id=None,
|
||||
)
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
def test_create_dataset_duplicate_name_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when dataset name already exists."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "Duplicate Dataset"
|
||||
|
||||
# Mock database query to return existing dataset
|
||||
existing_dataset = DatasetCreateTestDataFactory.create_dataset_mock(name=name)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = existing_dataset
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {name} already exists"):
|
||||
DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
)
|
||||
|
||||
|
||||
class TestDatasetServiceCreateEmptyRagPipelineDataset:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.create_empty_rag_pipeline_dataset method.
|
||||
|
||||
This test suite covers:
|
||||
- RAG pipeline dataset creation with provided name
|
||||
- RAG pipeline dataset creation with auto-generated name
|
||||
- Pipeline creation
|
||||
- Error conditions (duplicate names, missing current user)
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_pipeline_dependencies(self):
|
||||
"""Common mock setup for RAG pipeline dataset creation."""
|
||||
with (
|
||||
patch("services.dataset_service.db.session") as mock_db,
|
||||
patch("services.dataset_service.current_user") as mock_current_user,
|
||||
patch("services.dataset_service.generate_incremental_name") as mock_generate_name,
|
||||
):
|
||||
# Configure mock_current_user to behave like a Flask-Login proxy
|
||||
# Default: no user (falsy)
|
||||
mock_current_user.id = None
|
||||
yield {
|
||||
"db_session": mock_db,
|
||||
"current_user_mock": mock_current_user,
|
||||
"generate_name": mock_generate_name,
|
||||
}
|
||||
|
||||
def test_create_rag_pipeline_dataset_with_name_success(self, mock_rag_pipeline_dependencies):
|
||||
"""Test successful creation of RAG pipeline dataset with provided name."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
name = "RAG Pipeline Dataset"
|
||||
description = "RAG Pipeline Description"
|
||||
|
||||
# Mock current user - set up the mock to have id attribute accessible directly
|
||||
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
|
||||
|
||||
# Mock database query (no duplicate name)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock database operations
|
||||
mock_db = mock_rag_pipeline_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Create entity
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name=name,
|
||||
description=description,
|
||||
icon_info=icon_info,
|
||||
permission="only_me",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.name == name
|
||||
assert result.description == description
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.created_by == user_id
|
||||
assert result.provider == "vendor"
|
||||
assert result.runtime_mode == "rag_pipeline"
|
||||
assert result.permission == "only_me"
|
||||
assert mock_db.add.call_count == 2 # Pipeline + Dataset
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_rag_pipeline_dataset_with_auto_generated_name(self, mock_rag_pipeline_dependencies):
|
||||
"""Test creation of RAG pipeline dataset with auto-generated name."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
auto_name = "Untitled 1"
|
||||
|
||||
# Mock current user - set up the mock to have id attribute accessible directly
|
||||
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
|
||||
|
||||
# Mock database query (empty name, need to generate)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock name generation
|
||||
mock_rag_pipeline_dependencies["generate_name"].return_value = auto_name
|
||||
|
||||
# Mock database operations
|
||||
mock_db = mock_rag_pipeline_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Create entity with empty name
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name="",
|
||||
description="",
|
||||
icon_info=icon_info,
|
||||
permission="only_me",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.name == auto_name
|
||||
mock_rag_pipeline_dependencies["generate_name"].assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_rag_pipeline_dataset_duplicate_name_error(self, mock_rag_pipeline_dependencies):
|
||||
"""Test error when RAG pipeline dataset name already exists."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
name = "Duplicate RAG Dataset"
|
||||
|
||||
# Mock current user - set up the mock to have id attribute accessible directly
|
||||
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
|
||||
|
||||
# Mock database query to return existing dataset
|
||||
existing_dataset = DatasetCreateTestDataFactory.create_dataset_mock(name=name)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = existing_dataset
|
||||
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Create entity
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name=name,
|
||||
description="",
|
||||
icon_info=icon_info,
|
||||
permission="only_me",
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {name} already exists"):
|
||||
DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
|
||||
)
|
||||
|
||||
def test_create_rag_pipeline_dataset_missing_current_user_error(self, mock_rag_pipeline_dependencies):
|
||||
"""Test error when current user is not available."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Mock current user as None - set id to None so the check fails
|
||||
mock_rag_pipeline_dependencies["current_user_mock"].id = None
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Create entity
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name="Test Dataset",
|
||||
description="",
|
||||
icon_info=icon_info,
|
||||
permission="only_me",
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Current user or current user id not found"):
|
||||
DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
|
||||
)
|
||||
|
||||
def test_create_rag_pipeline_dataset_with_custom_permission(self, mock_rag_pipeline_dependencies):
|
||||
"""Test creation with custom permission setting."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
name = "Custom Permission RAG Dataset"
|
||||
|
||||
# Mock current user - set up the mock to have id attribute accessible directly
|
||||
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock database operations
|
||||
mock_db = mock_rag_pipeline_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Create entity
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name=name,
|
||||
description="",
|
||||
icon_info=icon_info,
|
||||
permission="all_team",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.permission == "all_team"
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_rag_pipeline_dataset_with_icon_info(self, mock_rag_pipeline_dependencies):
|
||||
"""Test creation with icon info configuration."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
name = "Icon Info RAG Dataset"
|
||||
|
||||
# Mock current user - set up the mock to have id attribute accessible directly
|
||||
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock database operations
|
||||
mock_db = mock_rag_pipeline_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Create entity with icon info
|
||||
icon_info = IconInfo(
|
||||
icon="📚",
|
||||
icon_background="#E8F5E9",
|
||||
icon_type="emoji",
|
||||
icon_url="https://example.com/icon.png",
|
||||
)
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name=name,
|
||||
description="",
|
||||
icon_info=icon_info,
|
||||
permission="only_me",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.icon_info == icon_info.model_dump()
|
||||
mock_db.commit.assert_called_once()
|
||||
@@ -0,0 +1,216 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
class DatasetDeleteTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for dataset delete tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
tenant_id: str = "test-tenant-123",
|
||||
created_by: str = "creator-456",
|
||||
doc_form: str | None = None,
|
||||
indexing_technique: str | None = "high_quality",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.created_by = created_by
|
||||
dataset.doc_form = doc_form
|
||||
dataset.indexing_technique = indexing_technique
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_user_mock(
|
||||
user_id: str = "user-789",
|
||||
tenant_id: str = "test-tenant-123",
|
||||
role: TenantAccountRole = TenantAccountRole.ADMIN,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock user with specified attributes."""
|
||||
user = Mock(spec=Account)
|
||||
user.id = user_id
|
||||
user.current_tenant_id = tenant_id
|
||||
user.current_role = role
|
||||
for key, value in kwargs.items():
|
||||
setattr(user, key, value)
|
||||
return user
|
||||
|
||||
|
||||
class TestDatasetServiceDeleteDataset:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.delete_dataset method.
|
||||
|
||||
This test suite covers all deletion scenarios including:
|
||||
- Normal dataset deletion with documents
|
||||
- Empty dataset deletion (no documents, doc_form is None)
|
||||
- Dataset deletion with missing indexing_technique
|
||||
- Permission checks
|
||||
- Event handling
|
||||
|
||||
This test suite provides regression protection for issue #27073.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""Common mock setup for dataset service dependencies."""
|
||||
with (
|
||||
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
|
||||
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted,
|
||||
):
|
||||
yield {
|
||||
"get_dataset": mock_get_dataset,
|
||||
"check_permission": mock_check_perm,
|
||||
"db_session": mock_db,
|
||||
"dataset_was_deleted": mock_dataset_was_deleted,
|
||||
}
|
||||
|
||||
def test_delete_dataset_with_documents_success(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test successful deletion of a dataset with documents.
|
||||
|
||||
This test verifies:
|
||||
- Dataset is retrieved correctly
|
||||
- Permission check is performed
|
||||
- dataset_was_deleted event is sent
|
||||
- Dataset is deleted from database
|
||||
- Method returns True
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DatasetDeleteTestDataFactory.create_dataset_mock(
|
||||
doc_form="text_model", indexing_technique="high_quality"
|
||||
)
|
||||
user = DatasetDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
# Act
|
||||
result = DatasetService.delete_dataset(dataset.id, user)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id)
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
|
||||
|
||||
def test_delete_empty_dataset_success(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test successful deletion of an empty dataset (no documents, doc_form is None).
|
||||
|
||||
This test verifies that:
|
||||
- Empty datasets can be deleted without errors
|
||||
- dataset_was_deleted event is sent (event handler will skip cleanup if doc_form is None)
|
||||
- Dataset is deleted from database
|
||||
- Method returns True
|
||||
|
||||
This is the primary test for issue #27073 where deleting an empty dataset
|
||||
caused internal server error due to assertion failure in event handlers.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique=None)
|
||||
user = DatasetDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
# Act
|
||||
result = DatasetService.delete_dataset(dataset.id, user)
|
||||
|
||||
# Assert - Verify complete deletion flow
|
||||
assert result is True
|
||||
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id)
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
|
||||
|
||||
def test_delete_dataset_with_partial_none_values(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test deletion of dataset with partial None values.
|
||||
|
||||
This test verifies that datasets with partial None values (e.g., doc_form exists
|
||||
but indexing_technique is None) can be deleted successfully. The event handler
|
||||
will skip cleanup if any required field is None.
|
||||
|
||||
Improvement based on Gemini Code Assist suggestion: Added comprehensive assertions
|
||||
to verify all core deletion operations are performed, not just event sending.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form="text_model", indexing_technique=None)
|
||||
user = DatasetDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
# Act
|
||||
result = DatasetService.delete_dataset(dataset.id, user)
|
||||
|
||||
# Assert - Verify complete deletion flow (Gemini suggestion implemented)
|
||||
assert result is True
|
||||
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id)
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
|
||||
|
||||
def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test deletion of dataset where doc_form is None but indexing_technique exists.
|
||||
|
||||
This edge case can occur in certain dataset configurations and should be handled
|
||||
gracefully by the event handler's conditional check.
|
||||
"""
|
||||
# Arrange
|
||||
dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique="high_quality")
|
||||
user = DatasetDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
# Act
|
||||
result = DatasetService.delete_dataset(dataset.id, user)
|
||||
|
||||
# Assert - Verify complete deletion flow
|
||||
assert result is True
|
||||
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id)
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset)
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_called_once()
|
||||
|
||||
def test_delete_dataset_not_found(self, mock_dataset_service_dependencies):
|
||||
"""
|
||||
Test deletion attempt when dataset doesn't exist.
|
||||
|
||||
This test verifies that:
|
||||
- Method returns False when dataset is not found
|
||||
- No deletion operations are performed
|
||||
- No events are sent
|
||||
"""
|
||||
# Arrange
|
||||
dataset_id = "non-existent-dataset"
|
||||
user = DatasetDeleteTestDataFactory.create_user_mock()
|
||||
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = None
|
||||
|
||||
# Act
|
||||
result = DatasetService.delete_dataset(dataset_id, user)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id)
|
||||
mock_dataset_service_dependencies["check_permission"].assert_not_called()
|
||||
mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_not_called()
|
||||
mock_dataset_service_dependencies["db_session"].delete.assert_not_called()
|
||||
mock_dataset_service_dependencies["db_session"].commit.assert_not_called()
|
||||
@@ -0,0 +1,746 @@
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService retrieval/list methods.
|
||||
|
||||
This test suite covers:
|
||||
- get_datasets - pagination, search, filtering, permissions
|
||||
- get_dataset - single dataset retrieval
|
||||
- get_datasets_by_ids - bulk retrieval
|
||||
- get_process_rules - dataset processing rules
|
||||
- get_dataset_queries - dataset query history
|
||||
- get_related_apps - apps using the dataset
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
Dataset,
|
||||
DatasetPermission,
|
||||
DatasetPermissionEnum,
|
||||
DatasetProcessRule,
|
||||
DatasetQuery,
|
||||
)
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
|
||||
|
||||
class DatasetRetrievalTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for dataset retrieval tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
name: str = "Test Dataset",
|
||||
tenant_id: str = "tenant-123",
|
||||
created_by: str = "user-123",
|
||||
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.name = name
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.created_by = created_by
|
||||
dataset.permission = permission
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_account_mock(
|
||||
account_id: str = "account-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
role: TenantAccountRole = TenantAccountRole.NORMAL,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock account."""
|
||||
account = create_autospec(Account, instance=True)
|
||||
account.id = account_id
|
||||
account.current_tenant_id = tenant_id
|
||||
account.current_role = role
|
||||
for key, value in kwargs.items():
|
||||
setattr(account, key, value)
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_permission_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
account_id: str = "account-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset permission."""
|
||||
permission = Mock(spec=DatasetPermission)
|
||||
permission.dataset_id = dataset_id
|
||||
permission.account_id = account_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(permission, key, value)
|
||||
return permission
|
||||
|
||||
@staticmethod
|
||||
def create_process_rule_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
mode: str = "automatic",
|
||||
rules: dict | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset process rule."""
|
||||
process_rule = Mock(spec=DatasetProcessRule)
|
||||
process_rule.dataset_id = dataset_id
|
||||
process_rule.mode = mode
|
||||
process_rule.rules_dict = rules or {}
|
||||
for key, value in kwargs.items():
|
||||
setattr(process_rule, key, value)
|
||||
return process_rule
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_query_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
query_id: str = "query-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset query."""
|
||||
dataset_query = Mock(spec=DatasetQuery)
|
||||
dataset_query.id = query_id
|
||||
dataset_query.dataset_id = dataset_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset_query, key, value)
|
||||
return dataset_query
|
||||
|
||||
@staticmethod
|
||||
def create_app_dataset_join_mock(
|
||||
app_id: str = "app-123",
|
||||
dataset_id: str = "dataset-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock app-dataset join."""
|
||||
join = Mock(spec=AppDatasetJoin)
|
||||
join.app_id = app_id
|
||||
join.dataset_id = dataset_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(join, key, value)
|
||||
return join
|
||||
|
||||
|
||||
class TestDatasetServiceGetDatasets:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.get_datasets method.
|
||||
|
||||
This test suite covers:
|
||||
- Pagination
|
||||
- Search functionality
|
||||
- Tag filtering
|
||||
- Permission-based filtering (ONLY_ME, ALL_TEAM, PARTIAL_TEAM)
|
||||
- Role-based filtering (OWNER, DATASET_OPERATOR, NORMAL)
|
||||
- include_all flag
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_datasets tests."""
|
||||
with (
|
||||
patch("services.dataset_service.db.session") as mock_db,
|
||||
patch("services.dataset_service.db.paginate") as mock_paginate,
|
||||
patch("services.dataset_service.TagService") as mock_tag_service,
|
||||
):
|
||||
yield {
|
||||
"db_session": mock_db,
|
||||
"paginate": mock_paginate,
|
||||
"tag_service": mock_tag_service,
|
||||
}
|
||||
|
||||
# ==================== Basic Retrieval Tests ====================
|
||||
|
||||
def test_get_datasets_basic_pagination(self, mock_dependencies):
|
||||
"""Test basic pagination without user or filters."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id=f"dataset-{i}", name=f"Dataset {i}", tenant_id=tenant_id
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
mock_paginate_result.total = 5
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 5
|
||||
assert total == 5
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_datasets_with_search(self, mock_dependencies):
|
||||
"""Test get_datasets with search keyword."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
search = "test"
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-1", name="Test Dataset", tenant_id=tenant_id
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, search=search)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_datasets_with_tag_filtering(self, mock_dependencies):
|
||||
"""Test get_datasets with tag_ids filtering."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
tag_ids = ["tag-1", "tag-2"]
|
||||
|
||||
# Mock tag service
|
||||
target_ids = ["dataset-1", "dataset-2"]
|
||||
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.return_value = target_ids
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
|
||||
for dataset_id in target_ids
|
||||
]
|
||||
mock_paginate_result.total = 2
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, tag_ids=tag_ids)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 2
|
||||
assert total == 2
|
||||
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.assert_called_once_with(
|
||||
"knowledge", tenant_id, tag_ids
|
||||
)
|
||||
|
||||
def test_get_datasets_with_empty_tag_ids(self, mock_dependencies):
|
||||
"""Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
tag_ids = []
|
||||
|
||||
# Mock pagination result - when tag_ids is empty, tag filtering is skipped
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", tenant_id=tenant_id)
|
||||
for i in range(3)
|
||||
]
|
||||
mock_paginate_result.total = 3
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, tag_ids=tag_ids)
|
||||
|
||||
# Assert
|
||||
# When tag_ids is empty, tag filtering is skipped, so normal query results are returned
|
||||
assert len(datasets) == 3
|
||||
assert total == 3
|
||||
# Tag service should not be called when tag_ids is empty
|
||||
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.assert_not_called()
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
# ==================== Permission-Based Filtering Tests ====================
|
||||
|
||||
def test_get_datasets_without_user_shows_only_all_team(self, mock_dependencies):
|
||||
"""Test that without user, only ALL_TEAM datasets are shown."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-1",
|
||||
tenant_id=tenant_id,
|
||||
permission=DatasetPermissionEnum.ALL_TEAM,
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, user=None)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_datasets_owner_with_include_all(self, mock_dependencies):
|
||||
"""Test that OWNER with include_all=True sees all datasets."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id="owner-123", tenant_id=tenant_id, role=TenantAccountRole.OWNER
|
||||
)
|
||||
|
||||
# Mock dataset permissions query (empty - owner doesn't need explicit permissions)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", tenant_id=tenant_id)
|
||||
for i in range(3)
|
||||
]
|
||||
mock_paginate_result.total = 3
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1, per_page=20, tenant_id=tenant_id, user=user, include_all=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 3
|
||||
assert total == 3
|
||||
|
||||
def test_get_datasets_normal_user_only_me_permission(self, mock_dependencies):
|
||||
"""Test that normal user sees ONLY_ME datasets they created."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = "user-123"
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock dataset permissions query (no explicit permissions)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-1",
|
||||
tenant_id=tenant_id,
|
||||
created_by=user_id,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_normal_user_all_team_permission(self, mock_dependencies):
|
||||
"""Test that normal user sees ALL_TEAM datasets."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id="user-123", tenant_id=tenant_id, role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock dataset permissions query (no explicit permissions)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-1",
|
||||
tenant_id=tenant_id,
|
||||
permission=DatasetPermissionEnum.ALL_TEAM,
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_normal_user_partial_team_with_permission(self, mock_dependencies):
|
||||
"""Test that normal user sees PARTIAL_TEAM datasets they have permission for."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = "user-123"
|
||||
dataset_id = "dataset-1"
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock dataset permissions query - user has permission
|
||||
permission = DatasetRetrievalTestDataFactory.create_dataset_permission_mock(
|
||||
dataset_id=dataset_id, account_id=user_id
|
||||
)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = [permission]
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_dataset_operator_with_permissions(self, mock_dependencies):
|
||||
"""Test that DATASET_OPERATOR only sees datasets they have explicit permission for."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = "operator-123"
|
||||
dataset_id = "dataset-1"
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.DATASET_OPERATOR
|
||||
)
|
||||
|
||||
# Mock dataset permissions query - operator has permission
|
||||
permission = DatasetRetrievalTestDataFactory.create_dataset_permission_mock(
|
||||
dataset_id=dataset_id, account_id=user_id
|
||||
)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = [permission]
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_dataset_operator_without_permissions(self, mock_dependencies):
|
||||
"""Test that DATASET_OPERATOR without permissions returns empty result."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = "operator-123"
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.DATASET_OPERATOR
|
||||
)
|
||||
|
||||
# Mock dataset permissions query - no permissions
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert datasets == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
class TestDatasetServiceGetDataset:
|
||||
"""Comprehensive unit tests for DatasetService.get_dataset method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_dataset tests."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield {"db_session": mock_db}
|
||||
|
||||
def test_get_dataset_success(self, mock_dependencies):
|
||||
"""Test successful retrieval of a single dataset."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = dataset
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.id == dataset_id
|
||||
mock_query.filter_by.assert_called_once_with(id=dataset_id)
|
||||
|
||||
def test_get_dataset_not_found(self, mock_dependencies):
|
||||
"""Test retrieval when dataset doesn't exist."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Mock database query returning None
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestDatasetServiceGetDatasetsByIds:
|
||||
"""Comprehensive unit tests for DatasetService.get_datasets_by_ids method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_datasets_by_ids tests."""
|
||||
with patch("services.dataset_service.db.paginate") as mock_paginate:
|
||||
yield {"paginate": mock_paginate}
|
||||
|
||||
def test_get_datasets_by_ids_success(self, mock_dependencies):
|
||||
"""Test successful bulk retrieval of datasets by IDs."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
dataset_ids = [str(uuid4()), str(uuid4()), str(uuid4())]
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
|
||||
for dataset_id in dataset_ids
|
||||
]
|
||||
mock_paginate_result.total = len(dataset_ids)
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets_by_ids(dataset_ids, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 3
|
||||
assert total == 3
|
||||
assert all(dataset.id in dataset_ids for dataset in datasets)
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_datasets_by_ids_empty_list(self, mock_dependencies):
|
||||
"""Test get_datasets_by_ids with empty list returns empty result."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
dataset_ids = []
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets_by_ids(dataset_ids, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert datasets == []
|
||||
assert total == 0
|
||||
mock_dependencies["paginate"].assert_not_called()
|
||||
|
||||
def test_get_datasets_by_ids_none_list(self, mock_dependencies):
|
||||
"""Test get_datasets_by_ids with None returns empty result."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets_by_ids(None, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert datasets == []
|
||||
assert total == 0
|
||||
mock_dependencies["paginate"].assert_not_called()
|
||||
|
||||
|
||||
class TestDatasetServiceGetProcessRules:
|
||||
"""Comprehensive unit tests for DatasetService.get_process_rules method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_process_rules tests."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield {"db_session": mock_db}
|
||||
|
||||
def test_get_process_rules_with_existing_rule(self, mock_dependencies):
|
||||
"""Test retrieval of process rules when rule exists."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
rules_data = {
|
||||
"pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}],
|
||||
"segmentation": {"delimiter": "\n", "max_tokens": 500},
|
||||
}
|
||||
process_rule = DatasetRetrievalTestDataFactory.create_process_rule_mock(
|
||||
dataset_id=dataset_id, mode="custom", rules=rules_data
|
||||
)
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value = process_rule
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_process_rules(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result["mode"] == "custom"
|
||||
assert result["rules"] == rules_data
|
||||
|
||||
def test_get_process_rules_without_existing_rule(self, mock_dependencies):
|
||||
"""Test retrieval of process rules when no rule exists (returns defaults)."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Mock database query returning None
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value = None
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_process_rules(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result["mode"] == DocumentService.DEFAULT_RULES["mode"]
|
||||
assert "rules" in result
|
||||
assert result["rules"] == DocumentService.DEFAULT_RULES["rules"]
|
||||
|
||||
|
||||
class TestDatasetServiceGetDatasetQueries:
|
||||
"""Comprehensive unit tests for DatasetService.get_dataset_queries method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_dataset_queries tests."""
|
||||
with patch("services.dataset_service.db.paginate") as mock_paginate:
|
||||
yield {"paginate": mock_paginate}
|
||||
|
||||
def test_get_dataset_queries_success(self, mock_dependencies):
|
||||
"""Test successful retrieval of dataset queries."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_query_mock(dataset_id=dataset_id, query_id=f"query-{i}")
|
||||
for i in range(3)
|
||||
]
|
||||
mock_paginate_result.total = 3
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
queries, total = DatasetService.get_dataset_queries(dataset_id, page, per_page)
|
||||
|
||||
# Assert
|
||||
assert len(queries) == 3
|
||||
assert total == 3
|
||||
assert all(query.dataset_id == dataset_id for query in queries)
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_dataset_queries_empty_result(self, mock_dependencies):
|
||||
"""Test retrieval when no queries exist."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
# Mock pagination result (empty)
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = []
|
||||
mock_paginate_result.total = 0
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
queries, total = DatasetService.get_dataset_queries(dataset_id, page, per_page)
|
||||
|
||||
# Assert
|
||||
assert queries == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
class TestDatasetServiceGetRelatedApps:
|
||||
"""Comprehensive unit tests for DatasetService.get_related_apps method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_related_apps tests."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield {"db_session": mock_db}
|
||||
|
||||
def test_get_related_apps_success(self, mock_dependencies):
|
||||
"""Test successful retrieval of related apps."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Mock app-dataset joins
|
||||
app_joins = [
|
||||
DatasetRetrievalTestDataFactory.create_app_dataset_join_mock(app_id=f"app-{i}", dataset_id=dataset_id)
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.order_by.return_value.all.return_value = app_joins
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_related_apps(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert all(join.dataset_id == dataset_id for join in result)
|
||||
mock_query.where.assert_called_once()
|
||||
mock_query.where.return_value.order_by.assert_called_once()
|
||||
|
||||
def test_get_related_apps_empty_result(self, mock_dependencies):
|
||||
"""Test retrieval when no related apps exist."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Mock database query returning empty list
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.order_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_related_apps(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
@@ -0,0 +1,652 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
# Mock redis_client before importing dataset_service
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, ExternalKnowledgeBindings
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
|
||||
class DatasetUpdateTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for dataset update tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
provider: str = "vendor",
|
||||
name: str = "old_name",
|
||||
description: str = "old_description",
|
||||
indexing_technique: str = "high_quality",
|
||||
retrieval_model: str = "old_model",
|
||||
embedding_model_provider: str | None = None,
|
||||
embedding_model: str | None = None,
|
||||
collection_binding_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.provider = provider
|
||||
dataset.name = name
|
||||
dataset.description = description
|
||||
dataset.indexing_technique = indexing_technique
|
||||
dataset.retrieval_model = retrieval_model
|
||||
dataset.embedding_model_provider = embedding_model_provider
|
||||
dataset.embedding_model = embedding_model
|
||||
dataset.collection_binding_id = collection_binding_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_user_mock(user_id: str = "user-789") -> Mock:
|
||||
"""Create a mock user."""
|
||||
user = Mock()
|
||||
user.id = user_id
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_external_binding_mock(
|
||||
external_knowledge_id: str = "old_knowledge_id", external_knowledge_api_id: str = "old_api_id"
|
||||
) -> Mock:
|
||||
"""Create a mock external knowledge binding."""
|
||||
binding = Mock(spec=ExternalKnowledgeBindings)
|
||||
binding.external_knowledge_id = external_knowledge_id
|
||||
binding.external_knowledge_api_id = external_knowledge_api_id
|
||||
return binding
|
||||
|
||||
@staticmethod
|
||||
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
|
||||
"""Create a mock embedding model."""
|
||||
embedding_model = Mock()
|
||||
embedding_model.model = model
|
||||
embedding_model.provider = provider
|
||||
return embedding_model
|
||||
|
||||
@staticmethod
|
||||
def create_collection_binding_mock(binding_id: str = "binding-456") -> Mock:
|
||||
"""Create a mock collection binding."""
|
||||
binding = Mock()
|
||||
binding.id = binding_id
|
||||
return binding
|
||||
|
||||
@staticmethod
|
||||
def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock:
|
||||
"""Create a mock current user."""
|
||||
current_user = create_autospec(Account, instance=True)
|
||||
current_user.current_tenant_id = tenant_id
|
||||
return current_user
|
||||
|
||||
|
||||
class TestDatasetServiceUpdateDataset:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.update_dataset method.
|
||||
|
||||
This test suite covers all supported scenarios including:
|
||||
- External dataset updates
|
||||
- Internal dataset updates with different indexing techniques
|
||||
- Embedding model updates
|
||||
- Permission checks
|
||||
- Error conditions and edge cases
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""Common mock setup for dataset service dependencies."""
|
||||
with (
|
||||
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
|
||||
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
||||
patch("services.dataset_service.DatasetService._has_dataset_same_name") as has_dataset_same_name,
|
||||
):
|
||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
mock_naive_utc_now.return_value = current_time
|
||||
|
||||
yield {
|
||||
"get_dataset": mock_get_dataset,
|
||||
"check_permission": mock_check_perm,
|
||||
"db_session": mock_db,
|
||||
"naive_utc_now": mock_naive_utc_now,
|
||||
"current_time": current_time,
|
||||
"has_dataset_same_name": has_dataset_same_name,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_provider_dependencies(self):
|
||||
"""Mock setup for external provider tests."""
|
||||
with patch("services.dataset_service.Session") as mock_session:
|
||||
from extensions.ext_database import db
|
||||
|
||||
with patch.object(db.__class__, "engine", new_callable=Mock):
|
||||
session_mock = Mock()
|
||||
mock_session.return_value.__enter__.return_value = session_mock
|
||||
yield session_mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_internal_provider_dependencies(self):
|
||||
"""Mock setup for internal provider tests."""
|
||||
with (
|
||||
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
||||
patch(
|
||||
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
|
||||
) as mock_get_binding,
|
||||
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
|
||||
patch(
|
||||
"services.dataset_service.current_user", create_autospec(Account, instance=True)
|
||||
) as mock_current_user,
|
||||
):
|
||||
mock_current_user.current_tenant_id = "tenant-123"
|
||||
yield {
|
||||
"model_manager": mock_model_manager,
|
||||
"get_binding": mock_get_binding,
|
||||
"task": mock_task,
|
||||
"current_user": mock_current_user,
|
||||
}
|
||||
|
||||
def _assert_database_update_called(self, mock_db, dataset_id: str, expected_updates: dict[str, Any]):
|
||||
"""Helper method to verify database update calls."""
|
||||
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_updates)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def _assert_external_dataset_update(self, mock_dataset, mock_binding, update_data: dict[str, Any]):
|
||||
"""Helper method to verify external dataset updates."""
|
||||
assert mock_dataset.name == update_data.get("name", mock_dataset.name)
|
||||
assert mock_dataset.description == update_data.get("description", mock_dataset.description)
|
||||
assert mock_dataset.retrieval_model == update_data.get("external_retrieval_model", mock_dataset.retrieval_model)
|
||||
|
||||
if "external_knowledge_id" in update_data:
|
||||
assert mock_binding.external_knowledge_id == update_data["external_knowledge_id"]
|
||||
if "external_knowledge_api_id" in update_data:
|
||||
assert mock_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"]
|
||||
|
||||
# ==================== External Dataset Tests ====================
|
||||
|
||||
def test_update_external_dataset_success(
|
||||
self, mock_dataset_service_dependencies, mock_external_provider_dependencies
|
||||
):
|
||||
"""Test successful update of external dataset."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="external", name="old_name", description="old_description", retrieval_model="old_model"
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
binding = DatasetUpdateTestDataFactory.create_external_binding_mock()
|
||||
|
||||
# Mock external knowledge binding query
|
||||
mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = binding
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": "new_description",
|
||||
"external_retrieval_model": "new_model",
|
||||
"permission": "only_me",
|
||||
"external_knowledge_id": "new_knowledge_id",
|
||||
"external_knowledge_api_id": "new_api_id",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
|
||||
# Verify dataset and binding updates
|
||||
self._assert_external_dataset_update(dataset, binding, update_data)
|
||||
|
||||
# Verify database operations
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add.assert_any_call(dataset)
|
||||
mock_db.add.assert_any_call(binding)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when external knowledge id is missing."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "External knowledge id is required" in str(context.value)
|
||||
|
||||
def test_update_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when external knowledge api id is missing."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "External knowledge api id is required" in str(context.value)
|
||||
|
||||
def test_update_external_dataset_binding_not_found_error(
|
||||
self, mock_dataset_service_dependencies, mock_external_provider_dependencies
|
||||
):
|
||||
"""Test error when external knowledge binding is not found."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Mock external knowledge binding query returning None
|
||||
mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"external_knowledge_id": "knowledge_id",
|
||||
"external_knowledge_api_id": "api_id",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "External knowledge binding not found" in str(context.value)
|
||||
|
||||
# ==================== Internal Dataset Basic Tests ====================
|
||||
|
||||
def test_update_internal_dataset_basic_success(self, mock_dataset_service_dependencies):
|
||||
"""Test successful update of internal dataset with basic fields."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id="binding-123",
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": "new_description",
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify permission check was called
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
|
||||
# Verify database update was called with correct filtered data
|
||||
expected_filtered_data = {
|
||||
"name": "new_name",
|
||||
"description": "new_description",
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_internal_dataset_filter_none_values(self, mock_dataset_service_dependencies):
|
||||
"""Test that None values are filtered out except for description field."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": None, # Should be included
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": None, # Should be filtered out
|
||||
"embedding_model": None, # Should be filtered out
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with filtered data
|
||||
expected_filtered_data = {
|
||||
"name": "new_name",
|
||||
"description": None, # Description should be included even if None
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
actual_call_args = mock_dataset_service_dependencies[
|
||||
"db_session"
|
||||
].query.return_value.filter_by.return_value.update.call_args[0][0]
|
||||
# Remove timestamp for comparison as it's dynamic
|
||||
del actual_call_args["updated_at"]
|
||||
del expected_filtered_data["updated_at"]
|
||||
|
||||
assert actual_call_args == expected_filtered_data
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
# ==================== Indexing Technique Switch Tests ====================
|
||||
|
||||
def test_update_internal_dataset_indexing_technique_to_economy(
|
||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
||||
):
|
||||
"""Test updating internal dataset indexing technique to economy."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with embedding model fields cleared
|
||||
expected_filtered_data = {
|
||||
"indexing_technique": "economy",
|
||||
"embedding_model": None,
|
||||
"embedding_model_provider": None,
|
||||
"collection_binding_id": None,
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_internal_dataset_indexing_technique_to_high_quality(
|
||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
||||
):
|
||||
"""Test updating internal dataset indexing technique to high_quality."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Mock embedding model
|
||||
embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock()
|
||||
mock_internal_provider_dependencies[
|
||||
"model_manager"
|
||||
].return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
# Mock collection binding
|
||||
binding = DatasetUpdateTestDataFactory.create_collection_binding_mock()
|
||||
mock_internal_provider_dependencies["get_binding"].return_value = binding
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify embedding model was validated
|
||||
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with(
|
||||
tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id,
|
||||
provider="openai",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
# Verify collection binding was retrieved
|
||||
mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-ada-002")
|
||||
|
||||
# Verify database update was called with correct data
|
||||
expected_filtered_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"embedding_model_provider": "openai",
|
||||
"collection_binding_id": "binding-456",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify vector index task was triggered
|
||||
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "add")
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
# ==================== Embedding Model Update Tests ====================
|
||||
|
||||
def test_update_internal_dataset_keep_existing_embedding_model(self, mock_dataset_service_dependencies):
|
||||
"""Test updating internal dataset without changing embedding model."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id="binding-123",
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with existing embedding model preserved
|
||||
expected_filtered_data = {
|
||||
"name": "new_name",
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"collection_binding_id": "binding-123",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_internal_dataset_embedding_model_update(
|
||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
||||
):
|
||||
"""Test updating internal dataset with new embedding model."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Mock embedding model
|
||||
embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock("text-embedding-3-small")
|
||||
mock_internal_provider_dependencies[
|
||||
"model_manager"
|
||||
].return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
# Mock collection binding
|
||||
binding = DatasetUpdateTestDataFactory.create_collection_binding_mock("binding-789")
|
||||
mock_internal_provider_dependencies["get_binding"].return_value = binding
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify embedding model was validated
|
||||
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with(
|
||||
tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id,
|
||||
provider="openai",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model="text-embedding-3-small",
|
||||
)
|
||||
|
||||
# Verify collection binding was retrieved
|
||||
mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-3-small")
|
||||
|
||||
# Verify database update was called with correct data
|
||||
expected_filtered_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"embedding_model_provider": "openai",
|
||||
"collection_binding_id": "binding-789",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify vector index task was triggered
|
||||
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "update")
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_internal_dataset_no_indexing_technique_change(self, mock_dataset_service_dependencies):
|
||||
"""Test updating internal dataset without changing indexing technique."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id="binding-123",
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"indexing_technique": "high_quality", # Same as current
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with correct data
|
||||
expected_filtered_data = {
|
||||
"name": "new_name",
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"collection_binding_id": "binding-123",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
def test_update_dataset_not_found_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when dataset is not found."""
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = None
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "Dataset not found" in str(context.value)
|
||||
|
||||
def test_update_dataset_permission_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when user doesn't have permission."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock()
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
mock_dataset_service_dependencies["check_permission"].side_effect = NoPermissionError("No permission")
|
||||
|
||||
update_data = {"name": "new_name"}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
def test_update_internal_dataset_embedding_model_error(
|
||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
||||
):
|
||||
"""Test error when embedding model is not available."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Mock model manager to raise error
|
||||
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.side_effect = Exception(
|
||||
"No Embedding Model available"
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "invalid_provider",
|
||||
"embedding_model": "invalid_model",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(Exception) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "No Embedding Model available".lower() in str(context.value).lower()
|
||||
@@ -0,0 +1,317 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
|
||||
|
||||
class DocumentIndexingTaskProxyTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
|
||||
"""Create mock features with billing configuration."""
|
||||
features = Mock()
|
||||
features.billing = Mock()
|
||||
features.billing.enabled = billing_enabled
|
||||
features.billing.subscription = Mock()
|
||||
features.billing.subscription.plan = plan
|
||||
return features
|
||||
|
||||
@staticmethod
|
||||
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
|
||||
"""Create mock TenantIsolatedTaskQueue."""
|
||||
queue = Mock(spec=TenantIsolatedTaskQueue)
|
||||
queue.get_task_key.return_value = "task_key" if has_task_key else None
|
||||
queue.push_tasks = Mock()
|
||||
queue.set_task_waiting_time = Mock()
|
||||
return queue
|
||||
|
||||
@staticmethod
|
||||
def create_document_task_proxy(
|
||||
tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None
|
||||
) -> DocumentIndexingTaskProxy:
|
||||
"""Create DocumentIndexingTaskProxy instance for testing."""
|
||||
if document_ids is None:
|
||||
document_ids = ["doc-1", "doc-2", "doc-3"]
|
||||
return DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
|
||||
class TestDocumentIndexingTaskProxy:
|
||||
"""Test cases for DocumentIndexingTaskProxy class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test DocumentIndexingTaskProxy initialization."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1", "doc-2", "doc-3"]
|
||||
|
||||
# Act
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
|
||||
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
|
||||
assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing"
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
"""Test cached_property features."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features()
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
|
||||
# Act
|
||||
features1 = proxy.features
|
||||
features2 = proxy.features # Second call should use cached property
|
||||
|
||||
# Assert
|
||||
assert features1 == mock_features
|
||||
assert features2 == mock_features
|
||||
assert features1 is features2 # Should be the same instance due to caching
|
||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
"""Test _send_to_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_direct_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
mock_task.delay.assert_called_once_with(
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when task key exists."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=True
|
||||
)
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once()
|
||||
pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0]
|
||||
assert len(pushed_tasks) == 1
|
||||
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
|
||||
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
|
||||
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
|
||||
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when no task key exists."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=False
|
||||
)
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(mock_task)
|
||||
|
||||
# Assert
|
||||
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
|
||||
mock_task.delay.assert_called_once_with(
|
||||
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
|
||||
)
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
|
||||
def test_send_to_default_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_default_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_default_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_priority_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_tenant_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
|
||||
def test_send_to_priority_direct_queue(self, mock_task):
|
||||
"""Test _send_to_priority_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_direct_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_direct_queue()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.TEAM
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing enabled with non sandbox plan, should send to priority tenant queue
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_with_billing_disabled(self, mock_feature_service):
|
||||
"""Test _dispatch method when billing is disabled."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_direct_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
||||
proxy._send_to_priority_direct_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_delay_method(self, mock_feature_service):
|
||||
"""Test delay method integration."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
# If billing enabled with sandbox plan, should send to default tenant queue
|
||||
proxy._send_to_default_tenant_queue.assert_called_once()
|
||||
|
||||
def test_document_task_dataclass(self):
|
||||
"""Test DocumentTask dataclass."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1", "doc-2"]
|
||||
|
||||
# Act
|
||||
task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
|
||||
|
||||
# Assert
|
||||
assert task.tenant_id == tenant_id
|
||||
assert task.dataset_id == dataset_id
|
||||
assert task.document_ids == document_ids
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with empty plan string."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
@patch("services.document_indexing_task_proxy.FeatureService")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
|
||||
"""Test _dispatch method with None plan."""
|
||||
# Arrange
|
||||
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once()
|
||||
|
||||
def test_initialization_with_empty_document_ids(self):
|
||||
"""Test initialization with empty document_ids list."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = []
|
||||
|
||||
# Act
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
|
||||
def test_initialization_with_single_document_id(self):
|
||||
"""Test initialization with single document_id."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
dataset_id = "dataset-456"
|
||||
document_ids = ["doc-1"]
|
||||
|
||||
# Act
|
||||
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == tenant_id
|
||||
assert proxy._dataset_id == dataset_id
|
||||
assert proxy._document_ids == document_ids
|
||||
@@ -0,0 +1,33 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from models.dataset import Document
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
def test_normalize_display_status_alias_mapping():
|
||||
assert DocumentService.normalize_display_status("ACTIVE") == "available"
|
||||
assert DocumentService.normalize_display_status("enabled") == "available"
|
||||
assert DocumentService.normalize_display_status("archived") == "archived"
|
||||
assert DocumentService.normalize_display_status("unknown") is None
|
||||
|
||||
|
||||
def test_build_display_status_filters_available():
|
||||
filters = DocumentService.build_display_status_filters("available")
|
||||
assert len(filters) == 3
|
||||
for condition in filters:
|
||||
assert condition is not None
|
||||
|
||||
|
||||
def test_apply_display_status_filter_applies_when_status_present():
|
||||
query = sa.select(Document)
|
||||
filtered = DocumentService.apply_display_status_filter(query, "queuing")
|
||||
compiled = str(filtered.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "WHERE" in compiled
|
||||
assert "documents.indexing_status = 'waiting'" in compiled
|
||||
|
||||
|
||||
def test_apply_display_status_filter_returns_same_when_invalid():
|
||||
query = sa.select(Document)
|
||||
filtered = DocumentService.apply_display_status_filter(query, "invalid")
|
||||
compiled = str(filtered.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "WHERE" not in compiled
|
||||
203
dify/api/tests/unit_tests/services/test_metadata_bug_complete.py
Normal file
203
dify/api/tests/unit_tests/services/test_metadata_bug_complete.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from flask_restx import reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from models.account import Account
|
||||
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class TestMetadataBugCompleteValidation:
|
||||
"""Complete test suite to verify the metadata nullable bug and its fix."""
|
||||
|
||||
def test_1_pydantic_layer_validation(self):
|
||||
"""Test Layer 1: Pydantic model validation correctly rejects None values."""
|
||||
# Pydantic should reject None values for required fields
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs(type=None, name=None)
|
||||
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs(type="string", name=None)
|
||||
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs(type=None, name="test")
|
||||
|
||||
# Valid values should work
|
||||
valid_args = MetadataArgs(type="string", name="test_name")
|
||||
assert valid_args.type == "string"
|
||||
assert valid_args.name == "test_name"
|
||||
|
||||
def test_2_business_logic_layer_crashes_on_none(self):
|
||||
"""Test Layer 2: Business logic crashes when None values slip through."""
|
||||
# Create mock that bypasses Pydantic validation
|
||||
mock_metadata_args = Mock()
|
||||
mock_metadata_args.name = None
|
||||
mock_metadata_args.type = "string"
|
||||
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
with patch(
|
||||
"services.metadata_service.current_account_with_tenant",
|
||||
return_value=(mock_user, mock_user.current_tenant_id),
|
||||
):
|
||||
# Should crash with TypeError
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args)
|
||||
|
||||
# Test update method as well
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
with patch(
|
||||
"services.metadata_service.current_account_with_tenant",
|
||||
return_value=(mock_user, mock_user.current_tenant_id),
|
||||
):
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
|
||||
|
||||
def test_3_database_constraints_verification(self):
|
||||
"""Test Layer 3: Verify database model has nullable=False constraints."""
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from models.dataset import DatasetMetadata
|
||||
|
||||
# Get table info
|
||||
mapper = inspect(DatasetMetadata)
|
||||
|
||||
# Check that type and name columns are not nullable
|
||||
type_column = mapper.columns["type"]
|
||||
name_column = mapper.columns["name"]
|
||||
|
||||
assert type_column.nullable is False, "type column should be nullable=False"
|
||||
assert name_column.nullable is False, "name column should be nullable=False"
|
||||
|
||||
def test_4_fixed_api_layer_rejects_null(self, app):
|
||||
"""Test Layer 4: Fixed API configuration properly rejects null values."""
|
||||
# Test Console API create endpoint (fixed)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
with pytest.raises(BadRequest):
|
||||
parser.parse_args()
|
||||
|
||||
# Test with just name being null
|
||||
with app.test_request_context(json={"type": "string", "name": None}, content_type="application/json"):
|
||||
with pytest.raises(BadRequest):
|
||||
parser.parse_args()
|
||||
|
||||
# Test with just type being null
|
||||
with app.test_request_context(json={"type": None, "name": "test"}, content_type="application/json"):
|
||||
with pytest.raises(BadRequest):
|
||||
parser.parse_args()
|
||||
|
||||
def test_5_fixed_api_accepts_valid_values(self, app):
|
||||
"""Test that fixed API still accepts valid non-null values."""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
with app.test_request_context(json={"type": "string", "name": "valid_name"}, content_type="application/json"):
|
||||
args = parser.parse_args()
|
||||
assert args["type"] == "string"
|
||||
assert args["name"] == "valid_name"
|
||||
|
||||
def test_6_simulated_buggy_behavior(self, app):
|
||||
"""Test simulating the original buggy behavior with nullable=True."""
|
||||
# Simulate the old buggy configuration
|
||||
buggy_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
)
|
||||
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
# This would pass in the buggy version
|
||||
args = buggy_parser.parse_args()
|
||||
assert args["type"] is None
|
||||
assert args["name"] is None
|
||||
|
||||
# But would crash when trying to create MetadataArgs
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs.model_validate(args)
|
||||
|
||||
def test_7_end_to_end_validation_layers(self):
|
||||
"""Test all validation layers work together correctly."""
|
||||
# Layer 1: API should reject null at parameter level (with fix)
|
||||
# Layer 2: Pydantic should reject null at model level
|
||||
# Layer 3: Business logic expects non-null
|
||||
# Layer 4: Database enforces non-null
|
||||
|
||||
# Test that valid data flows through all layers
|
||||
valid_data = {"type": "string", "name": "test_metadata"}
|
||||
|
||||
# Should create valid Pydantic object
|
||||
metadata_args = MetadataArgs.model_validate(valid_data)
|
||||
assert metadata_args.type == "string"
|
||||
assert metadata_args.name == "test_metadata"
|
||||
|
||||
# Should not crash in business logic length check
|
||||
assert len(metadata_args.name) <= 255 # This should not crash
|
||||
assert len(metadata_args.type) > 0 # This should not crash
|
||||
|
||||
def test_8_verify_specific_fix_locations(self):
|
||||
"""Verify that the specific locations mentioned in bug report are fixed."""
|
||||
# Read the actual files to verify fixes
|
||||
import os
|
||||
|
||||
# Console API create
|
||||
console_create_file = "api/controllers/console/datasets/metadata.py"
|
||||
if os.path.exists(console_create_file):
|
||||
content = Path(console_create_file).read_text()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
|
||||
|
||||
# Service API create
|
||||
service_create_file = "api/controllers/service_api/dataset/metadata.py"
|
||||
if os.path.exists(service_create_file):
|
||||
content = Path(service_create_file).read_text()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
|
||||
assert "nullable=True" not in create_api_section
|
||||
|
||||
|
||||
class TestMetadataValidationSummary:
|
||||
"""Summary tests that demonstrate the complete validation architecture."""
|
||||
|
||||
def test_validation_layer_architecture(self):
|
||||
"""Document and test the 4-layer validation architecture."""
|
||||
# Layer 1: API Parameter Validation (Flask-RESTful reqparse)
|
||||
# - Role: First line of defense, validates HTTP request parameters
|
||||
# - Fixed: nullable=False ensures null values are rejected at API boundary
|
||||
|
||||
# Layer 2: Pydantic Model Validation
|
||||
# - Role: Validates data structure and types before business logic
|
||||
# - Working: Required fields without Optional[] reject None values
|
||||
|
||||
# Layer 3: Business Logic Validation
|
||||
# - Role: Domain-specific validation (length checks, uniqueness, etc.)
|
||||
# - Vulnerable: Direct len() calls crash on None values
|
||||
|
||||
# Layer 4: Database Constraints
|
||||
# - Role: Final data integrity enforcement
|
||||
# - Working: nullable=False prevents None values in database
|
||||
|
||||
# The bug was: Layer 1 allowed None, but Layers 2-4 expected non-None
|
||||
# The fix: Make Layer 1 consistent with Layers 2-4
|
||||
|
||||
assert True # This test documents the architecture
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
127
dify/api/tests/unit_tests/services/test_metadata_nullable_bug.py
Normal file
127
dify/api/tests/unit_tests/services/test_metadata_nullable_bug.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from flask_restx import reqparse
|
||||
|
||||
from models.account import Account
|
||||
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class TestMetadataNullableBug:
|
||||
"""Test case to reproduce the metadata nullable validation bug."""
|
||||
|
||||
def test_metadata_args_with_none_values_should_fail(self):
|
||||
"""Test that MetadataArgs validation should reject None values."""
|
||||
# This test demonstrates the expected behavior - should fail validation
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
# This should fail because Pydantic expects non-None values
|
||||
MetadataArgs(type=None, name=None)
|
||||
|
||||
def test_metadata_service_create_with_none_name_crashes(self):
|
||||
"""Test that MetadataService.create_metadata crashes when name is None."""
|
||||
# Mock the MetadataArgs to bypass Pydantic validation
|
||||
mock_metadata_args = Mock()
|
||||
mock_metadata_args.name = None # This will cause len() to crash
|
||||
mock_metadata_args.type = "string"
|
||||
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
with patch(
|
||||
"services.metadata_service.current_account_with_tenant",
|
||||
return_value=(mock_user, mock_user.current_tenant_id),
|
||||
):
|
||||
# This should crash with TypeError when calling len(None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args)
|
||||
|
||||
def test_metadata_service_update_with_none_name_crashes(self):
|
||||
"""Test that MetadataService.update_metadata_name crashes when name is None."""
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
with patch(
|
||||
"services.metadata_service.current_account_with_tenant",
|
||||
return_value=(mock_user, mock_user.current_tenant_id),
|
||||
):
|
||||
# This should crash with TypeError when calling len(None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
|
||||
|
||||
def test_api_parser_accepts_null_values(self, app):
|
||||
"""Test that API parser configuration incorrectly accepts null values."""
|
||||
# Simulate the current API parser configuration
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
)
|
||||
|
||||
# Simulate request data with null values
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
# This should parse successfully due to nullable=True
|
||||
args = parser.parse_args()
|
||||
|
||||
# Verify that null values are accepted
|
||||
assert args["type"] is None
|
||||
assert args["name"] is None
|
||||
|
||||
# This demonstrates the bug: API accepts None but business logic will crash
|
||||
|
||||
def test_integration_bug_scenario(self, app):
|
||||
"""Test the complete bug scenario from API to service layer."""
|
||||
# Step 1: API parser accepts null values (current buggy behavior)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
)
|
||||
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
args = parser.parse_args()
|
||||
|
||||
# Step 2: Try to create MetadataArgs with None values
|
||||
# This should fail at Pydantic validation level
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
|
||||
# Step 3: If we bypass Pydantic (simulating the bug scenario)
|
||||
# Move this outside the request context to avoid Flask-Login issues
|
||||
mock_metadata_args = Mock()
|
||||
mock_metadata_args.name = None # From args["name"]
|
||||
mock_metadata_args.type = None # From args["type"]
|
||||
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
with patch(
|
||||
"services.metadata_service.current_account_with_tenant",
|
||||
return_value=(mock_user, mock_user.current_tenant_id),
|
||||
):
|
||||
# Step 4: Service layer crashes on len(None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args)
|
||||
|
||||
def test_correct_nullable_false_configuration_works(self, app):
|
||||
"""Test that the correct nullable=False configuration works as expected."""
|
||||
# This tests the FIXED configuration
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
# This should fail with BadRequest due to nullable=False
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
with pytest.raises(BadRequest):
|
||||
parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -0,0 +1,153 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from models.dataset import Dataset, Document
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
MetadataDetail,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class TestMetadataPartialUpdate(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.dataset = MagicMock(spec=Dataset)
|
||||
self.dataset.id = "dataset_id"
|
||||
self.dataset.built_in_field_enabled = False
|
||||
|
||||
self.document = MagicMock(spec=Document)
|
||||
self.document.id = "doc_id"
|
||||
self.document.doc_metadata = {"existing_key": "existing_value"}
|
||||
self.document.data_source_type = "upload_file"
|
||||
|
||||
@patch("services.metadata_service.db")
|
||||
@patch("services.metadata_service.DocumentService")
|
||||
@patch("services.metadata_service.current_account_with_tenant")
|
||||
@patch("services.metadata_service.redis_client")
|
||||
def test_partial_update_merges_metadata(self, mock_redis, mock_current_account, mock_document_service, mock_db):
|
||||
# Setup mocks
|
||||
mock_redis.get.return_value = None
|
||||
mock_document_service.get_document.return_value = self.document
|
||||
mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
# Mock DB query for existing bindings
|
||||
|
||||
# No existing binding for new key
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Input data
|
||||
operation = DocumentMetadataOperation(
|
||||
document_id="doc_id",
|
||||
metadata_list=[MetadataDetail(id="new_meta_id", name="new_key", value="new_value")],
|
||||
partial_update=True,
|
||||
)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Execute
|
||||
MetadataService.update_documents_metadata(self.dataset, metadata_args)
|
||||
|
||||
# Verify
|
||||
# 1. Check that doc_metadata contains BOTH existing and new keys
|
||||
expected_metadata = {"existing_key": "existing_value", "new_key": "new_value"}
|
||||
assert self.document.doc_metadata == expected_metadata
|
||||
|
||||
# 2. Check that existing bindings were NOT deleted
|
||||
# The delete call in the original code: db.session.query(...).filter_by(...).delete()
|
||||
# In partial update, this should NOT be called.
|
||||
mock_db.session.query.return_value.filter_by.return_value.delete.assert_not_called()
|
||||
|
||||
@patch("services.metadata_service.db")
|
||||
@patch("services.metadata_service.DocumentService")
|
||||
@patch("services.metadata_service.current_account_with_tenant")
|
||||
@patch("services.metadata_service.redis_client")
|
||||
def test_full_update_replaces_metadata(self, mock_redis, mock_current_account, mock_document_service, mock_db):
|
||||
# Setup mocks
|
||||
mock_redis.get.return_value = None
|
||||
mock_document_service.get_document.return_value = self.document
|
||||
mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
# Input data (partial_update=False by default)
|
||||
operation = DocumentMetadataOperation(
|
||||
document_id="doc_id",
|
||||
metadata_list=[MetadataDetail(id="new_meta_id", name="new_key", value="new_value")],
|
||||
partial_update=False,
|
||||
)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Execute
|
||||
MetadataService.update_documents_metadata(self.dataset, metadata_args)
|
||||
|
||||
# Verify
|
||||
# 1. Check that doc_metadata contains ONLY the new key
|
||||
expected_metadata = {"new_key": "new_value"}
|
||||
assert self.document.doc_metadata == expected_metadata
|
||||
|
||||
# 2. Check that existing bindings WERE deleted
|
||||
# In full update (default), we expect the existing bindings to be cleared.
|
||||
mock_db.session.query.return_value.filter_by.return_value.delete.assert_called()
|
||||
|
||||
@patch("services.metadata_service.db")
|
||||
@patch("services.metadata_service.DocumentService")
|
||||
@patch("services.metadata_service.current_account_with_tenant")
|
||||
@patch("services.metadata_service.redis_client")
|
||||
def test_partial_update_skips_existing_binding(
|
||||
self, mock_redis, mock_current_account, mock_document_service, mock_db
|
||||
):
|
||||
# Setup mocks
|
||||
mock_redis.get.return_value = None
|
||||
mock_document_service.get_document.return_value = self.document
|
||||
mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id")
|
||||
|
||||
# Mock DB query to return an existing binding
|
||||
# This simulates that the document ALREADY has the metadata we are trying to add
|
||||
mock_existing_binding = MagicMock()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_existing_binding
|
||||
|
||||
# Input data
|
||||
operation = DocumentMetadataOperation(
|
||||
document_id="doc_id",
|
||||
metadata_list=[MetadataDetail(id="existing_meta_id", name="existing_key", value="existing_value")],
|
||||
partial_update=True,
|
||||
)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Execute
|
||||
MetadataService.update_documents_metadata(self.dataset, metadata_args)
|
||||
|
||||
# Verify
|
||||
# We verify that db.session.add was NOT called for DatasetMetadataBinding
|
||||
# Since we can't easily check "not called with specific type" on the generic add method without complex logic,
|
||||
# we can check if the number of add calls is 1 (only for the document update) instead of 2 (document + binding)
|
||||
|
||||
# Expected calls:
|
||||
# 1. db.session.add(document)
|
||||
# 2. NO db.session.add(binding) because it exists
|
||||
|
||||
# Note: In the code, db.session.add is called for document.
|
||||
# Then loop over metadata_list.
|
||||
# If existing_binding found, continue.
|
||||
# So binding add should be skipped.
|
||||
|
||||
# Let's filter the calls to add to see what was added
|
||||
add_calls = mock_db.session.add.call_args_list
|
||||
added_objects = [call.args[0] for call in add_calls]
|
||||
|
||||
# Check that no DatasetMetadataBinding was added
|
||||
from models.dataset import DatasetMetadataBinding
|
||||
|
||||
has_binding_add = any(
|
||||
isinstance(obj, DatasetMetadataBinding)
|
||||
or (isinstance(obj, MagicMock) and getattr(obj, "__class__", None) == DatasetMetadataBinding)
|
||||
for obj in added_objects
|
||||
)
|
||||
|
||||
# Since we mock everything, checking isinstance might be tricky if DatasetMetadataBinding
|
||||
# is not the exact class used in the service (imports match).
|
||||
# But we can check the count.
|
||||
# If it were added, there would be 2 calls. If skipped, 1 call.
|
||||
assert mock_db.session.add.call_count == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,483 @@
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
|
||||
|
||||
|
||||
class RagPipelineTaskProxyTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for RagPipelineTaskProxy tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
|
||||
"""Create mock features with billing configuration."""
|
||||
features = Mock()
|
||||
features.billing = Mock()
|
||||
features.billing.enabled = billing_enabled
|
||||
features.billing.subscription = Mock()
|
||||
features.billing.subscription.plan = plan
|
||||
return features
|
||||
|
||||
@staticmethod
|
||||
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
|
||||
"""Create mock TenantIsolatedTaskQueue."""
|
||||
queue = Mock(spec=TenantIsolatedTaskQueue)
|
||||
queue.get_task_key.return_value = "task_key" if has_task_key else None
|
||||
queue.push_tasks = Mock()
|
||||
queue.set_task_waiting_time = Mock()
|
||||
return queue
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_invoke_entity(
|
||||
pipeline_id: str = "pipeline-123",
|
||||
user_id: str = "user-456",
|
||||
tenant_id: str = "tenant-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
streaming: bool = True,
|
||||
workflow_execution_id: str | None = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> RagPipelineInvokeEntity:
|
||||
"""Create RagPipelineInvokeEntity instance for testing."""
|
||||
return RagPipelineInvokeEntity(
|
||||
pipeline_id=pipeline_id,
|
||||
application_generate_entity={"key": "value"},
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_id=workflow_id,
|
||||
streaming=streaming,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_task_proxy(
|
||||
dataset_tenant_id: str = "tenant-123",
|
||||
user_id: str = "user-456",
|
||||
rag_pipeline_invoke_entities: list[RagPipelineInvokeEntity] | None = None,
|
||||
) -> RagPipelineTaskProxy:
|
||||
"""Create RagPipelineTaskProxy instance for testing."""
|
||||
if rag_pipeline_invoke_entities is None:
|
||||
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
|
||||
return RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
@staticmethod
|
||||
def create_mock_upload_file(file_id: str = "file-123") -> Mock:
|
||||
"""Create mock upload file."""
|
||||
upload_file = Mock()
|
||||
upload_file.id = file_id
|
||||
return upload_file
|
||||
|
||||
|
||||
class TestRagPipelineTaskProxy:
|
||||
"""Test cases for RagPipelineTaskProxy class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test RagPipelineTaskProxy initialization."""
|
||||
# Arrange
|
||||
dataset_tenant_id = "tenant-123"
|
||||
user_id = "user-456"
|
||||
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
|
||||
|
||||
# Act
|
||||
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
# Assert
|
||||
assert proxy._dataset_tenant_id == dataset_tenant_id
|
||||
assert proxy._user_id == user_id
|
||||
assert proxy._rag_pipeline_invoke_entities == rag_pipeline_invoke_entities
|
||||
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
|
||||
assert proxy._tenant_isolated_task_queue._tenant_id == dataset_tenant_id
|
||||
assert proxy._tenant_isolated_task_queue._unique_key == "pipeline"
|
||||
|
||||
def test_initialization_with_empty_entities(self):
|
||||
"""Test initialization with empty rag_pipeline_invoke_entities."""
|
||||
# Arrange
|
||||
dataset_tenant_id = "tenant-123"
|
||||
user_id = "user-456"
|
||||
rag_pipeline_invoke_entities = []
|
||||
|
||||
# Act
|
||||
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
# Assert
|
||||
assert proxy._dataset_tenant_id == dataset_tenant_id
|
||||
assert proxy._user_id == user_id
|
||||
assert proxy._rag_pipeline_invoke_entities == []
|
||||
|
||||
def test_initialization_with_multiple_entities(self):
|
||||
"""Test initialization with multiple rag_pipeline_invoke_entities."""
|
||||
# Arrange
|
||||
dataset_tenant_id = "tenant-123"
|
||||
user_id = "user-456"
|
||||
rag_pipeline_invoke_entities = [
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-3"),
|
||||
]
|
||||
|
||||
# Act
|
||||
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
|
||||
|
||||
# Assert
|
||||
assert len(proxy._rag_pipeline_invoke_entities) == 3
|
||||
assert proxy._rag_pipeline_invoke_entities[0].pipeline_id == "pipeline-1"
|
||||
assert proxy._rag_pipeline_invoke_entities[1].pipeline_id == "pipeline-2"
|
||||
assert proxy._rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3"
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
def test_features_property(self, mock_feature_service):
|
||||
"""Test cached_property features."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features()
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
|
||||
# Act
|
||||
features1 = proxy.features
|
||||
features2 = proxy.features # Second call should use cached property
|
||||
|
||||
# Assert
|
||||
assert features1 == mock_features
|
||||
assert features2 == mock_features
|
||||
assert features1 is features2 # Should be the same instance due to caching
|
||||
mock_feature_service.get_features.assert_called_once_with("tenant-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_upload_invoke_entities(self, mock_db, mock_file_service_class):
|
||||
"""Test _upload_invoke_entities method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
result = proxy._upload_invoke_entities()
|
||||
|
||||
# Assert
|
||||
assert result == "file-123"
|
||||
mock_file_service_class.assert_called_once_with(mock_db.engine)
|
||||
|
||||
# Verify upload_text was called with correct parameters
|
||||
mock_file_service.upload_text.assert_called_once()
|
||||
call_args = mock_file_service.upload_text.call_args
|
||||
json_text, name, user_id, tenant_id = call_args[0]
|
||||
|
||||
assert name == "rag_pipeline_invoke_entities.json"
|
||||
assert user_id == "user-456"
|
||||
assert tenant_id == "tenant-123"
|
||||
|
||||
# Verify JSON content
|
||||
parsed_json = json.loads(json_text)
|
||||
assert len(parsed_json) == 1
|
||||
assert parsed_json[0]["pipeline_id"] == "pipeline-123"
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class):
|
||||
"""Test _upload_invoke_entities method with multiple entities."""
|
||||
# Arrange
|
||||
entities = [
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
|
||||
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
|
||||
]
|
||||
proxy = RagPipelineTaskProxy("tenant-123", "user-456", entities)
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-456")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
result = proxy._upload_invoke_entities()
|
||||
|
||||
# Assert
|
||||
assert result == "file-456"
|
||||
|
||||
# Verify JSON content contains both entities
|
||||
call_args = mock_file_service.upload_text.call_args
|
||||
json_text = call_args[0][0]
|
||||
parsed_json = json.loads(json_text)
|
||||
assert len(parsed_json) == 2
|
||||
assert parsed_json[0]["pipeline_id"] == "pipeline-1"
|
||||
assert parsed_json[1]["pipeline_id"] == "pipeline-2"
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_direct_queue(self, mock_task):
|
||||
"""Test _send_to_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue()
|
||||
upload_file_id = "file-123"
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_direct_queue(upload_file_id, mock_task)
|
||||
|
||||
# If sent to direct queue, tenant_isolated_task_queue should not be called
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
# Celery should be called directly
|
||||
mock_task.delay.assert_called_once_with(
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
|
||||
)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when task key exists."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=True
|
||||
)
|
||||
upload_file_id = "file-123"
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(upload_file_id, mock_task)
|
||||
|
||||
# If task key exists, should push tasks to the queue
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once_with([upload_file_id])
|
||||
# Celery should not be called directly
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_tenant_queue_without_task_key(self, mock_task):
|
||||
"""Test _send_to_tenant_queue when no task key exists."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
|
||||
has_task_key=False
|
||||
)
|
||||
upload_file_id = "file-123"
|
||||
mock_task.delay = Mock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(upload_file_id, mock_task)
|
||||
|
||||
# If no task key, should set task waiting time key first
|
||||
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
|
||||
mock_task.delay.assert_called_once_with(
|
||||
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
|
||||
)
|
||||
|
||||
# The first task should be sent to celery directly, so push tasks should not be called
|
||||
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
|
||||
def test_send_to_default_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_default_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
upload_file_id = "file-123"
|
||||
|
||||
# Act
|
||||
proxy._send_to_default_tenant_queue(upload_file_id)
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
|
||||
def test_send_to_priority_tenant_queue(self, mock_task):
|
||||
"""Test _send_to_priority_tenant_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_tenant_queue = Mock()
|
||||
upload_file_id = "file-123"
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_tenant_queue(upload_file_id)
|
||||
|
||||
# Assert
|
||||
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
|
||||
def test_send_to_priority_direct_queue(self, mock_task):
|
||||
"""Test _send_to_priority_direct_queue method."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_direct_queue = Mock()
|
||||
upload_file_id = "file-123"
|
||||
|
||||
# Act
|
||||
proxy._send_to_priority_direct_queue(upload_file_id)
|
||||
|
||||
# Assert
|
||||
proxy._send_to_direct_queue.assert_called_once_with(upload_file_id, mock_task)
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method when billing is enabled with sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_default_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing is enabled with sandbox plan, should send to default tenant queue
|
||||
proxy._send_to_default_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_billing_enabled_non_sandbox_plan(
|
||||
self, mock_db, mock_file_service_class, mock_feature_service
|
||||
):
|
||||
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.TEAM
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing is enabled with non-sandbox plan, should send to priority tenant queue
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method when billing is disabled."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_direct_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# If billing is disabled, for example: self-hosted or enterprise, should send to priority direct queue
|
||||
proxy._send_to_priority_direct_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class):
|
||||
"""Test _dispatch method when upload_file_id is empty."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = Mock()
|
||||
mock_upload_file.id = "" # Empty file ID
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="upload_file_id is empty"):
|
||||
proxy._dispatch()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method with empty plan string."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test _dispatch method with None plan."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._send_to_priority_tenant_queue = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
|
||||
def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service):
|
||||
"""Test delay method integration."""
|
||||
# Arrange
|
||||
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
|
||||
billing_enabled=True, plan=CloudPlan.SANDBOX
|
||||
)
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
|
||||
proxy._dispatch = Mock()
|
||||
|
||||
mock_file_service = Mock()
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
|
||||
mock_file_service.upload_text.return_value = mock_upload_file
|
||||
|
||||
# Act
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
proxy._dispatch.assert_called_once()
|
||||
|
||||
@patch("services.rag_pipeline.rag_pipeline_task_proxy.logger")
|
||||
def test_delay_method_with_empty_entities(self, mock_logger):
|
||||
"""Test delay method with empty rag_pipeline_invoke_entities."""
|
||||
# Arrange
|
||||
proxy = RagPipelineTaskProxy("tenant-123", "user-456", [])
|
||||
|
||||
# Act
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s", "tenant-123", "user-456"
|
||||
)
|
||||
779
dify/api/tests/unit_tests/services/test_schedule_service.py
Normal file
779
dify/api/tests/unit_tests/services/test_schedule_service.py
Normal file
@@ -0,0 +1,779 @@
|
||||
import unittest
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError
|
||||
from events.event_handlers.sync_workflow_schedule_when_app_published import (
|
||||
sync_schedule_from_workflow,
|
||||
)
|
||||
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from models.workflow import Workflow
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
|
||||
|
||||
class TestScheduleService(unittest.TestCase):
|
||||
"""Test cases for ScheduleService class."""
|
||||
|
||||
def test_calculate_next_run_at_valid_cron(self):
|
||||
"""Test calculating next run time with valid cron expression."""
|
||||
# Test daily cron at 10:30 AM
|
||||
cron_expr = "30 10 * * *"
|
||||
timezone = "UTC"
|
||||
base_time = datetime(2025, 8, 29, 9, 0, 0, tzinfo=UTC)
|
||||
|
||||
next_run = calculate_next_run_at(cron_expr, timezone, base_time)
|
||||
|
||||
assert next_run is not None
|
||||
assert next_run.hour == 10
|
||||
assert next_run.minute == 30
|
||||
assert next_run.day == 29
|
||||
|
||||
def test_calculate_next_run_at_with_timezone(self):
|
||||
"""Test calculating next run time with different timezone."""
|
||||
cron_expr = "0 9 * * *" # 9:00 AM
|
||||
timezone = "America/New_York"
|
||||
base_time = datetime(2025, 8, 29, 12, 0, 0, tzinfo=UTC) # 8:00 AM EDT
|
||||
|
||||
next_run = calculate_next_run_at(cron_expr, timezone, base_time)
|
||||
|
||||
assert next_run is not None
|
||||
# 9:00 AM EDT = 13:00 UTC (during EDT)
|
||||
expected_utc_hour = 13
|
||||
assert next_run.hour == expected_utc_hour
|
||||
|
||||
def test_calculate_next_run_at_with_last_day_of_month(self):
|
||||
"""Test calculating next run time with 'L' (last day) syntax."""
|
||||
cron_expr = "0 10 L * *" # 10:00 AM on last day of month
|
||||
timezone = "UTC"
|
||||
base_time = datetime(2025, 2, 15, 9, 0, 0, tzinfo=UTC)
|
||||
|
||||
next_run = calculate_next_run_at(cron_expr, timezone, base_time)
|
||||
|
||||
assert next_run is not None
|
||||
# February 2025 has 28 days
|
||||
assert next_run.day == 28
|
||||
assert next_run.month == 2
|
||||
|
||||
def test_calculate_next_run_at_invalid_cron(self):
|
||||
"""Test calculating next run time with invalid cron expression."""
|
||||
cron_expr = "invalid cron"
|
||||
timezone = "UTC"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
calculate_next_run_at(cron_expr, timezone)
|
||||
|
||||
def test_calculate_next_run_at_invalid_timezone(self):
|
||||
"""Test calculating next run time with invalid timezone."""
|
||||
from pytz import UnknownTimeZoneError
|
||||
|
||||
cron_expr = "30 10 * * *"
|
||||
timezone = "Invalid/Timezone"
|
||||
|
||||
with pytest.raises(UnknownTimeZoneError):
|
||||
calculate_next_run_at(cron_expr, timezone)
|
||||
|
||||
@patch("libs.schedule_utils.calculate_next_run_at")
|
||||
def test_create_schedule(self, mock_calculate_next_run):
|
||||
"""Test creating a new schedule."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_calculate_next_run.return_value = datetime(2025, 8, 30, 10, 30, 0, tzinfo=UTC)
|
||||
|
||||
config = ScheduleConfig(
|
||||
node_id="start",
|
||||
cron_expression="30 10 * * *",
|
||||
timezone="UTC",
|
||||
)
|
||||
|
||||
schedule = ScheduleService.create_schedule(
|
||||
session=mock_session,
|
||||
tenant_id="test-tenant",
|
||||
app_id="test-app",
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert schedule is not None
|
||||
assert schedule.tenant_id == "test-tenant"
|
||||
assert schedule.app_id == "test-app"
|
||||
assert schedule.node_id == "start"
|
||||
assert schedule.cron_expression == "30 10 * * *"
|
||||
assert schedule.timezone == "UTC"
|
||||
assert schedule.next_run_at is not None
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
@patch("services.trigger.schedule_service.calculate_next_run_at")
|
||||
def test_update_schedule(self, mock_calculate_next_run):
|
||||
"""Test updating an existing schedule."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_schedule = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_schedule.cron_expression = "0 12 * * *"
|
||||
mock_schedule.timezone = "America/New_York"
|
||||
mock_session.get.return_value = mock_schedule
|
||||
mock_calculate_next_run.return_value = datetime(2025, 8, 30, 12, 0, 0, tzinfo=UTC)
|
||||
|
||||
updates = SchedulePlanUpdate(
|
||||
cron_expression="0 12 * * *",
|
||||
timezone="America/New_York",
|
||||
)
|
||||
|
||||
result = ScheduleService.update_schedule(
|
||||
session=mock_session,
|
||||
schedule_id="test-schedule-id",
|
||||
updates=updates,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.cron_expression == "0 12 * * *"
|
||||
assert result.timezone == "America/New_York"
|
||||
mock_calculate_next_run.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
def test_update_schedule_not_found(self):
|
||||
"""Test updating a non-existent schedule raises exception."""
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.get.return_value = None
|
||||
|
||||
updates = SchedulePlanUpdate(
|
||||
cron_expression="0 12 * * *",
|
||||
)
|
||||
|
||||
with pytest.raises(ScheduleNotFoundError) as context:
|
||||
ScheduleService.update_schedule(
|
||||
session=mock_session,
|
||||
schedule_id="non-existent-id",
|
||||
updates=updates,
|
||||
)
|
||||
|
||||
assert "Schedule not found: non-existent-id" in str(context.value)
|
||||
mock_session.flush.assert_not_called()
|
||||
|
||||
def test_delete_schedule(self):
|
||||
"""Test deleting a schedule."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_schedule = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_session.get.return_value = mock_schedule
|
||||
|
||||
# Should not raise exception and complete successfully
|
||||
ScheduleService.delete_schedule(
|
||||
session=mock_session,
|
||||
schedule_id="test-schedule-id",
|
||||
)
|
||||
|
||||
mock_session.delete.assert_called_once_with(mock_schedule)
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
def test_delete_schedule_not_found(self):
|
||||
"""Test deleting a non-existent schedule raises exception."""
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Should raise ScheduleNotFoundError
|
||||
with pytest.raises(ScheduleNotFoundError) as context:
|
||||
ScheduleService.delete_schedule(
|
||||
session=mock_session,
|
||||
schedule_id="non-existent-id",
|
||||
)
|
||||
|
||||
assert "Schedule not found: non-existent-id" in str(context.value)
|
||||
mock_session.delete.assert_not_called()
|
||||
|
||||
@patch("services.trigger.schedule_service.select")
|
||||
def test_get_tenant_owner(self, mock_select):
|
||||
"""Test getting tenant owner account."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_account = Mock(spec=Account)
|
||||
mock_account.id = "owner-account-id"
|
||||
|
||||
# Mock owner query
|
||||
mock_owner_result = Mock(spec=TenantAccountJoin)
|
||||
mock_owner_result.account_id = "owner-account-id"
|
||||
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = mock_owner_result
|
||||
mock_session.get.return_value = mock_account
|
||||
|
||||
result = ScheduleService.get_tenant_owner(
|
||||
session=mock_session,
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == "owner-account-id"
|
||||
|
||||
@patch("services.trigger.schedule_service.select")
|
||||
def test_get_tenant_owner_fallback_to_admin(self, mock_select):
|
||||
"""Test getting tenant owner falls back to admin if no owner."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_account = Mock(spec=Account)
|
||||
mock_account.id = "admin-account-id"
|
||||
|
||||
# Mock admin query (owner returns None)
|
||||
mock_admin_result = Mock(spec=TenantAccountJoin)
|
||||
mock_admin_result.account_id = "admin-account-id"
|
||||
|
||||
mock_session.execute.return_value.scalar_one_or_none.side_effect = [None, mock_admin_result]
|
||||
mock_session.get.return_value = mock_account
|
||||
|
||||
result = ScheduleService.get_tenant_owner(
|
||||
session=mock_session,
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == "admin-account-id"
|
||||
|
||||
@patch("services.trigger.schedule_service.calculate_next_run_at")
|
||||
def test_update_next_run_at(self, mock_calculate_next_run):
|
||||
"""Test updating next run time after schedule triggered."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_schedule = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_schedule.cron_expression = "30 10 * * *"
|
||||
mock_schedule.timezone = "UTC"
|
||||
mock_session.get.return_value = mock_schedule
|
||||
|
||||
next_time = datetime(2025, 8, 31, 10, 30, 0, tzinfo=UTC)
|
||||
mock_calculate_next_run.return_value = next_time
|
||||
|
||||
result = ScheduleService.update_next_run_at(
|
||||
session=mock_session,
|
||||
schedule_id="test-schedule-id",
|
||||
)
|
||||
|
||||
assert result == next_time
|
||||
assert mock_schedule.next_run_at == next_time
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
|
||||
class TestVisualToCron(unittest.TestCase):
|
||||
"""Test cases for visual configuration to cron conversion."""
|
||||
|
||||
def test_visual_to_cron_hourly(self):
|
||||
"""Test converting hourly visual config to cron."""
|
||||
visual_config = VisualConfig(on_minute=15)
|
||||
result = ScheduleService.visual_to_cron("hourly", visual_config)
|
||||
assert result == "15 * * * *"
|
||||
|
||||
def test_visual_to_cron_daily(self):
|
||||
"""Test converting daily visual config to cron."""
|
||||
visual_config = VisualConfig(time="2:30 PM")
|
||||
result = ScheduleService.visual_to_cron("daily", visual_config)
|
||||
assert result == "30 14 * * *"
|
||||
|
||||
def test_visual_to_cron_weekly(self):
|
||||
"""Test converting weekly visual config to cron."""
|
||||
visual_config = VisualConfig(
|
||||
time="10:00 AM",
|
||||
weekdays=["mon", "wed", "fri"],
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("weekly", visual_config)
|
||||
assert result == "0 10 * * 1,3,5"
|
||||
|
||||
def test_visual_to_cron_monthly_with_specific_days(self):
|
||||
"""Test converting monthly visual config with specific days."""
|
||||
visual_config = VisualConfig(
|
||||
time="11:30 AM",
|
||||
monthly_days=[1, 15],
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
assert result == "30 11 1,15 * *"
|
||||
|
||||
def test_visual_to_cron_monthly_with_last_day(self):
|
||||
"""Test converting monthly visual config with last day using 'L' syntax."""
|
||||
visual_config = VisualConfig(
|
||||
time="11:30 AM",
|
||||
monthly_days=[1, "last"],
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
assert result == "30 11 1,L * *"
|
||||
|
||||
def test_visual_to_cron_monthly_only_last_day(self):
|
||||
"""Test converting monthly visual config with only last day."""
|
||||
visual_config = VisualConfig(
|
||||
time="9:00 PM",
|
||||
monthly_days=["last"],
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
assert result == "0 21 L * *"
|
||||
|
||||
def test_visual_to_cron_monthly_with_end_days_and_last(self):
|
||||
"""Test converting monthly visual config with days 29, 30, 31 and 'last'."""
|
||||
visual_config = VisualConfig(
|
||||
time="3:45 PM",
|
||||
monthly_days=[29, 30, 31, "last"],
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
# Should have 29,30,31,L - the L handles all possible last days
|
||||
assert result == "45 15 29,30,31,L * *"
|
||||
|
||||
def test_visual_to_cron_invalid_frequency(self):
|
||||
"""Test converting with invalid frequency."""
|
||||
with pytest.raises(ScheduleConfigError, match="Unsupported frequency: invalid"):
|
||||
ScheduleService.visual_to_cron("invalid", VisualConfig())
|
||||
|
||||
def test_visual_to_cron_weekly_no_weekdays(self):
|
||||
"""Test converting weekly with no weekdays specified."""
|
||||
visual_config = VisualConfig(time="10:00 AM")
|
||||
with pytest.raises(ScheduleConfigError, match="Weekdays are required for weekly schedules"):
|
||||
ScheduleService.visual_to_cron("weekly", visual_config)
|
||||
|
||||
def test_visual_to_cron_hourly_no_minute(self):
|
||||
"""Test converting hourly with no on_minute specified."""
|
||||
visual_config = VisualConfig() # on_minute defaults to 0
|
||||
result = ScheduleService.visual_to_cron("hourly", visual_config)
|
||||
assert result == "0 * * * *" # Should use default value 0
|
||||
|
||||
def test_visual_to_cron_daily_no_time(self):
|
||||
"""Test converting daily with no time specified."""
|
||||
visual_config = VisualConfig(time=None)
|
||||
with pytest.raises(ScheduleConfigError, match="time is required for daily schedules"):
|
||||
ScheduleService.visual_to_cron("daily", visual_config)
|
||||
|
||||
def test_visual_to_cron_weekly_no_time(self):
|
||||
"""Test converting weekly with no time specified."""
|
||||
visual_config = VisualConfig(weekdays=["mon"])
|
||||
visual_config.time = None # Override default
|
||||
with pytest.raises(ScheduleConfigError, match="time is required for weekly schedules"):
|
||||
ScheduleService.visual_to_cron("weekly", visual_config)
|
||||
|
||||
def test_visual_to_cron_monthly_no_time(self):
|
||||
"""Test converting monthly with no time specified."""
|
||||
visual_config = VisualConfig(monthly_days=[1])
|
||||
visual_config.time = None # Override default
|
||||
with pytest.raises(ScheduleConfigError, match="time is required for monthly schedules"):
|
||||
ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
|
||||
def test_visual_to_cron_monthly_duplicate_days(self):
|
||||
"""Test monthly with duplicate days should be deduplicated."""
|
||||
visual_config = VisualConfig(
|
||||
time="10:00 AM",
|
||||
monthly_days=[1, 15, 1, 15, 31], # Duplicates
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
assert result == "0 10 1,15,31 * *" # Should be deduplicated
|
||||
|
||||
def test_visual_to_cron_monthly_unsorted_days(self):
|
||||
"""Test monthly with unsorted days should be sorted."""
|
||||
visual_config = VisualConfig(
|
||||
time="2:30 PM",
|
||||
monthly_days=[20, 5, 15, 1, 10], # Unsorted
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
assert result == "30 14 1,5,10,15,20 * *" # Should be sorted
|
||||
|
||||
def test_visual_to_cron_weekly_all_weekdays(self):
|
||||
"""Test weekly with all weekdays."""
|
||||
visual_config = VisualConfig(
|
||||
time="8:00 AM",
|
||||
weekdays=["sun", "mon", "tue", "wed", "thu", "fri", "sat"],
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("weekly", visual_config)
|
||||
assert result == "0 8 * * 0,1,2,3,4,5,6"
|
||||
|
||||
def test_visual_to_cron_hourly_boundary_values(self):
|
||||
"""Test hourly with boundary minute values."""
|
||||
# Minimum value
|
||||
visual_config = VisualConfig(on_minute=0)
|
||||
result = ScheduleService.visual_to_cron("hourly", visual_config)
|
||||
assert result == "0 * * * *"
|
||||
|
||||
# Maximum value
|
||||
visual_config = VisualConfig(on_minute=59)
|
||||
result = ScheduleService.visual_to_cron("hourly", visual_config)
|
||||
assert result == "59 * * * *"
|
||||
|
||||
def test_visual_to_cron_daily_midnight_noon(self):
|
||||
"""Test daily at special times (midnight and noon)."""
|
||||
# Midnight
|
||||
visual_config = VisualConfig(time="12:00 AM")
|
||||
result = ScheduleService.visual_to_cron("daily", visual_config)
|
||||
assert result == "0 0 * * *"
|
||||
|
||||
# Noon
|
||||
visual_config = VisualConfig(time="12:00 PM")
|
||||
result = ScheduleService.visual_to_cron("daily", visual_config)
|
||||
assert result == "0 12 * * *"
|
||||
|
||||
def test_visual_to_cron_monthly_mixed_with_last_and_duplicates(self):
|
||||
"""Test monthly with mixed days, 'last', and duplicates."""
|
||||
visual_config = VisualConfig(
|
||||
time="11:45 PM",
|
||||
monthly_days=[15, 1, "last", 15, 30, 1, "last"], # Mixed with duplicates
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
assert result == "45 23 1,15,30,L * *" # Deduplicated and sorted with L at end
|
||||
|
||||
def test_visual_to_cron_weekly_single_day(self):
|
||||
"""Test weekly with single weekday."""
|
||||
visual_config = VisualConfig(
|
||||
time="6:30 PM",
|
||||
weekdays=["sun"],
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("weekly", visual_config)
|
||||
assert result == "30 18 * * 0"
|
||||
|
||||
def test_visual_to_cron_monthly_all_possible_days(self):
|
||||
"""Test monthly with all 31 days plus 'last'."""
|
||||
all_days = list(range(1, 32)) + ["last"]
|
||||
visual_config = VisualConfig(
|
||||
time="12:01 AM",
|
||||
monthly_days=all_days,
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
expected_days = ",".join([str(i) for i in range(1, 32)]) + ",L"
|
||||
assert result == f"1 0 {expected_days} * *"
|
||||
|
||||
def test_visual_to_cron_monthly_no_days(self):
|
||||
"""Test monthly without any days specified should raise error."""
|
||||
visual_config = VisualConfig(time="10:00 AM", monthly_days=[])
|
||||
with pytest.raises(ScheduleConfigError, match="Monthly days are required for monthly schedules"):
|
||||
ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
|
||||
def test_visual_to_cron_weekly_empty_weekdays_list(self):
|
||||
"""Test weekly with empty weekdays list should raise error."""
|
||||
visual_config = VisualConfig(time="10:00 AM", weekdays=[])
|
||||
with pytest.raises(ScheduleConfigError, match="Weekdays are required for weekly schedules"):
|
||||
ScheduleService.visual_to_cron("weekly", visual_config)
|
||||
|
||||
|
||||
class TestParseTime(unittest.TestCase):
|
||||
"""Test cases for time parsing function."""
|
||||
|
||||
def test_parse_time_am(self):
|
||||
"""Test parsing AM time."""
|
||||
hour, minute = convert_12h_to_24h("9:30 AM")
|
||||
assert hour == 9
|
||||
assert minute == 30
|
||||
|
||||
def test_parse_time_pm(self):
|
||||
"""Test parsing PM time."""
|
||||
hour, minute = convert_12h_to_24h("2:45 PM")
|
||||
assert hour == 14
|
||||
assert minute == 45
|
||||
|
||||
def test_parse_time_noon(self):
|
||||
"""Test parsing 12:00 PM (noon)."""
|
||||
hour, minute = convert_12h_to_24h("12:00 PM")
|
||||
assert hour == 12
|
||||
assert minute == 0
|
||||
|
||||
def test_parse_time_midnight(self):
|
||||
"""Test parsing 12:00 AM (midnight)."""
|
||||
hour, minute = convert_12h_to_24h("12:00 AM")
|
||||
assert hour == 0
|
||||
assert minute == 0
|
||||
|
||||
def test_parse_time_invalid_format(self):
|
||||
"""Test parsing invalid time format."""
|
||||
with pytest.raises(ValueError, match="Invalid time format"):
|
||||
convert_12h_to_24h("25:00")
|
||||
|
||||
def test_parse_time_invalid_hour(self):
|
||||
"""Test parsing invalid hour."""
|
||||
with pytest.raises(ValueError, match="Invalid hour: 13"):
|
||||
convert_12h_to_24h("13:00 PM")
|
||||
|
||||
def test_parse_time_invalid_minute(self):
|
||||
"""Test parsing invalid minute."""
|
||||
with pytest.raises(ValueError, match="Invalid minute: 60"):
|
||||
convert_12h_to_24h("10:60 AM")
|
||||
|
||||
def test_parse_time_empty_string(self):
|
||||
"""Test parsing empty string."""
|
||||
with pytest.raises(ValueError, match="Time string cannot be empty"):
|
||||
convert_12h_to_24h("")
|
||||
|
||||
def test_parse_time_invalid_period(self):
|
||||
"""Test parsing invalid period."""
|
||||
with pytest.raises(ValueError, match="Invalid period"):
|
||||
convert_12h_to_24h("10:30 XM")
|
||||
|
||||
|
||||
class TestExtractScheduleConfig(unittest.TestCase):
|
||||
"""Test cases for extracting schedule configuration from workflow."""
|
||||
|
||||
def test_extract_schedule_config_with_cron_mode(self):
|
||||
"""Test extracting schedule config in cron mode."""
|
||||
workflow = Mock(spec=Workflow)
|
||||
workflow.graph_dict = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "schedule-node",
|
||||
"data": {
|
||||
"type": "trigger-schedule",
|
||||
"mode": "cron",
|
||||
"cron_expression": "0 10 * * *",
|
||||
"timezone": "America/New_York",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
config = ScheduleService.extract_schedule_config(workflow)
|
||||
|
||||
assert config is not None
|
||||
assert config.node_id == "schedule-node"
|
||||
assert config.cron_expression == "0 10 * * *"
|
||||
assert config.timezone == "America/New_York"
|
||||
|
||||
def test_extract_schedule_config_with_visual_mode(self):
|
||||
"""Test extracting schedule config in visual mode."""
|
||||
workflow = Mock(spec=Workflow)
|
||||
workflow.graph_dict = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "schedule-node",
|
||||
"data": {
|
||||
"type": "trigger-schedule",
|
||||
"mode": "visual",
|
||||
"frequency": "daily",
|
||||
"visual_config": {"time": "10:30 AM"},
|
||||
"timezone": "UTC",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
config = ScheduleService.extract_schedule_config(workflow)
|
||||
|
||||
assert config is not None
|
||||
assert config.node_id == "schedule-node"
|
||||
assert config.cron_expression == "30 10 * * *"
|
||||
assert config.timezone == "UTC"
|
||||
|
||||
def test_extract_schedule_config_no_schedule_node(self):
|
||||
"""Test extracting config when no schedule node exists."""
|
||||
workflow = Mock(spec=Workflow)
|
||||
workflow.graph_dict = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "other-node",
|
||||
"data": {"type": "llm"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
config = ScheduleService.extract_schedule_config(workflow)
|
||||
assert config is None
|
||||
|
||||
def test_extract_schedule_config_invalid_graph(self):
|
||||
"""Test extracting config with invalid graph data."""
|
||||
workflow = Mock(spec=Workflow)
|
||||
workflow.graph_dict = None
|
||||
|
||||
with pytest.raises(ScheduleConfigError, match="Workflow graph is empty"):
|
||||
ScheduleService.extract_schedule_config(workflow)
|
||||
|
||||
|
||||
class TestScheduleWithTimezone(unittest.TestCase):
|
||||
"""Test cases for schedule with timezone handling."""
|
||||
|
||||
def test_visual_schedule_with_timezone_integration(self):
|
||||
"""Test complete flow: visual config → cron → execution in different timezones.
|
||||
|
||||
This test verifies that when a user in Shanghai sets a schedule for 10:30 AM,
|
||||
it runs at 10:30 AM Shanghai time, not 10:30 AM UTC.
|
||||
"""
|
||||
# User in Shanghai wants to run a task at 10:30 AM local time
|
||||
visual_config = VisualConfig(
|
||||
time="10:30 AM", # This is Shanghai time
|
||||
monthly_days=[1],
|
||||
)
|
||||
|
||||
# Convert to cron expression
|
||||
cron_expr = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
assert cron_expr is not None
|
||||
|
||||
assert cron_expr == "30 10 1 * *" # Direct conversion
|
||||
|
||||
# Now test execution with Shanghai timezone
|
||||
shanghai_tz = "Asia/Shanghai"
|
||||
# Base time: 2025-01-01 00:00:00 UTC (08:00:00 Shanghai)
|
||||
base_time = datetime(2025, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
|
||||
next_run = calculate_next_run_at(cron_expr, shanghai_tz, base_time)
|
||||
|
||||
assert next_run is not None
|
||||
|
||||
# Should run at 10:30 AM Shanghai time on Jan 1
|
||||
# 10:30 AM Shanghai = 02:30 AM UTC (Shanghai is UTC+8)
|
||||
assert next_run.year == 2025
|
||||
assert next_run.month == 1
|
||||
assert next_run.day == 1
|
||||
assert next_run.hour == 2 # 02:30 UTC
|
||||
assert next_run.minute == 30
|
||||
|
||||
def test_visual_schedule_different_timezones_same_local_time(self):
|
||||
"""Test that same visual config in different timezones runs at different UTC times.
|
||||
|
||||
This verifies that a schedule set for "9:00 AM" runs at 9 AM local time
|
||||
regardless of the timezone.
|
||||
"""
|
||||
visual_config = VisualConfig(
|
||||
time="9:00 AM",
|
||||
weekdays=["mon"],
|
||||
)
|
||||
|
||||
cron_expr = ScheduleService.visual_to_cron("weekly", visual_config)
|
||||
assert cron_expr is not None
|
||||
assert cron_expr == "0 9 * * 1"
|
||||
|
||||
# Base time: Sunday 2025-01-05 12:00:00 UTC
|
||||
base_time = datetime(2025, 1, 5, 12, 0, 0, tzinfo=UTC)
|
||||
|
||||
# Test New York (UTC-5 in January)
|
||||
ny_next = calculate_next_run_at(cron_expr, "America/New_York", base_time)
|
||||
assert ny_next is not None
|
||||
# Monday 9 AM EST = Monday 14:00 UTC
|
||||
assert ny_next.day == 6
|
||||
assert ny_next.hour == 14 # 9 AM EST = 2 PM UTC
|
||||
|
||||
# Test Tokyo (UTC+9)
|
||||
tokyo_next = calculate_next_run_at(cron_expr, "Asia/Tokyo", base_time)
|
||||
assert tokyo_next is not None
|
||||
# Monday 9 AM JST = Monday 00:00 UTC
|
||||
assert tokyo_next.day == 6
|
||||
assert tokyo_next.hour == 0 # 9 AM JST = 0 AM UTC
|
||||
|
||||
def test_visual_schedule_daily_across_dst_change(self):
|
||||
"""Test that daily schedules adjust correctly during DST changes.
|
||||
|
||||
A schedule set for "10:00 AM" should always run at 10 AM local time,
|
||||
even when DST changes.
|
||||
"""
|
||||
visual_config = VisualConfig(
|
||||
time="10:00 AM",
|
||||
)
|
||||
|
||||
cron_expr = ScheduleService.visual_to_cron("daily", visual_config)
|
||||
assert cron_expr is not None
|
||||
|
||||
assert cron_expr == "0 10 * * *"
|
||||
|
||||
# Test before DST (EST - UTC-5)
|
||||
winter_base = datetime(2025, 2, 1, 0, 0, 0, tzinfo=UTC)
|
||||
winter_next = calculate_next_run_at(cron_expr, "America/New_York", winter_base)
|
||||
assert winter_next is not None
|
||||
# 10 AM EST = 15:00 UTC
|
||||
assert winter_next.hour == 15
|
||||
|
||||
# Test during DST (EDT - UTC-4)
|
||||
summer_base = datetime(2025, 6, 1, 0, 0, 0, tzinfo=UTC)
|
||||
summer_next = calculate_next_run_at(cron_expr, "America/New_York", summer_base)
|
||||
assert summer_next is not None
|
||||
# 10 AM EDT = 14:00 UTC
|
||||
assert summer_next.hour == 14
|
||||
|
||||
|
||||
class TestSyncScheduleFromWorkflow(unittest.TestCase):
|
||||
"""Test cases for syncing schedule from workflow."""
|
||||
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.db")
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService")
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.select")
|
||||
def test_sync_schedule_create_new(self, mock_select, mock_service, mock_db):
|
||||
"""Test creating new schedule when none exists."""
|
||||
mock_session = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=None)
|
||||
Session = MagicMock(return_value=mock_session)
|
||||
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
|
||||
mock_session.scalar.return_value = None # No existing plan
|
||||
|
||||
# Mock extract_schedule_config to return a ScheduleConfig object
|
||||
mock_config = Mock(spec=ScheduleConfig)
|
||||
mock_config.node_id = "start"
|
||||
mock_config.cron_expression = "30 10 * * *"
|
||||
mock_config.timezone = "UTC"
|
||||
mock_service.extract_schedule_config.return_value = mock_config
|
||||
|
||||
mock_new_plan = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_service.create_schedule.return_value = mock_new_plan
|
||||
|
||||
workflow = Mock(spec=Workflow)
|
||||
result = sync_schedule_from_workflow("tenant-id", "app-id", workflow)
|
||||
|
||||
assert result == mock_new_plan
|
||||
mock_service.create_schedule.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.db")
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService")
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.select")
|
||||
def test_sync_schedule_update_existing(self, mock_select, mock_service, mock_db):
|
||||
"""Test updating existing schedule."""
|
||||
mock_session = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=None)
|
||||
Session = MagicMock(return_value=mock_session)
|
||||
|
||||
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
|
||||
mock_existing_plan = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_existing_plan.id = "existing-plan-id"
|
||||
mock_session.scalar.return_value = mock_existing_plan
|
||||
|
||||
# Mock extract_schedule_config to return a ScheduleConfig object
|
||||
mock_config = Mock(spec=ScheduleConfig)
|
||||
mock_config.node_id = "start"
|
||||
mock_config.cron_expression = "0 12 * * *"
|
||||
mock_config.timezone = "America/New_York"
|
||||
mock_service.extract_schedule_config.return_value = mock_config
|
||||
|
||||
mock_updated_plan = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_service.update_schedule.return_value = mock_updated_plan
|
||||
|
||||
workflow = Mock(spec=Workflow)
|
||||
result = sync_schedule_from_workflow("tenant-id", "app-id", workflow)
|
||||
|
||||
assert result == mock_updated_plan
|
||||
mock_service.update_schedule.assert_called_once()
|
||||
# Verify the arguments passed to update_schedule
|
||||
call_args = mock_service.update_schedule.call_args
|
||||
assert call_args.kwargs["session"] == mock_session
|
||||
assert call_args.kwargs["schedule_id"] == "existing-plan-id"
|
||||
updates_obj = call_args.kwargs["updates"]
|
||||
assert isinstance(updates_obj, SchedulePlanUpdate)
|
||||
assert updates_obj.node_id == "start"
|
||||
assert updates_obj.cron_expression == "0 12 * * *"
|
||||
assert updates_obj.timezone == "America/New_York"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.db")
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService")
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.select")
|
||||
def test_sync_schedule_remove_when_no_config(self, mock_select, mock_service, mock_db):
|
||||
"""Test removing schedule when no schedule config in workflow."""
|
||||
mock_session = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=None)
|
||||
Session = MagicMock(return_value=mock_session)
|
||||
|
||||
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
|
||||
mock_existing_plan = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_existing_plan.id = "existing-plan-id"
|
||||
mock_session.scalar.return_value = mock_existing_plan
|
||||
|
||||
mock_service.extract_schedule_config.return_value = None # No schedule config
|
||||
|
||||
workflow = Mock(spec=Workflow)
|
||||
result = sync_schedule_from_workflow("tenant-id", "app-id", workflow)
|
||||
|
||||
assert result is None
|
||||
# Now using ScheduleService.delete_schedule instead of session.delete
|
||||
mock_service.delete_schedule.assert_called_once_with(session=mock_session, schedule_id="existing-plan-id")
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
629
dify/api/tests/unit_tests/services/test_variable_truncator.py
Normal file
629
dify/api/tests/unit_tests/services/test_variable_truncator.py
Normal file
@@ -0,0 +1,629 @@
|
||||
"""
|
||||
Comprehensive unit tests for VariableTruncator class based on current implementation.
|
||||
|
||||
This test suite covers all functionality of the current VariableTruncator including:
|
||||
- JSON size calculation for different data types
|
||||
- String, array, and object truncation logic
|
||||
- Segment-based truncation interface
|
||||
- Helper methods for budget-based truncation
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
import functools
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from services.variable_truncator import (
|
||||
DummyVariableTruncator,
|
||||
MaxDepthExceededError,
|
||||
TruncationResult,
|
||||
UnknownTypeError,
|
||||
VariableTruncator,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file() -> File:
|
||||
return File(
|
||||
id=str(uuid4()), # Generate new UUID for File.id
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id=str(uuid.uuid4()),
|
||||
filename="test_file.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=1024,
|
||||
storage_key="initial_key",
|
||||
)
|
||||
|
||||
|
||||
_compact_json_dumps = functools.partial(json.dumps, separators=(",", ":"))
|
||||
|
||||
|
||||
class TestCalculateJsonSize:
|
||||
"""Test calculate_json_size method with different data types."""
|
||||
|
||||
@pytest.fixture
|
||||
def truncator(self):
|
||||
return VariableTruncator()
|
||||
|
||||
def test_string_size_calculation(self):
|
||||
"""Test JSON size calculation for strings."""
|
||||
# Simple ASCII string
|
||||
assert VariableTruncator.calculate_json_size("hello") == 7 # "hello" + 2 quotes
|
||||
|
||||
# Empty string
|
||||
assert VariableTruncator.calculate_json_size("") == 2 # Just quotes
|
||||
|
||||
# Unicode string
|
||||
assert VariableTruncator.calculate_json_size("你好") == 4
|
||||
|
||||
def test_number_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for numbers."""
|
||||
assert truncator.calculate_json_size(123) == 3
|
||||
assert truncator.calculate_json_size(12.34) == 5
|
||||
assert truncator.calculate_json_size(-456) == 4
|
||||
assert truncator.calculate_json_size(0) == 1
|
||||
|
||||
def test_boolean_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for booleans."""
|
||||
assert truncator.calculate_json_size(True) == 4 # "true"
|
||||
assert truncator.calculate_json_size(False) == 5 # "false"
|
||||
|
||||
def test_null_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for None/null."""
|
||||
assert truncator.calculate_json_size(None) == 4 # "null"
|
||||
|
||||
def test_array_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for arrays."""
|
||||
# Empty array
|
||||
assert truncator.calculate_json_size([]) == 2 # "[]"
|
||||
|
||||
# Simple array
|
||||
simple_array = [1, 2, 3]
|
||||
# [1,2,3] = 1 + 1 + 1 + 1 + 1 + 2 = 7 (numbers + commas + brackets)
|
||||
assert truncator.calculate_json_size(simple_array) == 7
|
||||
|
||||
# Array with strings
|
||||
string_array = ["a", "b"]
|
||||
# ["a","b"] = 3 + 3 + 1 + 2 = 9 (quoted strings + comma + brackets)
|
||||
assert truncator.calculate_json_size(string_array) == 9
|
||||
|
||||
def test_object_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for objects."""
|
||||
# Empty object
|
||||
assert truncator.calculate_json_size({}) == 2 # "{}"
|
||||
|
||||
# Simple object
|
||||
simple_obj = {"a": 1}
|
||||
# {"a":1} = 3 + 1 + 1 + 2 = 7 (key + colon + value + brackets)
|
||||
assert truncator.calculate_json_size(simple_obj) == 7
|
||||
|
||||
# Multiple keys
|
||||
multi_obj = {"a": 1, "b": 2}
|
||||
# {"a":1,"b":2} = 3 + 1 + 1 + 1 + 3 + 1 + 1 + 2 = 13
|
||||
assert truncator.calculate_json_size(multi_obj) == 13
|
||||
|
||||
def test_nested_structure_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for nested structures."""
|
||||
nested = {"items": [1, 2, {"nested": "value"}]}
|
||||
size = truncator.calculate_json_size(nested)
|
||||
assert size > 0 # Should calculate without error
|
||||
|
||||
# Verify it matches actual JSON length roughly
|
||||
|
||||
actual_json = _compact_json_dumps(nested)
|
||||
# Should be close but not exact due to UTF-8 encoding considerations
|
||||
assert abs(size - len(actual_json.encode())) <= 5
|
||||
|
||||
def test_calculate_json_size_max_depth_exceeded(self, truncator):
|
||||
"""Test that calculate_json_size handles deep nesting gracefully."""
|
||||
# Create deeply nested structure
|
||||
nested: dict[str, Any] = {"level": 0}
|
||||
current = nested
|
||||
for i in range(105): # Create deep nesting
|
||||
current["next"] = {"level": i + 1}
|
||||
current = current["next"]
|
||||
|
||||
# Should either raise an error or handle gracefully
|
||||
with pytest.raises(MaxDepthExceededError):
|
||||
truncator.calculate_json_size(nested)
|
||||
|
||||
def test_calculate_json_size_unknown_type(self, truncator):
|
||||
"""Test that calculate_json_size raises error for unknown types."""
|
||||
|
||||
class CustomType:
|
||||
pass
|
||||
|
||||
with pytest.raises(UnknownTypeError):
|
||||
truncator.calculate_json_size(CustomType())
|
||||
|
||||
|
||||
class TestStringTruncation:
|
||||
LENGTH_LIMIT = 10
|
||||
"""Test string truncation functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(string_length_limit=10)
|
||||
|
||||
def test_short_string_no_truncation(self, small_truncator):
|
||||
"""Test that short strings are not truncated."""
|
||||
short_str = "hello"
|
||||
result = small_truncator._truncate_string(short_str, self.LENGTH_LIMIT)
|
||||
assert result.value == short_str
|
||||
assert result.truncated is False
|
||||
assert result.value_size == VariableTruncator.calculate_json_size(short_str)
|
||||
|
||||
def test_long_string_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test that long strings are truncated with ellipsis."""
|
||||
long_str = "this is a very long string that exceeds the limit"
|
||||
result = small_truncator._truncate_string(long_str, self.LENGTH_LIMIT)
|
||||
|
||||
assert result.truncated is True
|
||||
assert result.value == long_str[:5] + "..."
|
||||
assert result.value_size == 10 # 10 chars + "..."
|
||||
|
||||
def test_exact_limit_string(self, small_truncator: VariableTruncator):
|
||||
"""Test string exactly at limit."""
|
||||
exact_str = "1234567890" # Exactly 10 chars
|
||||
result = small_truncator._truncate_string(exact_str, self.LENGTH_LIMIT)
|
||||
assert result.value == "12345..."
|
||||
assert result.truncated is True
|
||||
assert result.value_size == 10
|
||||
|
||||
|
||||
class TestArrayTruncation:
|
||||
"""Test array truncation functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(array_element_limit=3, max_size_bytes=100)
|
||||
|
||||
def test_small_array_no_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test that small arrays are not truncated."""
|
||||
small_array = [1, 2]
|
||||
result = small_truncator._truncate_array(small_array, 1000)
|
||||
assert result.value == small_array
|
||||
assert result.truncated is False
|
||||
|
||||
def test_array_element_limit_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test that arrays over element limit are truncated."""
|
||||
large_array = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3
|
||||
result = small_truncator._truncate_array(large_array, 1000)
|
||||
|
||||
assert result.truncated is True
|
||||
assert result.value == [1, 2, 3]
|
||||
|
||||
def test_array_size_budget_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test array truncation due to size budget constraints."""
|
||||
# Create array with strings that will exceed size budget
|
||||
large_strings = ["very long string " * 5, "another long string " * 5]
|
||||
result = small_truncator._truncate_array(large_strings, 50)
|
||||
|
||||
assert result.truncated is True
|
||||
# Should have truncated the strings within the array
|
||||
for item in result.value:
|
||||
assert isinstance(item, str)
|
||||
assert VariableTruncator.calculate_json_size(result.value) <= 50
|
||||
|
||||
def test_array_with_nested_objects(self, small_truncator):
|
||||
"""Test array truncation with nested objects."""
|
||||
nested_array = [
|
||||
{"name": "item1", "data": "some data"},
|
||||
{"name": "item2", "data": "more data"},
|
||||
{"name": "item3", "data": "even more data"},
|
||||
]
|
||||
result = small_truncator._truncate_array(nested_array, 30)
|
||||
|
||||
assert isinstance(result.value, list)
|
||||
assert len(result.value) <= 3
|
||||
for item in result.value:
|
||||
assert isinstance(item, dict)
|
||||
|
||||
|
||||
class TestObjectTruncation:
|
||||
"""Test object truncation functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(max_size_bytes=100)
|
||||
|
||||
def test_small_object_no_truncation(self, small_truncator):
|
||||
"""Test that small objects are not truncated."""
|
||||
small_obj = {"a": 1, "b": 2}
|
||||
result = small_truncator._truncate_object(small_obj, 1000)
|
||||
assert result.value == small_obj
|
||||
assert result.truncated is False
|
||||
|
||||
def test_empty_object_no_truncation(self, small_truncator):
|
||||
"""Test that empty objects are not truncated."""
|
||||
empty_obj = {}
|
||||
result = small_truncator._truncate_object(empty_obj, 100)
|
||||
assert result.value == empty_obj
|
||||
assert result.truncated is False
|
||||
|
||||
def test_object_value_truncation(self, small_truncator):
|
||||
"""Test object truncation where values are truncated to fit budget."""
|
||||
obj_with_long_values = {
|
||||
"key1": "very long string " * 10,
|
||||
"key2": "another long string " * 10,
|
||||
"key3": "third long string " * 10,
|
||||
}
|
||||
result = small_truncator._truncate_object(obj_with_long_values, 80)
|
||||
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.value, dict)
|
||||
|
||||
assert set(result.value.keys()).issubset(obj_with_long_values.keys())
|
||||
|
||||
# Values should be truncated if they exist
|
||||
for key, value in result.value.items():
|
||||
if isinstance(value, str):
|
||||
original_value = obj_with_long_values[key]
|
||||
# Value should be same or smaller
|
||||
assert len(value) <= len(original_value)
|
||||
|
||||
def test_object_key_dropping(self, small_truncator):
|
||||
"""Test object truncation where keys are dropped due to size constraints."""
|
||||
large_obj = {f"key{i:02d}": f"value{i}" for i in range(20)}
|
||||
result = small_truncator._truncate_object(large_obj, 50)
|
||||
|
||||
assert result.truncated is True
|
||||
assert len(result.value) < len(large_obj)
|
||||
|
||||
# Should maintain sorted key order
|
||||
result_keys = list(result.value.keys())
|
||||
assert result_keys == sorted(result_keys)
|
||||
|
||||
def test_object_with_nested_structures(self, small_truncator):
|
||||
"""Test object truncation with nested arrays and objects."""
|
||||
nested_obj = {"simple": "value", "array": [1, 2, 3, 4, 5], "nested": {"inner": "data", "more": ["a", "b", "c"]}}
|
||||
result = small_truncator._truncate_object(nested_obj, 60)
|
||||
|
||||
assert isinstance(result.value, dict)
|
||||
|
||||
|
||||
class TestSegmentBasedTruncation:
|
||||
"""Test the main truncate method that works with Segments."""
|
||||
|
||||
@pytest.fixture
|
||||
def truncator(self):
|
||||
return VariableTruncator()
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(string_length_limit=20, array_element_limit=3, max_size_bytes=200)
|
||||
|
||||
def test_integer_segment_no_truncation(self, truncator):
|
||||
"""Test that integer segments are never truncated."""
|
||||
segment = IntegerSegment(value=12345)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_boolean_as_integer_segment(self, truncator):
|
||||
"""Test boolean values in IntegerSegment are converted to int."""
|
||||
segment = IntegerSegment(value=True)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert isinstance(result.result, IntegerSegment)
|
||||
assert result.result.value == 1 # True converted to 1
|
||||
|
||||
def test_float_segment_no_truncation(self, truncator):
|
||||
"""Test that float segments are never truncated."""
|
||||
segment = FloatSegment(value=123.456)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_none_segment_no_truncation(self, truncator):
|
||||
"""Test that None segments are never truncated."""
|
||||
segment = NoneSegment()
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_file_segment_no_truncation(self, truncator, file):
|
||||
"""Test that file segments are never truncated."""
|
||||
file_segment = FileSegment(value=file)
|
||||
result = truncator.truncate(file_segment)
|
||||
assert result.result == file_segment
|
||||
assert result.truncated is False
|
||||
|
||||
def test_array_file_segment_no_truncation(self, truncator, file):
|
||||
"""Test that array file segments are never truncated."""
|
||||
|
||||
array_file_segment = ArrayFileSegment(value=[file] * 20)
|
||||
result = truncator.truncate(array_file_segment)
|
||||
assert result.result == array_file_segment
|
||||
assert result.truncated is False
|
||||
|
||||
def test_string_segment_small_no_truncation(self, truncator):
|
||||
"""Test small string segments are not truncated."""
|
||||
segment = StringSegment(value="hello world")
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_string_segment_large_truncation(self, small_truncator):
|
||||
"""Test large string segments are truncated."""
|
||||
long_text = "this is a very long string that will definitely exceed the limit"
|
||||
segment = StringSegment(value=long_text)
|
||||
result = small_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
assert len(result.result.value) < len(long_text)
|
||||
assert result.result.value.endswith("...")
|
||||
|
||||
def test_array_segment_small_no_truncation(self, truncator):
|
||||
"""Test small array segments are not truncated."""
|
||||
from factories.variable_factory import build_segment
|
||||
|
||||
segment = build_segment([1, 2, 3])
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_array_segment_large_truncation(self, small_truncator):
|
||||
"""Test large array segments are truncated."""
|
||||
from factories.variable_factory import build_segment
|
||||
|
||||
large_array = list(range(10)) # Exceeds element limit of 3
|
||||
segment = build_segment(large_array)
|
||||
result = small_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, ArraySegment)
|
||||
assert len(result.result.value) <= 3
|
||||
|
||||
def test_object_segment_small_no_truncation(self, truncator):
|
||||
"""Test small object segments are not truncated."""
|
||||
segment = ObjectSegment(value={"key": "value"})
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_object_segment_large_truncation(self, small_truncator):
|
||||
"""Test large object segments are truncated."""
|
||||
large_obj = {f"key{i}": f"very long value {i}" * 5 for i in range(5)}
|
||||
segment = ObjectSegment(value=large_obj)
|
||||
result = small_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, ObjectSegment)
|
||||
# Object should be smaller or equal than original
|
||||
original_size = small_truncator.calculate_json_size(large_obj)
|
||||
result_size = small_truncator.calculate_json_size(result.result.value)
|
||||
assert result_size <= original_size
|
||||
|
||||
def test_final_size_fallback_to_json_string(self, small_truncator):
|
||||
"""Test final fallback when truncated result still exceeds size limit."""
|
||||
# Create data that will still be large after initial truncation
|
||||
large_nested_data = {"data": ["very long string " * 5] * 5, "more": {"nested": "content " * 20}}
|
||||
segment = ObjectSegment(value=large_nested_data)
|
||||
|
||||
# Use very small limit to force JSON string fallback
|
||||
tiny_truncator = VariableTruncator(max_size_bytes=50)
|
||||
result = tiny_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
# Should be JSON string with possible truncation
|
||||
assert len(result.result.value) <= 53 # 50 + "..." = 53
|
||||
|
||||
def test_final_size_fallback_string_truncation(self, small_truncator):
|
||||
"""Test final fallback for string that still exceeds limit."""
|
||||
# Create very long string that exceeds string length limit
|
||||
very_long_string = "x" * 6000 # Exceeds default string_length_limit of 5000
|
||||
segment = StringSegment(value=very_long_string)
|
||||
|
||||
# Use small limit to test string fallback path
|
||||
tiny_truncator = VariableTruncator(string_length_limit=100, max_size_bytes=50)
|
||||
result = tiny_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
# Should be truncated due to string limit or final size limit
|
||||
assert len(result.result.value) <= 1000 # Much smaller than original
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error conditions."""
|
||||
|
||||
def test_empty_inputs(self):
|
||||
"""Test truncator with empty inputs."""
|
||||
truncator = VariableTruncator()
|
||||
|
||||
# Empty string
|
||||
result = truncator.truncate(StringSegment(value=""))
|
||||
assert not result.truncated
|
||||
assert result.result.value == ""
|
||||
|
||||
# Empty array
|
||||
from factories.variable_factory import build_segment
|
||||
|
||||
result = truncator.truncate(build_segment([]))
|
||||
assert not result.truncated
|
||||
assert result.result.value == []
|
||||
|
||||
# Empty object
|
||||
result = truncator.truncate(ObjectSegment(value={}))
|
||||
assert not result.truncated
|
||||
assert result.result.value == {}
|
||||
|
||||
def test_zero_and_negative_limits(self):
|
||||
"""Test truncator behavior with zero or very small limits."""
|
||||
# Zero string limit
|
||||
with pytest.raises(ValueError):
|
||||
truncator = VariableTruncator(string_length_limit=3)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
truncator = VariableTruncator(array_element_limit=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
truncator = VariableTruncator(max_size_bytes=0)
|
||||
|
||||
def test_unicode_and_special_characters(self):
|
||||
"""Test truncator with unicode and special characters."""
|
||||
truncator = VariableTruncator(string_length_limit=10)
|
||||
|
||||
# Unicode characters
|
||||
unicode_text = "🌍🚀🌍🚀🌍🚀🌍🚀🌍🚀" # Each emoji counts as 1 character
|
||||
result = truncator.truncate(StringSegment(value=unicode_text))
|
||||
if len(unicode_text) > 10:
|
||||
assert result.truncated is True
|
||||
|
||||
# Special JSON characters
|
||||
special_chars = '{"key": "value with \\"quotes\\" and \\n newlines"}'
|
||||
result = truncator.truncate(StringSegment(value=special_chars))
|
||||
assert isinstance(result.result, StringSegment)
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test realistic integration scenarios."""
|
||||
|
||||
def test_workflow_output_scenario(self):
|
||||
"""Test truncation of typical workflow output data."""
|
||||
truncator = VariableTruncator()
|
||||
|
||||
workflow_data = {
|
||||
"result": "success",
|
||||
"data": {
|
||||
"users": [
|
||||
{"id": 1, "name": "Alice", "email": "alice@example.com"},
|
||||
{"id": 2, "name": "Bob", "email": "bob@example.com"},
|
||||
]
|
||||
* 3, # Multiply to make it larger
|
||||
"metadata": {
|
||||
"count": 6,
|
||||
"processing_time": "1.23s",
|
||||
"details": "x" * 200, # Long string but not too long
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
segment = ObjectSegment(value=workflow_data)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert isinstance(result.result, (ObjectSegment, StringSegment))
|
||||
# Should handle complex nested structure appropriately
|
||||
|
||||
def test_large_text_processing_scenario(self):
|
||||
"""Test truncation of large text data."""
|
||||
truncator = VariableTruncator(string_length_limit=100)
|
||||
|
||||
large_text = "This is a very long text document. " * 20 # Make it larger than limit
|
||||
|
||||
segment = StringSegment(value=large_text)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
assert len(result.result.value) <= 103 # 100 + "..."
|
||||
assert result.result.value.endswith("...")
|
||||
|
||||
def test_mixed_data_types_scenario(self):
|
||||
"""Test truncation with mixed data types in complex structure."""
|
||||
truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300)
|
||||
|
||||
mixed_data = {
|
||||
"strings": ["short", "medium length", "very long string " * 3],
|
||||
"numbers": [1, 2.5, 999999],
|
||||
"booleans": [True, False, True],
|
||||
"nested": {
|
||||
"more_strings": ["nested string " * 2],
|
||||
"more_numbers": list(range(5)),
|
||||
"deep": {"level": 3, "content": "deep content " * 3},
|
||||
},
|
||||
"nulls": [None, None],
|
||||
}
|
||||
|
||||
segment = ObjectSegment(value=mixed_data)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
# Should handle all data types appropriately
|
||||
if result.truncated:
|
||||
# Verify the result is smaller or equal than original
|
||||
original_size = truncator.calculate_json_size(mixed_data)
|
||||
if isinstance(result.result, ObjectSegment):
|
||||
result_size = truncator.calculate_json_size(result.result.value)
|
||||
assert result_size <= original_size
|
||||
|
||||
def test_file_and_array_file_variable_mapping(self, file):
|
||||
truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300)
|
||||
|
||||
mapping = {"array_file": [file]}
|
||||
truncated_mapping, truncated = truncator.truncate_variable_mapping(mapping)
|
||||
assert truncated is False
|
||||
assert truncated_mapping == mapping
|
||||
|
||||
|
||||
def test_dummy_variable_truncator_methods():
|
||||
"""Test DummyVariableTruncator methods work correctly."""
|
||||
truncator = DummyVariableTruncator()
|
||||
|
||||
# Test truncate_variable_mapping
|
||||
test_data: dict[str, Any] = {
|
||||
"key1": "value1",
|
||||
"key2": ["item1", "item2"],
|
||||
"large_array": list(range(2000)),
|
||||
}
|
||||
result, is_truncated = truncator.truncate_variable_mapping(test_data)
|
||||
|
||||
assert result == test_data
|
||||
assert not is_truncated
|
||||
|
||||
# Test truncate method
|
||||
segment = StringSegment(value="test string")
|
||||
result = truncator.truncate(segment)
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.result == segment
|
||||
assert result.truncated is False
|
||||
|
||||
segment = ArrayNumberSegment(value=list(range(2000)))
|
||||
result = truncator.truncate(segment)
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.result == segment
|
||||
assert result.truncated is False
|
||||
501
dify/api/tests/unit_tests/services/test_webhook_service.py
Normal file
501
dify/api/tests/unit_tests/services/test_webhook_service.py
Normal file
@@ -0,0 +1,501 @@
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class TestWebhookServiceUnit:
|
||||
"""Unit tests for WebhookService focusing on business logic without database dependencies."""
|
||||
|
||||
def test_extract_webhook_data_json(self):
|
||||
"""Test webhook data extraction from JSON request."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json", "Authorization": "Bearer token"},
|
||||
query_string="version=1&format=json",
|
||||
json={"message": "hello", "count": 42},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["headers"]["Authorization"] == "Bearer token"
|
||||
# Query params are now extracted as raw strings
|
||||
assert webhook_data["query_params"]["version"] == "1"
|
||||
assert webhook_data["query_params"]["format"] == "json"
|
||||
assert webhook_data["body"]["message"] == "hello"
|
||||
assert webhook_data["body"]["count"] == 42
|
||||
assert webhook_data["files"] == {}
|
||||
|
||||
def test_extract_webhook_data_query_params_remain_strings(self):
|
||||
"""Query parameters should be extracted as raw strings without automatic conversion."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="GET",
|
||||
headers={"Content-Type": "application/json"},
|
||||
query_string="count=42&threshold=3.14&enabled=true¬e=text",
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# After refactoring, raw extraction keeps query params as strings
|
||||
assert webhook_data["query_params"]["count"] == "42"
|
||||
assert webhook_data["query_params"]["threshold"] == "3.14"
|
||||
assert webhook_data["query_params"]["enabled"] == "true"
|
||||
assert webhook_data["query_params"]["note"] == "text"
|
||||
|
||||
def test_extract_webhook_data_form_urlencoded(self):
|
||||
"""Test webhook data extraction from form URL encoded request."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={"username": "test", "password": "secret"},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["body"]["username"] == "test"
|
||||
assert webhook_data["body"]["password"] == "secret"
|
||||
|
||||
def test_extract_webhook_data_multipart_with_files(self):
|
||||
"""Test webhook data extraction from multipart form with files."""
|
||||
app = Flask(__name__)
|
||||
|
||||
# Create a mock file
|
||||
file_content = b"test file content"
|
||||
file_storage = FileStorage(stream=BytesIO(file_content), filename="test.txt", content_type="text/plain")
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "multipart/form-data"},
|
||||
data={"message": "test", "upload": file_storage},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
with patch.object(WebhookService, "_process_file_uploads") as mock_process_files:
|
||||
mock_process_files.return_value = {"upload": "mocked_file_obj"}
|
||||
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["body"]["message"] == "test"
|
||||
assert webhook_data["files"]["upload"] == "mocked_file_obj"
|
||||
mock_process_files.assert_called_once()
|
||||
|
||||
def test_extract_webhook_data_raw_text(self):
|
||||
"""Test webhook data extraction from raw text request."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook", method="POST", headers={"Content-Type": "text/plain"}, data="raw text content"
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["body"]["raw"] == "raw text content"
|
||||
|
||||
def test_extract_webhook_data_invalid_json(self):
|
||||
"""Test webhook data extraction with invalid JSON."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook", method="POST", headers={"Content-Type": "application/json"}, data="invalid json"
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
with pytest.raises(ValueError, match="Invalid JSON body"):
|
||||
WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
def test_generate_webhook_response_default(self):
|
||||
"""Test webhook response generation with default values."""
|
||||
node_config = {"data": {}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 200
|
||||
assert response_data["status"] == "success"
|
||||
assert "Webhook processed successfully" in response_data["message"]
|
||||
|
||||
def test_generate_webhook_response_custom_json(self):
|
||||
"""Test webhook response generation with custom JSON response."""
|
||||
node_config = {"data": {"status_code": 201, "response_body": '{"result": "created", "id": 123}'}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 201
|
||||
assert response_data["result"] == "created"
|
||||
assert response_data["id"] == 123
|
||||
|
||||
def test_generate_webhook_response_custom_text(self):
|
||||
"""Test webhook response generation with custom text response."""
|
||||
node_config = {"data": {"status_code": 202, "response_body": "Request accepted for processing"}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 202
|
||||
assert response_data["message"] == "Request accepted for processing"
|
||||
|
||||
def test_generate_webhook_response_invalid_json(self):
|
||||
"""Test webhook response generation with invalid JSON response."""
|
||||
node_config = {"data": {"status_code": 400, "response_body": '{"invalid": json}'}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 400
|
||||
assert response_data["message"] == '{"invalid": json}'
|
||||
|
||||
def test_generate_webhook_response_empty_response_body(self):
|
||||
"""Test webhook response generation with empty response body."""
|
||||
node_config = {"data": {"status_code": 204, "response_body": ""}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 204
|
||||
assert response_data["status"] == "success"
|
||||
assert "Webhook processed successfully" in response_data["message"]
|
||||
|
||||
def test_generate_webhook_response_array_json(self):
|
||||
"""Test webhook response generation with JSON array response."""
|
||||
node_config = {"data": {"status_code": 200, "response_body": '[{"id": 1}, {"id": 2}]'}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 200
|
||||
assert isinstance(response_data, list)
|
||||
assert len(response_data) == 2
|
||||
assert response_data[0]["id"] == 1
|
||||
assert response_data[1]["id"] == 2
|
||||
|
||||
@patch("services.trigger.webhook_service.ToolFileManager")
|
||||
@patch("services.trigger.webhook_service.file_factory")
|
||||
def test_process_file_uploads_success(self, mock_file_factory, mock_tool_file_manager):
|
||||
"""Test successful file upload processing."""
|
||||
# Mock ToolFileManager
|
||||
mock_tool_file_instance = MagicMock()
|
||||
mock_tool_file_manager.return_value = mock_tool_file_instance
|
||||
|
||||
# Mock file creation
|
||||
mock_tool_file = MagicMock()
|
||||
mock_tool_file.id = "test_file_id"
|
||||
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
|
||||
|
||||
# Mock file factory
|
||||
mock_file_obj = MagicMock()
|
||||
mock_file_factory.build_from_mapping.return_value = mock_file_obj
|
||||
|
||||
# Create mock files
|
||||
files = {
|
||||
"file1": MagicMock(filename="test1.txt", content_type="text/plain"),
|
||||
"file2": MagicMock(filename="test2.jpg", content_type="image/jpeg"),
|
||||
}
|
||||
|
||||
# Mock file reads
|
||||
files["file1"].read.return_value = b"content1"
|
||||
files["file2"].read.return_value = b"content2"
|
||||
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
result = WebhookService._process_file_uploads(files, webhook_trigger)
|
||||
|
||||
assert len(result) == 2
|
||||
assert "file1" in result
|
||||
assert "file2" in result
|
||||
|
||||
# Verify file processing was called for each file
|
||||
assert mock_tool_file_manager.call_count == 2
|
||||
assert mock_file_factory.build_from_mapping.call_count == 2
|
||||
|
||||
@patch("services.trigger.webhook_service.ToolFileManager")
|
||||
@patch("services.trigger.webhook_service.file_factory")
|
||||
def test_process_file_uploads_with_errors(self, mock_file_factory, mock_tool_file_manager):
|
||||
"""Test file upload processing with errors."""
|
||||
# Mock ToolFileManager
|
||||
mock_tool_file_instance = MagicMock()
|
||||
mock_tool_file_manager.return_value = mock_tool_file_instance
|
||||
|
||||
# Mock file creation
|
||||
mock_tool_file = MagicMock()
|
||||
mock_tool_file.id = "test_file_id"
|
||||
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
|
||||
|
||||
# Mock file factory
|
||||
mock_file_obj = MagicMock()
|
||||
mock_file_factory.build_from_mapping.return_value = mock_file_obj
|
||||
|
||||
# Create mock files, one will fail
|
||||
files = {
|
||||
"good_file": MagicMock(filename="test.txt", content_type="text/plain"),
|
||||
"bad_file": MagicMock(filename="test.bad", content_type="text/plain"),
|
||||
}
|
||||
|
||||
files["good_file"].read.return_value = b"content"
|
||||
files["bad_file"].read.side_effect = Exception("Read error")
|
||||
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
result = WebhookService._process_file_uploads(files, webhook_trigger)
|
||||
|
||||
# Should process the good file and skip the bad one
|
||||
assert len(result) == 1
|
||||
assert "good_file" in result
|
||||
assert "bad_file" not in result
|
||||
|
||||
def test_process_file_uploads_empty_filename(self):
|
||||
"""Test file upload processing with empty filename."""
|
||||
files = {
|
||||
"no_filename": MagicMock(filename="", content_type="text/plain"),
|
||||
"none_filename": MagicMock(filename=None, content_type="text/plain"),
|
||||
}
|
||||
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
result = WebhookService._process_file_uploads(files, webhook_trigger)
|
||||
|
||||
# Should skip files without filenames
|
||||
assert len(result) == 0
|
||||
|
||||
def test_validate_json_value_string(self):
|
||||
"""Test JSON value validation for string type."""
|
||||
# Valid string
|
||||
result = WebhookService._validate_json_value("name", "hello", "string")
|
||||
assert result == "hello"
|
||||
|
||||
# Invalid string (number) - should raise ValueError
|
||||
with pytest.raises(ValueError, match="Expected string, got int"):
|
||||
WebhookService._validate_json_value("name", 123, "string")
|
||||
|
||||
def test_validate_json_value_number(self):
|
||||
"""Test JSON value validation for number type."""
|
||||
# Valid integer
|
||||
result = WebhookService._validate_json_value("count", 42, "number")
|
||||
assert result == 42
|
||||
|
||||
# Valid float
|
||||
result = WebhookService._validate_json_value("price", 19.99, "number")
|
||||
assert result == 19.99
|
||||
|
||||
# Invalid number (string) - should raise ValueError
|
||||
with pytest.raises(ValueError, match="Expected number, got str"):
|
||||
WebhookService._validate_json_value("count", "42", "number")
|
||||
|
||||
def test_validate_json_value_bool(self):
|
||||
"""Test JSON value validation for boolean type."""
|
||||
# Valid boolean
|
||||
result = WebhookService._validate_json_value("enabled", True, "boolean")
|
||||
assert result is True
|
||||
|
||||
result = WebhookService._validate_json_value("enabled", False, "boolean")
|
||||
assert result is False
|
||||
|
||||
# Invalid boolean (string) - should raise ValueError
|
||||
with pytest.raises(ValueError, match="Expected boolean, got str"):
|
||||
WebhookService._validate_json_value("enabled", "true", "boolean")
|
||||
|
||||
def test_validate_json_value_object(self):
|
||||
"""Test JSON value validation for object type."""
|
||||
# Valid object
|
||||
result = WebhookService._validate_json_value("user", {"name": "John", "age": 30}, "object")
|
||||
assert result == {"name": "John", "age": 30}
|
||||
|
||||
# Invalid object (string) - should raise ValueError
|
||||
with pytest.raises(ValueError, match="Expected object, got str"):
|
||||
WebhookService._validate_json_value("user", "not_an_object", "object")
|
||||
|
||||
def test_validate_json_value_array_string(self):
|
||||
"""Test JSON value validation for array[string] type."""
|
||||
# Valid array of strings
|
||||
result = WebhookService._validate_json_value("tags", ["tag1", "tag2", "tag3"], "array[string]")
|
||||
assert result == ["tag1", "tag2", "tag3"]
|
||||
|
||||
# Invalid - not an array
|
||||
with pytest.raises(ValueError, match="Expected array of strings, got str"):
|
||||
WebhookService._validate_json_value("tags", "not_an_array", "array[string]")
|
||||
|
||||
# Invalid - array with non-strings
|
||||
with pytest.raises(ValueError, match="Expected array of strings, got list"):
|
||||
WebhookService._validate_json_value("tags", ["tag1", 123, "tag3"], "array[string]")
|
||||
|
||||
def test_validate_json_value_array_number(self):
|
||||
"""Test JSON value validation for array[number] type."""
|
||||
# Valid array of numbers
|
||||
result = WebhookService._validate_json_value("scores", [1, 2.5, 3, 4.7], "array[number]")
|
||||
assert result == [1, 2.5, 3, 4.7]
|
||||
|
||||
# Invalid - array with non-numbers
|
||||
with pytest.raises(ValueError, match="Expected array of numbers, got list"):
|
||||
WebhookService._validate_json_value("scores", [1, "2", 3], "array[number]")
|
||||
|
||||
def test_validate_json_value_array_bool(self):
|
||||
"""Test JSON value validation for array[boolean] type."""
|
||||
# Valid array of booleans
|
||||
result = WebhookService._validate_json_value("flags", [True, False, True], "array[boolean]")
|
||||
assert result == [True, False, True]
|
||||
|
||||
# Invalid - array with non-booleans
|
||||
with pytest.raises(ValueError, match="Expected array of booleans, got list"):
|
||||
WebhookService._validate_json_value("flags", [True, "false", True], "array[boolean]")
|
||||
|
||||
def test_validate_json_value_array_object(self):
|
||||
"""Test JSON value validation for array[object] type."""
|
||||
# Valid array of objects
|
||||
result = WebhookService._validate_json_value("users", [{"name": "John"}, {"name": "Jane"}], "array[object]")
|
||||
assert result == [{"name": "John"}, {"name": "Jane"}]
|
||||
|
||||
# Invalid - array with non-objects
|
||||
with pytest.raises(ValueError, match="Expected array of objects, got list"):
|
||||
WebhookService._validate_json_value("users", [{"name": "John"}, "not_object"], "array[object]")
|
||||
|
||||
def test_convert_form_value_string(self):
|
||||
"""Test form value conversion for string type."""
|
||||
result = WebhookService._convert_form_value("test", "hello", "string")
|
||||
assert result == "hello"
|
||||
|
||||
def test_convert_form_value_number(self):
|
||||
"""Test form value conversion for number type."""
|
||||
# Integer
|
||||
result = WebhookService._convert_form_value("count", "42", "number")
|
||||
assert result == 42
|
||||
|
||||
# Float
|
||||
result = WebhookService._convert_form_value("price", "19.99", "number")
|
||||
assert result == 19.99
|
||||
|
||||
# Invalid number
|
||||
with pytest.raises(ValueError, match="Cannot convert 'not_a_number' to number"):
|
||||
WebhookService._convert_form_value("count", "not_a_number", "number")
|
||||
|
||||
def test_convert_form_value_boolean(self):
|
||||
"""Test form value conversion for boolean type."""
|
||||
# True values
|
||||
assert WebhookService._convert_form_value("flag", "true", "boolean") is True
|
||||
assert WebhookService._convert_form_value("flag", "1", "boolean") is True
|
||||
assert WebhookService._convert_form_value("flag", "yes", "boolean") is True
|
||||
|
||||
# False values
|
||||
assert WebhookService._convert_form_value("flag", "false", "boolean") is False
|
||||
assert WebhookService._convert_form_value("flag", "0", "boolean") is False
|
||||
assert WebhookService._convert_form_value("flag", "no", "boolean") is False
|
||||
|
||||
# Invalid boolean
|
||||
with pytest.raises(ValueError, match="Cannot convert 'maybe' to boolean"):
|
||||
WebhookService._convert_form_value("flag", "maybe", "boolean")
|
||||
|
||||
def test_extract_and_validate_webhook_data_success(self):
|
||||
"""Test successful unified data extraction and validation."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
query_string="count=42&enabled=true",
|
||||
json={"message": "hello", "age": 25},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
node_config = {
|
||||
"data": {
|
||||
"method": "post",
|
||||
"content_type": "application/json",
|
||||
"params": [
|
||||
{"name": "count", "type": "number", "required": True},
|
||||
{"name": "enabled", "type": "boolean", "required": True},
|
||||
],
|
||||
"body": [
|
||||
{"name": "message", "type": "string", "required": True},
|
||||
{"name": "age", "type": "number", "required": True},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
result = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
# Check that types are correctly converted
|
||||
assert result["query_params"]["count"] == 42 # Converted to int
|
||||
assert result["query_params"]["enabled"] is True # Converted to bool
|
||||
assert result["body"]["message"] == "hello" # Already string
|
||||
assert result["body"]["age"] == 25 # Already number
|
||||
|
||||
def test_extract_and_validate_webhook_data_invalid_json_error(self):
|
||||
"""Invalid JSON should bubble up as a ValueError with details."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data='{"invalid": }',
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
node_config = {
|
||||
"data": {
|
||||
"method": "post",
|
||||
"content_type": "application/json",
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid JSON body"):
|
||||
WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
def test_extract_and_validate_webhook_data_validation_error(self):
|
||||
"""Test unified data extraction with validation error."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="GET", # Wrong method
|
||||
headers={"Content-Type": "application/json"},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
node_config = {
|
||||
"data": {
|
||||
"method": "post", # Expects POST
|
||||
"content_type": "application/json",
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="HTTP method mismatch"):
|
||||
WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
def test_debug_mode_parameter_handling(self):
|
||||
"""Test that the debug mode parameter is properly handled in _prepare_webhook_execution."""
|
||||
from controllers.trigger.webhook import _prepare_webhook_execution
|
||||
|
||||
# Mock the WebhookService methods
|
||||
with (
|
||||
patch.object(WebhookService, "get_webhook_trigger_and_workflow") as mock_get_trigger,
|
||||
patch.object(WebhookService, "extract_and_validate_webhook_data") as mock_extract,
|
||||
):
|
||||
mock_trigger = MagicMock()
|
||||
mock_workflow = MagicMock()
|
||||
mock_config = {"data": {"test": "config"}}
|
||||
mock_data = {"test": "data"}
|
||||
|
||||
mock_get_trigger.return_value = (mock_trigger, mock_workflow, mock_config)
|
||||
mock_extract.return_value = mock_data
|
||||
|
||||
result = _prepare_webhook_execution("test_webhook", is_debug=False)
|
||||
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
|
||||
|
||||
# Reset mock
|
||||
mock_get_trigger.reset_mock()
|
||||
|
||||
result = _prepare_webhook_execution("test_webhook", is_debug=True)
|
||||
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
|
||||
@@ -0,0 +1,200 @@
|
||||
"""Comprehensive unit tests for WorkflowRunService class.
|
||||
|
||||
This test suite covers all pause state management operations including:
|
||||
- Retrieving pause state for workflow runs
|
||||
- Saving pause state with file uploads
|
||||
- Marking paused workflows as resumed
|
||||
- Error handling and edge cases
|
||||
- Database transaction management
|
||||
- Repository-based approach testing
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||
from services.workflow_run_service import (
|
||||
WorkflowRunService,
|
||||
)
|
||||
|
||||
|
||||
class TestDataFactory:
|
||||
"""Factory class for creating test data objects."""
|
||||
|
||||
@staticmethod
|
||||
def create_workflow_run_mock(
|
||||
id: str = "workflow-run-123",
|
||||
tenant_id: str = "tenant-456",
|
||||
app_id: str = "app-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
status: str | WorkflowExecutionStatus = "paused",
|
||||
pause_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowRun object."""
|
||||
mock_run = MagicMock()
|
||||
mock_run.id = id
|
||||
mock_run.tenant_id = tenant_id
|
||||
mock_run.app_id = app_id
|
||||
mock_run.workflow_id = workflow_id
|
||||
mock_run.status = status
|
||||
mock_run.pause_id = pause_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_run, key, value)
|
||||
|
||||
return mock_run
|
||||
|
||||
@staticmethod
|
||||
def create_workflow_pause_mock(
|
||||
id: str = "pause-123",
|
||||
tenant_id: str = "tenant-456",
|
||||
app_id: str = "app-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
workflow_execution_id: str = "workflow-execution-123",
|
||||
state_file_id: str = "file-456",
|
||||
resumed_at: datetime | None = None,
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowPauseModel object."""
|
||||
mock_pause = MagicMock()
|
||||
mock_pause.id = id
|
||||
mock_pause.tenant_id = tenant_id
|
||||
mock_pause.app_id = app_id
|
||||
mock_pause.workflow_id = workflow_id
|
||||
mock_pause.workflow_execution_id = workflow_execution_id
|
||||
mock_pause.state_file_id = state_file_id
|
||||
mock_pause.resumed_at = resumed_at
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_pause, key, value)
|
||||
|
||||
return mock_pause
|
||||
|
||||
@staticmethod
|
||||
def create_upload_file_mock(
|
||||
id: str = "file-456",
|
||||
key: str = "upload_files/test/state.json",
|
||||
name: str = "state.json",
|
||||
tenant_id: str = "tenant-456",
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock UploadFile object."""
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = id
|
||||
mock_file.key = key
|
||||
mock_file.name = name
|
||||
mock_file.tenant_id = tenant_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_file, key, value)
|
||||
|
||||
return mock_file
|
||||
|
||||
@staticmethod
|
||||
def create_pause_entity_mock(
|
||||
pause_model: MagicMock | None = None,
|
||||
upload_file: MagicMock | None = None,
|
||||
) -> _PrivateWorkflowPauseEntity:
|
||||
"""Create a mock _PrivateWorkflowPauseEntity object."""
|
||||
if pause_model is None:
|
||||
pause_model = TestDataFactory.create_workflow_pause_mock()
|
||||
if upload_file is None:
|
||||
upload_file = TestDataFactory.create_upload_file_mock()
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file)
|
||||
|
||||
|
||||
class TestWorkflowRunService:
|
||||
"""Comprehensive unit tests for WorkflowRunService class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory(self):
|
||||
"""Create a mock session factory with proper session management."""
|
||||
mock_session = create_autospec(Session)
|
||||
|
||||
# Create a mock context manager for the session
|
||||
mock_session_cm = MagicMock()
|
||||
mock_session_cm.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session_cm.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Create a mock context manager for the transaction
|
||||
mock_transaction_cm = MagicMock()
|
||||
mock_transaction_cm.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_transaction_cm.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
mock_session.begin = MagicMock(return_value=mock_transaction_cm)
|
||||
|
||||
# Create mock factory that returns the context manager
|
||||
mock_factory = MagicMock(spec=sessionmaker)
|
||||
mock_factory.return_value = mock_session_cm
|
||||
|
||||
return mock_factory, mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_run_repository(self):
|
||||
"""Create a mock APIWorkflowRunRepository."""
|
||||
mock_repo = create_autospec(APIWorkflowRunRepository)
|
||||
return mock_repo
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_run_service(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Create WorkflowRunService instance with mocked dependencies."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(session_factory)
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_run_service_with_engine(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Create WorkflowRunService instance with Engine input."""
|
||||
mock_engine = create_autospec(Engine)
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(mock_engine)
|
||||
return service
|
||||
|
||||
# ==================== Initialization Tests ====================
|
||||
|
||||
def test_init_with_session_factory(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Test WorkflowRunService initialization with session_factory."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
assert service._session_factory == session_factory
|
||||
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
def test_init_with_engine(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Test WorkflowRunService initialization with Engine (should convert to sessionmaker)."""
|
||||
mock_engine = create_autospec(Engine)
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker:
|
||||
service = WorkflowRunService(mock_engine)
|
||||
|
||||
mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
|
||||
assert service._session_factory == session_factory
|
||||
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
def test_init_with_default_dependencies(self, mock_session_factory):
|
||||
"""Test WorkflowRunService initialization with default dependencies."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
assert service._session_factory == session_factory
|
||||
@@ -0,0 +1,253 @@
|
||||
"""Test cases for MCP tool transformation functionality."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Provides a mock user object."""
|
||||
user = Mock()
|
||||
user.name = "Test User"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider(mock_user):
|
||||
"""Provides a mock MCPToolProvider with a loaded user."""
|
||||
provider = Mock(spec=MCPToolProvider)
|
||||
provider.load_user.return_value = mock_user
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_no_user():
|
||||
"""Provides a mock MCPToolProvider with no user."""
|
||||
provider = Mock(spec=MCPToolProvider)
|
||||
provider.load_user.return_value = None
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_full(mock_user):
|
||||
"""Provides a fully configured mock MCPToolProvider for detailed tests."""
|
||||
provider = Mock(spec=MCPToolProvider)
|
||||
provider.id = "provider-id-123"
|
||||
provider.server_identifier = "server-identifier-456"
|
||||
provider.name = "Test MCP Provider"
|
||||
provider.provider_icon = "icon.png"
|
||||
provider.authed = True
|
||||
provider.masked_server_url = "https://*****.com/mcp"
|
||||
provider.timeout = 30
|
||||
provider.sse_read_timeout = 300
|
||||
provider.masked_headers = {"Authorization": "Bearer *****"}
|
||||
provider.decrypted_headers = {"Authorization": "Bearer secret-token"}
|
||||
|
||||
# Mock timestamp
|
||||
mock_updated_at = Mock()
|
||||
mock_updated_at.timestamp.return_value = 1234567890
|
||||
provider.updated_at = mock_updated_at
|
||||
|
||||
provider.load_user.return_value = mock_user
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_mcp_tools():
|
||||
"""Provides sample MCP tools for testing."""
|
||||
return {
|
||||
"simple": MCPTool(
|
||||
name="simple_tool", description="A simple test tool", inputSchema={"type": "object", "properties": {}}
|
||||
),
|
||||
"none_desc": MCPTool(name="tool_none_desc", description=None, inputSchema={"type": "object", "properties": {}}),
|
||||
"complex": MCPTool(
|
||||
name="complex_tool",
|
||||
description="A tool with complex parameters",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "Input text"},
|
||||
"count": {"type": "integer", "description": "Number of items", "minimum": 1, "maximum": 100},
|
||||
"options": {"type": "array", "items": {"type": "string"}, "description": "List of options"},
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class TestMCPToolTransform:
|
||||
"""Test cases for MCP tool transformation methods."""
|
||||
|
||||
def test_mcp_tool_to_user_tool_with_none_description(self, mock_provider):
|
||||
"""Test that mcp_tool_to_user_tool handles None description correctly."""
|
||||
# Create MCP tools with None description
|
||||
tools = [
|
||||
MCPTool(
|
||||
name="tool1",
|
||||
description=None, # This is the case that caused the error
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
),
|
||||
MCPTool(
|
||||
name="tool2",
|
||||
description=None,
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"param1": {"type": "string", "description": "A parameter"}},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.mcp_tool_to_user_tool(mock_provider, tools)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(tool, ToolApiEntity) for tool in result)
|
||||
|
||||
# Check first tool
|
||||
assert result[0].name == "tool1"
|
||||
assert result[0].author == "Test User"
|
||||
assert isinstance(result[0].label, I18nObject)
|
||||
assert result[0].label.en_US == "tool1"
|
||||
assert isinstance(result[0].description, I18nObject)
|
||||
assert result[0].description.en_US == "" # Should be empty string, not None
|
||||
assert result[0].description.zh_Hans == ""
|
||||
|
||||
# Check second tool
|
||||
assert result[1].name == "tool2"
|
||||
assert result[1].description.en_US == ""
|
||||
assert result[1].description.zh_Hans == ""
|
||||
|
||||
def test_mcp_tool_to_user_tool_with_description(self, mock_provider):
|
||||
"""Test that mcp_tool_to_user_tool handles normal description correctly."""
|
||||
# Create MCP tools with description
|
||||
tools = [
|
||||
MCPTool(
|
||||
name="tool_with_desc",
|
||||
description="This is a test tool that does something useful",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
]
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.mcp_tool_to_user_tool(mock_provider, tools)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ToolApiEntity)
|
||||
assert result[0].name == "tool_with_desc"
|
||||
assert result[0].description.en_US == "This is a test tool that does something useful"
|
||||
assert result[0].description.zh_Hans == "This is a test tool that does something useful"
|
||||
|
||||
def test_mcp_tool_to_user_tool_with_no_user(self, mock_provider_no_user):
|
||||
"""Test that mcp_tool_to_user_tool handles None user correctly."""
|
||||
# Create MCP tool
|
||||
tools = [MCPTool(name="tool1", description="Test tool", inputSchema={"type": "object", "properties": {}})]
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.mcp_tool_to_user_tool(mock_provider_no_user, tools)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 1
|
||||
assert result[0].author == "Anonymous"
|
||||
|
||||
def test_mcp_tool_to_user_tool_with_complex_schema(self, mock_provider, sample_mcp_tools):
|
||||
"""Test that mcp_tool_to_user_tool correctly converts complex input schemas."""
|
||||
# Use complex tool from fixtures
|
||||
tools = [sample_mcp_tools["complex"]]
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.mcp_tool_to_user_tool(mock_provider, tools)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "complex_tool"
|
||||
assert result[0].parameters is not None
|
||||
# The actual parameter conversion is handled by convert_mcp_schema_to_parameter
|
||||
# which should be tested separately
|
||||
|
||||
def test_mcp_provider_to_user_provider_for_list(self, mock_provider_full):
|
||||
"""Test mcp_provider_to_user_provider with for_list=True."""
|
||||
# Set tools data with null description
|
||||
mock_provider_full.tools = '[{"name": "tool1", "description": null, "inputSchema": {}}]'
|
||||
|
||||
# Mock the to_entity and to_api_response methods
|
||||
mock_entity = Mock()
|
||||
mock_entity.to_api_response.return_value = {
|
||||
"name": "Test MCP Provider",
|
||||
"type": ToolProviderType.MCP,
|
||||
"is_team_authorization": True,
|
||||
"server_url": "https://*****.com/mcp",
|
||||
"provider_icon": "icon.png",
|
||||
"masked_headers": {"Authorization": "Bearer *****"},
|
||||
"updated_at": 1234567890,
|
||||
"labels": [],
|
||||
"author": "Test User",
|
||||
"description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"),
|
||||
"icon": "icon.png",
|
||||
"label": I18nObject(en_US="Test MCP Provider", zh_Hans="Test MCP Provider"),
|
||||
"masked_credentials": {},
|
||||
}
|
||||
mock_provider_full.to_entity.return_value = mock_entity
|
||||
|
||||
# Call the method with for_list=True
|
||||
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=True)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == "provider-id-123" # Should use provider.id when for_list=True
|
||||
assert result.name == "Test MCP Provider"
|
||||
assert result.type == ToolProviderType.MCP
|
||||
assert result.is_team_authorization is True
|
||||
assert result.server_url == "https://*****.com/mcp"
|
||||
assert len(result.tools) == 1
|
||||
assert result.tools[0].description.en_US == "" # Should handle None description
|
||||
|
||||
def test_mcp_provider_to_user_provider_not_for_list(self, mock_provider_full):
|
||||
"""Test mcp_provider_to_user_provider with for_list=False."""
|
||||
# Set tools data with description
|
||||
mock_provider_full.tools = '[{"name": "tool1", "description": "Tool description", "inputSchema": {}}]'
|
||||
|
||||
# Mock the to_entity and to_api_response methods
|
||||
mock_entity = Mock()
|
||||
mock_entity.to_api_response.return_value = {
|
||||
"name": "Test MCP Provider",
|
||||
"type": ToolProviderType.MCP,
|
||||
"is_team_authorization": True,
|
||||
"server_url": "https://*****.com/mcp",
|
||||
"provider_icon": "icon.png",
|
||||
"masked_headers": {"Authorization": "Bearer *****"},
|
||||
"updated_at": 1234567890,
|
||||
"labels": [],
|
||||
"configuration": {"timeout": "30", "sse_read_timeout": "300"},
|
||||
"original_headers": {"Authorization": "Bearer secret-token"},
|
||||
"author": "Test User",
|
||||
"description": I18nObject(en_US="Test MCP Provider Description", zh_Hans="Test MCP Provider Description"),
|
||||
"icon": "icon.png",
|
||||
"label": I18nObject(en_US="Test MCP Provider", zh_Hans="Test MCP Provider"),
|
||||
"masked_credentials": {},
|
||||
}
|
||||
mock_provider_full.to_entity.return_value = mock_entity
|
||||
|
||||
# Call the method with for_list=False
|
||||
result = ToolTransformService.mcp_provider_to_user_provider(mock_provider_full, for_list=False)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == "server-identifier-456" # Should use server_identifier when for_list=False
|
||||
assert result.server_identifier == "server-identifier-456"
|
||||
assert result.configuration is not None
|
||||
assert result.configuration.timeout == 30
|
||||
assert result.configuration.sse_read_timeout == 300
|
||||
assert result.original_headers == {"Authorization": "Bearer secret-token"}
|
||||
assert len(result.tools) == 1
|
||||
assert result.tools[0].description.en_US == "Tool description"
|
||||
@@ -0,0 +1,301 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.api_entities import ToolApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class TestToolTransformService:
|
||||
"""Test cases for ToolTransformService.convert_tool_entity_to_api_entity method"""
|
||||
|
||||
def test_convert_tool_with_parameter_override(self):
|
||||
"""Test that runtime parameters correctly override base parameters"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
base_param2 = Mock(spec=ToolParameter)
|
||||
base_param2.name = "param2"
|
||||
base_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param2.type = "string"
|
||||
base_param2.label = "Base Param 2"
|
||||
|
||||
# Create mock runtime parameters that override base parameters
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1" # Different label to verify override
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1, base_param2]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.author == "test_author"
|
||||
assert result.name == "test_tool"
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 2
|
||||
|
||||
# Find the overridden parameter
|
||||
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
|
||||
assert overridden_param is not None
|
||||
assert overridden_param.label == "Runtime Param 1" # Should be runtime version
|
||||
|
||||
# Find the non-overridden parameter
|
||||
original_param = next((p for p in result.parameters if p.name == "param2"), None)
|
||||
assert original_param is not None
|
||||
assert original_param.label == "Base Param 2" # Should be base version
|
||||
|
||||
def test_convert_tool_with_additional_runtime_parameters(self):
|
||||
"""Test that additional runtime parameters are added to the final list"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
# Create mock runtime parameters - one that overrides and one that's new
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1"
|
||||
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "runtime_only"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "Runtime Only Param"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 2
|
||||
|
||||
# Check that both parameters are present
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert "param1" in param_names
|
||||
assert "runtime_only" in param_names
|
||||
|
||||
# Verify the overridden parameter has runtime version
|
||||
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
|
||||
assert overridden_param is not None
|
||||
assert overridden_param.label == "Runtime Param 1"
|
||||
|
||||
# Verify the new runtime parameter is included
|
||||
new_param = next((p for p in result.parameters if p.name == "runtime_only"), None)
|
||||
assert new_param is not None
|
||||
assert new_param.label == "Runtime Only Param"
|
||||
|
||||
def test_convert_tool_with_non_form_runtime_parameters(self):
|
||||
"""Test that non-FORM runtime parameters are not added as new parameters"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
# Create mock runtime parameters with different forms
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1"
|
||||
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "llm_param"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.LLM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "LLM Param"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 1 # Only the FORM parameter should be present
|
||||
|
||||
# Check that only the FORM parameter is present
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert "param1" in param_names
|
||||
assert "llm_param" not in param_names
|
||||
|
||||
def test_convert_tool_with_empty_parameters(self):
|
||||
"""Test conversion with empty base and runtime parameters"""
|
||||
# Create mock tool with no parameters
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = []
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = []
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_convert_tool_with_none_parameters(self):
|
||||
"""Test conversion when base parameters is None"""
|
||||
# Create mock tool with None parameters
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = None
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = []
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_convert_tool_parameter_order_preserved(self):
|
||||
"""Test that parameter order is preserved correctly"""
|
||||
# Create mock base parameters in specific order
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
base_param2 = Mock(spec=ToolParameter)
|
||||
base_param2.name = "param2"
|
||||
base_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param2.type = "string"
|
||||
base_param2.label = "Base Param 2"
|
||||
|
||||
base_param3 = Mock(spec=ToolParameter)
|
||||
base_param3.name = "param3"
|
||||
base_param3.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param3.type = "string"
|
||||
base_param3.label = "Base Param 3"
|
||||
|
||||
# Create runtime parameter that overrides middle parameter
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "param2"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "Runtime Param 2"
|
||||
|
||||
# Create new runtime parameter
|
||||
runtime_param4 = Mock(spec=ToolParameter)
|
||||
runtime_param4.name = "param4"
|
||||
runtime_param4.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param4.type = "string"
|
||||
runtime_param4.label = "Runtime Param 4"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1, base_param2, base_param3]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 4
|
||||
|
||||
# Check that order is maintained: base parameters first, then new runtime parameters
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert param_names == ["param1", "param2", "param3", "param4"]
|
||||
|
||||
# Verify that param2 was overridden with runtime version
|
||||
param2 = result.parameters[1]
|
||||
assert param2.name == "param2"
|
||||
assert param2.label == "Runtime Param 2"
|
||||
@@ -0,0 +1,377 @@
|
||||
"""Simplified unit tests for DraftVarLoader focusing on core functionality."""
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.variables.segments import ObjectSegment, StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from services.workflow_draft_variable_service import DraftVarLoader
|
||||
|
||||
|
||||
class TestDraftVarLoaderSimple:
|
||||
"""Simplified unit tests for DraftVarLoader core methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine(self) -> Engine:
|
||||
return Mock(spec=Engine)
|
||||
|
||||
@pytest.fixture
|
||||
def draft_var_loader(self, mock_engine):
|
||||
"""Create DraftVarLoader instance for testing."""
|
||||
return DraftVarLoader(
|
||||
engine=mock_engine, app_id="test-app-id", tenant_id="test-tenant-id", fallback_variables=[]
|
||||
)
|
||||
|
||||
def test_load_offloaded_variable_string_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with string type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test.txt"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.STRING
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_variable"
|
||||
draft_var.description = "test description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_variable"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_content = "This is the full string content"
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_content.encode()
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_variable"
|
||||
mock_variable.value = StringSegment(value=test_content)
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_variable")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_variable"
|
||||
assert variable.description == "test description"
|
||||
assert variable.value == test_content
|
||||
|
||||
# Verify storage was called correctly
|
||||
mock_storage.load.assert_called_once_with("storage/key/test.txt")
|
||||
|
||||
def test_load_offloaded_variable_object_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with object type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.OBJECT
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_object"
|
||||
draft_var.description = "test description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_object"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_object = {"key1": "value1", "key2": 42}
|
||||
test_json_content = json.dumps(test_object, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
mock_segment = ObjectSegment(value=test_object)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_object"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_object")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_object"
|
||||
assert variable.description == "test description"
|
||||
assert variable.value == test_object
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.OBJECT, test_object)
|
||||
|
||||
def test_load_offloaded_variable_missing_variable_file_unit(self, draft_var_loader):
|
||||
"""Test that assertion error is raised when variable_file is None."""
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.variable_file = None
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
def test_load_offloaded_variable_missing_upload_file_unit(self, draft_var_loader):
|
||||
"""Test that assertion error is raised when upload_file is None."""
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.upload_file = None
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
def test_load_variables_empty_selectors_unit(self, draft_var_loader):
|
||||
"""Test load_variables returns empty list for empty selectors."""
|
||||
result = draft_var_loader.load_variables([])
|
||||
assert result == []
|
||||
|
||||
def test_selector_to_tuple_unit(self, draft_var_loader):
|
||||
"""Test _selector_to_tuple method."""
|
||||
selector = ["node_id", "var_name", "extra_field"]
|
||||
result = draft_var_loader._selector_to_tuple(selector)
|
||||
assert result == ("node_id", "var_name")
|
||||
|
||||
def test_load_offloaded_variable_number_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with number type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test_number.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.NUMBER
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_number"
|
||||
draft_var.description = "test number description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_number"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_number = 123.45
|
||||
test_json_content = json.dumps(test_number)
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
from core.variables.segments import FloatSegment
|
||||
|
||||
mock_segment = FloatSegment(value=test_number)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_number"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_number")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_number"
|
||||
assert variable.description == "test number description"
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test_number.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.NUMBER, test_number)
|
||||
|
||||
def test_load_offloaded_variable_array_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with array type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test_array.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.ARRAY_ANY
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_array"
|
||||
draft_var.description = "test array description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_array"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_array = ["item1", "item2", "item3"]
|
||||
test_json_content = json.dumps(test_array)
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
|
||||
mock_segment = ArrayAnySegment(value=test_array)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_array"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_array")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_array"
|
||||
assert variable.description == "test array description"
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test_array.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.ARRAY_ANY, test_array)
|
||||
|
||||
def test_load_variables_with_offloaded_variables_unit(self, draft_var_loader):
|
||||
"""Test load_variables method with mix of regular and offloaded variables."""
|
||||
selectors = [["node1", "regular_var"], ["node2", "offloaded_var"]]
|
||||
|
||||
# Mock regular variable
|
||||
regular_draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
regular_draft_var.is_truncated.return_value = False
|
||||
regular_draft_var.node_id = "node1"
|
||||
regular_draft_var.name = "regular_var"
|
||||
regular_draft_var.get_value.return_value = StringSegment(value="regular_value")
|
||||
regular_draft_var.get_selector.return_value = ["node1", "regular_var"]
|
||||
regular_draft_var.id = "regular-var-id"
|
||||
regular_draft_var.description = "regular description"
|
||||
|
||||
# Mock offloaded variable
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/offloaded.txt"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.STRING
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
offloaded_draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_draft_var.is_truncated.return_value = True
|
||||
offloaded_draft_var.node_id = "node2"
|
||||
offloaded_draft_var.name = "offloaded_var"
|
||||
offloaded_draft_var.get_selector.return_value = ["node2", "offloaded_var"]
|
||||
offloaded_draft_var.variable_file = variable_file
|
||||
offloaded_draft_var.id = "offloaded-var-id"
|
||||
offloaded_draft_var.description = "offloaded description"
|
||||
|
||||
draft_vars = [regular_draft_var, offloaded_draft_var]
|
||||
|
||||
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
|
||||
mock_session = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
|
||||
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service
|
||||
):
|
||||
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
# Mock regular variable creation
|
||||
regular_variable = Mock()
|
||||
regular_variable.selector = ["node1", "regular_var"]
|
||||
|
||||
# Mock offloaded variable creation
|
||||
offloaded_variable = Mock()
|
||||
offloaded_variable.selector = ["node2", "offloaded_var"]
|
||||
|
||||
mock_segment_to_variable.return_value = regular_variable
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = b"offloaded_content"
|
||||
|
||||
with patch.object(draft_var_loader, "_load_offloaded_variable") as mock_load_offloaded:
|
||||
mock_load_offloaded.return_value = (("node2", "offloaded_var"), offloaded_variable)
|
||||
|
||||
with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor_cls:
|
||||
mock_executor = Mock()
|
||||
mock_executor_cls.return_value.__enter__.return_value = mock_executor
|
||||
mock_executor.map.return_value = [(("node2", "offloaded_var"), offloaded_variable)]
|
||||
|
||||
# Execute the method
|
||||
result = draft_var_loader.load_variables(selectors)
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify service method was called
|
||||
mock_service.get_draft_variables_by_selectors.assert_called_once_with(
|
||||
draft_var_loader._app_id, selectors
|
||||
)
|
||||
|
||||
# Verify offloaded variable loading was called
|
||||
mock_load_offloaded.assert_called_once_with(offloaded_draft_var)
|
||||
|
||||
def test_load_variables_all_offloaded_variables_unit(self, draft_var_loader):
|
||||
"""Test load_variables method with only offloaded variables."""
|
||||
selectors = [["node1", "offloaded_var1"], ["node2", "offloaded_var2"]]
|
||||
|
||||
# Mock first offloaded variable
|
||||
offloaded_var1 = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_var1.is_truncated.return_value = True
|
||||
offloaded_var1.node_id = "node1"
|
||||
offloaded_var1.name = "offloaded_var1"
|
||||
|
||||
# Mock second offloaded variable
|
||||
offloaded_var2 = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_var2.is_truncated.return_value = True
|
||||
offloaded_var2.node_id = "node2"
|
||||
offloaded_var2.name = "offloaded_var2"
|
||||
|
||||
draft_vars = [offloaded_var1, offloaded_var2]
|
||||
|
||||
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
|
||||
mock_session = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
|
||||
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service
|
||||
):
|
||||
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
|
||||
with patch("services.workflow_draft_variable_service.ThreadPoolExecutor") as mock_executor_cls:
|
||||
mock_executor = Mock()
|
||||
mock_executor_cls.return_value.__enter__.return_value = mock_executor
|
||||
mock_executor.map.return_value = [
|
||||
(("node1", "offloaded_var1"), Mock()),
|
||||
(("node2", "offloaded_var2"), Mock()),
|
||||
]
|
||||
|
||||
# Execute the method
|
||||
result = draft_var_loader.load_variables(selectors)
|
||||
|
||||
# Verify results - since we have only offloaded variables, should have 2 results
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify ThreadPoolExecutor was used
|
||||
mock_executor_cls.assert_called_once_with(max_workers=10)
|
||||
mock_executor.map.assert_called_once()
|
||||
@@ -0,0 +1,432 @@
|
||||
# test for api/services/workflow/workflow_converter.py
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
AdvancedChatMessageEntity,
|
||||
AdvancedChatPromptTemplateEntity,
|
||||
AdvancedCompletionPromptTemplateEntity,
|
||||
DatasetEntity,
|
||||
DatasetRetrieveConfigEntity,
|
||||
ExternalDataVariableEntity,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
VariableEntity,
|
||||
VariableEntityType,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
from models.model import AppMode
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_variables():
|
||||
value = [
|
||||
VariableEntity(
|
||||
variable="text_input",
|
||||
label="text-input",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
),
|
||||
VariableEntity(
|
||||
variable="paragraph",
|
||||
label="paragraph",
|
||||
type=VariableEntityType.PARAGRAPH,
|
||||
),
|
||||
VariableEntity(
|
||||
variable="select",
|
||||
label="select",
|
||||
type=VariableEntityType.SELECT,
|
||||
),
|
||||
]
|
||||
return value
|
||||
|
||||
|
||||
def test__convert_to_start_node(default_variables):
|
||||
# act
|
||||
result = WorkflowConverter()._convert_to_start_node(default_variables)
|
||||
|
||||
# assert
|
||||
assert isinstance(result["data"]["variables"][0]["type"], str)
|
||||
assert result["data"]["variables"][0]["type"] == "text-input"
|
||||
assert result["data"]["variables"][0]["variable"] == "text_input"
|
||||
assert result["data"]["variables"][1]["variable"] == "paragraph"
|
||||
assert result["data"]["variables"][2]["variable"] == "select"
|
||||
|
||||
|
||||
def test__convert_to_http_request_node_for_chatbot(default_variables):
|
||||
"""
|
||||
Test convert to http request nodes for chatbot
|
||||
:return:
|
||||
"""
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app_id"
|
||||
app_model.tenant_id = "tenant_id"
|
||||
app_model.mode = AppMode.CHAT
|
||||
|
||||
api_based_extension_id = "api_based_extension_id"
|
||||
mock_api_based_extension = APIBasedExtension(
|
||||
tenant_id="tenant_id",
|
||||
name="api-1",
|
||||
api_key="encrypted_api_key",
|
||||
api_endpoint="https://dify.ai",
|
||||
)
|
||||
|
||||
mock_api_based_extension.id = api_based_extension_id
|
||||
workflow_converter = WorkflowConverter()
|
||||
workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension)
|
||||
|
||||
encrypter.decrypt_token = MagicMock(return_value="api_key")
|
||||
|
||||
external_data_variables = [
|
||||
ExternalDataVariableEntity(
|
||||
variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id}
|
||||
)
|
||||
]
|
||||
|
||||
nodes, _ = workflow_converter._convert_to_http_request_node(
|
||||
app_model=app_model, variables=default_variables, external_data_variables=external_data_variables
|
||||
)
|
||||
|
||||
assert len(nodes) == 2
|
||||
assert nodes[0]["data"]["type"] == "http-request"
|
||||
|
||||
http_request_node = nodes[0]
|
||||
|
||||
assert http_request_node["data"]["method"] == "post"
|
||||
assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint
|
||||
assert http_request_node["data"]["authorization"]["type"] == "api-key"
|
||||
assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"}
|
||||
assert http_request_node["data"]["body"]["type"] == "json"
|
||||
|
||||
body_data = http_request_node["data"]["body"]["data"]
|
||||
|
||||
assert body_data
|
||||
|
||||
body_data_json = json.loads(body_data)
|
||||
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
|
||||
|
||||
body_params = body_data_json["params"]
|
||||
assert body_params["app_id"] == app_model.id
|
||||
assert body_params["tool_variable"] == external_data_variables[0].variable
|
||||
assert len(body_params["inputs"]) == 3
|
||||
assert body_params["query"] == "{{#sys.query#}}" # for chatbot
|
||||
|
||||
code_node = nodes[1]
|
||||
assert code_node["data"]["type"] == "code"
|
||||
|
||||
|
||||
def test__convert_to_http_request_node_for_workflow_app(default_variables):
|
||||
"""
|
||||
Test convert to http request nodes for workflow app
|
||||
:return:
|
||||
"""
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app_id"
|
||||
app_model.tenant_id = "tenant_id"
|
||||
app_model.mode = AppMode.WORKFLOW
|
||||
|
||||
api_based_extension_id = "api_based_extension_id"
|
||||
mock_api_based_extension = APIBasedExtension(
|
||||
tenant_id="tenant_id",
|
||||
name="api-1",
|
||||
api_key="encrypted_api_key",
|
||||
api_endpoint="https://dify.ai",
|
||||
)
|
||||
mock_api_based_extension.id = api_based_extension_id
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension)
|
||||
|
||||
encrypter.decrypt_token = MagicMock(return_value="api_key")
|
||||
|
||||
external_data_variables = [
|
||||
ExternalDataVariableEntity(
|
||||
variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id}
|
||||
)
|
||||
]
|
||||
|
||||
nodes, _ = workflow_converter._convert_to_http_request_node(
|
||||
app_model=app_model, variables=default_variables, external_data_variables=external_data_variables
|
||||
)
|
||||
|
||||
assert len(nodes) == 2
|
||||
assert nodes[0]["data"]["type"] == "http-request"
|
||||
|
||||
http_request_node = nodes[0]
|
||||
|
||||
assert http_request_node["data"]["method"] == "post"
|
||||
assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint
|
||||
assert http_request_node["data"]["authorization"]["type"] == "api-key"
|
||||
assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"}
|
||||
assert http_request_node["data"]["body"]["type"] == "json"
|
||||
|
||||
body_data = http_request_node["data"]["body"]["data"]
|
||||
|
||||
assert body_data
|
||||
|
||||
body_data_json = json.loads(body_data)
|
||||
assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY
|
||||
|
||||
body_params = body_data_json["params"]
|
||||
assert body_params["app_id"] == app_model.id
|
||||
assert body_params["tool_variable"] == external_data_variables[0].variable
|
||||
assert len(body_params["inputs"]) == 3
|
||||
assert body_params["query"] == ""
|
||||
|
||||
code_node = nodes[1]
|
||||
assert code_node["data"]["type"] == "code"
|
||||
|
||||
|
||||
def test__convert_to_knowledge_retrieval_node_for_chatbot():
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
|
||||
dataset_config = DatasetEntity(
|
||||
dataset_ids=["dataset_id_1", "dataset_id_2"],
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
top_k=5,
|
||||
score_threshold=0.8,
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
|
||||
reranking_enabled=True,
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
|
||||
|
||||
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config
|
||||
)
|
||||
assert node is not None
|
||||
|
||||
assert node["data"]["type"] == "knowledge-retrieval"
|
||||
assert node["data"]["query_variable_selector"] == ["sys", "query"]
|
||||
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
|
||||
assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value
|
||||
assert node["data"]["multiple_retrieval_config"] == {
|
||||
"top_k": dataset_config.retrieve_config.top_k,
|
||||
"score_threshold": dataset_config.retrieve_config.score_threshold,
|
||||
"reranking_model": dataset_config.retrieve_config.reranking_model,
|
||||
}
|
||||
|
||||
|
||||
def test__convert_to_knowledge_retrieval_node_for_workflow_app():
|
||||
new_app_mode = AppMode.WORKFLOW
|
||||
|
||||
dataset_config = DatasetEntity(
|
||||
dataset_ids=["dataset_id_1", "dataset_id_2"],
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable="query",
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
top_k=5,
|
||||
score_threshold=0.8,
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
|
||||
reranking_enabled=True,
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
|
||||
|
||||
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config
|
||||
)
|
||||
assert node is not None
|
||||
|
||||
assert node["data"]["type"] == "knowledge-retrieval"
|
||||
assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable]
|
||||
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
|
||||
assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value
|
||||
assert node["data"]["multiple_retrieval_config"] == {
|
||||
"top_k": dataset_config.retrieve_config.top_k,
|
||||
"score_threshold": dataset_config.retrieve_config.score_threshold,
|
||||
"reranking_model": dataset_config.retrieve_config.reranking_model,
|
||||
}
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
model = "gpt-4"
|
||||
model_mode = LLMMode.CHAT
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
start_node = workflow_converter._convert_to_start_node(default_variables)
|
||||
graph = {
|
||||
"nodes": [start_node],
|
||||
"edges": [], # no need
|
||||
}
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = model
|
||||
model_config_mock.mode = model_mode.value
|
||||
model_config_mock.parameters = {}
|
||||
model_config_mock.stop = []
|
||||
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.",
|
||||
)
|
||||
|
||||
llm_node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=new_app_mode,
|
||||
model_config=model_config_mock,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
template = prompt_template.simple_prompt_template
|
||||
assert template is not None
|
||||
for v in default_variables:
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n"
|
||||
assert llm_node["data"]["context"]["enabled"] is False
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables):
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
model = "gpt-3.5-turbo-instruct"
|
||||
model_mode = LLMMode.COMPLETION
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
start_node = workflow_converter._convert_to_start_node(default_variables)
|
||||
graph = {
|
||||
"nodes": [start_node],
|
||||
"edges": [], # no need
|
||||
}
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = model
|
||||
model_config_mock.mode = model_mode.value
|
||||
model_config_mock.parameters = {}
|
||||
model_config_mock.stop = []
|
||||
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.",
|
||||
)
|
||||
|
||||
llm_node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=new_app_mode,
|
||||
model_config=model_config_mock,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
template = prompt_template.simple_prompt_template
|
||||
assert template is not None
|
||||
for v in default_variables:
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
assert llm_node["data"]["prompt_template"]["text"] == template + "\n"
|
||||
assert llm_node["data"]["context"]["enabled"] is False
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables):
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
model = "gpt-4"
|
||||
model_mode = LLMMode.CHAT
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
start_node = workflow_converter._convert_to_start_node(default_variables)
|
||||
graph = {
|
||||
"nodes": [start_node],
|
||||
"edges": [], # no need
|
||||
}
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = model
|
||||
model_config_mock.mode = model_mode.value
|
||||
model_config_mock.parameters = {}
|
||||
model_config_mock.stop = []
|
||||
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
|
||||
messages=[
|
||||
AdvancedChatMessageEntity(
|
||||
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
),
|
||||
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
|
||||
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
llm_node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=new_app_mode,
|
||||
model_config=model_config_mock,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
assert isinstance(llm_node["data"]["prompt_template"], list)
|
||||
assert prompt_template.advanced_chat_prompt_template is not None
|
||||
assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages)
|
||||
template = prompt_template.advanced_chat_prompt_template.messages[0].text
|
||||
for v in default_variables:
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
assert llm_node["data"]["prompt_template"][0]["text"] == template
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables):
|
||||
new_app_mode = AppMode.ADVANCED_CHAT
|
||||
model = "gpt-3.5-turbo-instruct"
|
||||
model_mode = LLMMode.COMPLETION
|
||||
|
||||
workflow_converter = WorkflowConverter()
|
||||
start_node = workflow_converter._convert_to_start_node(default_variables)
|
||||
graph = {
|
||||
"nodes": [start_node],
|
||||
"edges": [], # no need
|
||||
}
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = model
|
||||
model_config_mock.mode = model_mode.value
|
||||
model_config_mock.parameters = {}
|
||||
model_config_mock.stop = []
|
||||
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
|
||||
prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ",
|
||||
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"),
|
||||
),
|
||||
)
|
||||
|
||||
llm_node = workflow_converter._convert_to_llm_node(
|
||||
original_app_mode=AppMode.CHAT,
|
||||
new_app_mode=new_app_mode,
|
||||
model_config=model_config_mock,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
assert isinstance(llm_node["data"]["prompt_template"], dict)
|
||||
assert prompt_template.advanced_completion_prompt_template is not None
|
||||
template = prompt_template.advanced_completion_prompt_template.prompt
|
||||
for v in default_variables:
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
assert llm_node["data"]["prompt_template"]["text"] == template
|
||||
@@ -0,0 +1,127 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_setup():
|
||||
mock_session_maker = MagicMock()
|
||||
workflow_service = WorkflowService(mock_session_maker)
|
||||
session = MagicMock(spec=Session)
|
||||
tenant_id = "test-tenant-id"
|
||||
workflow_id = "test-workflow-id"
|
||||
|
||||
# Mock workflow
|
||||
workflow = MagicMock(spec=Workflow)
|
||||
workflow.id = workflow_id
|
||||
workflow.tenant_id = tenant_id
|
||||
workflow.version = "1.0" # Not a draft
|
||||
workflow.tool_published = False # Not published as a tool by default
|
||||
|
||||
# Mock app
|
||||
app = MagicMock(spec=App)
|
||||
app.id = "test-app-id"
|
||||
app.name = "Test App"
|
||||
app.workflow_id = None # Not used by an app by default
|
||||
|
||||
return {
|
||||
"workflow_service": workflow_service,
|
||||
"session": session,
|
||||
"tenant_id": tenant_id,
|
||||
"workflow_id": workflow_id,
|
||||
"workflow": workflow,
|
||||
"app": app,
|
||||
}
|
||||
|
||||
|
||||
def test_delete_workflow_success(workflow_setup):
|
||||
# Setup mocks
|
||||
|
||||
# Mock the tool provider query to return None (not published as a tool)
|
||||
workflow_setup["session"].query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
workflow_setup["session"].scalar = MagicMock(
|
||||
side_effect=[workflow_setup["workflow"], None]
|
||||
) # Return workflow first, then None for app
|
||||
|
||||
# Call the method
|
||||
result = workflow_setup["workflow_service"].delete_workflow(
|
||||
session=workflow_setup["session"],
|
||||
workflow_id=workflow_setup["workflow_id"],
|
||||
tenant_id=workflow_setup["tenant_id"],
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result is True
|
||||
workflow_setup["session"].delete.assert_called_once_with(workflow_setup["workflow"])
|
||||
|
||||
|
||||
def test_delete_workflow_draft_error(workflow_setup):
|
||||
# Setup mocks
|
||||
workflow_setup["workflow"].version = "draft"
|
||||
workflow_setup["session"].scalar = MagicMock(return_value=workflow_setup["workflow"])
|
||||
|
||||
# Call the method and verify exception
|
||||
with pytest.raises(DraftWorkflowDeletionError):
|
||||
workflow_setup["workflow_service"].delete_workflow(
|
||||
session=workflow_setup["session"],
|
||||
workflow_id=workflow_setup["workflow_id"],
|
||||
tenant_id=workflow_setup["tenant_id"],
|
||||
)
|
||||
|
||||
# Verify
|
||||
workflow_setup["session"].delete.assert_not_called()
|
||||
|
||||
|
||||
def test_delete_workflow_in_use_by_app_error(workflow_setup):
|
||||
# Setup mocks
|
||||
workflow_setup["app"].workflow_id = workflow_setup["workflow_id"]
|
||||
workflow_setup["session"].scalar = MagicMock(
|
||||
side_effect=[workflow_setup["workflow"], workflow_setup["app"]]
|
||||
) # Return workflow first, then app
|
||||
|
||||
# Call the method and verify exception
|
||||
with pytest.raises(WorkflowInUseError) as excinfo:
|
||||
workflow_setup["workflow_service"].delete_workflow(
|
||||
session=workflow_setup["session"],
|
||||
workflow_id=workflow_setup["workflow_id"],
|
||||
tenant_id=workflow_setup["tenant_id"],
|
||||
)
|
||||
|
||||
# Verify error message contains app name
|
||||
assert "Cannot delete workflow that is currently in use by app" in str(excinfo.value)
|
||||
|
||||
# Verify
|
||||
workflow_setup["session"].delete.assert_not_called()
|
||||
|
||||
|
||||
def test_delete_workflow_published_as_tool_error(workflow_setup):
|
||||
# Setup mocks
|
||||
from models.tools import WorkflowToolProvider
|
||||
|
||||
# Mock the tool provider query
|
||||
mock_tool_provider = MagicMock(spec=WorkflowToolProvider)
|
||||
workflow_setup["session"].query.return_value.where.return_value.first.return_value = mock_tool_provider
|
||||
|
||||
workflow_setup["session"].scalar = MagicMock(
|
||||
side_effect=[workflow_setup["workflow"], None]
|
||||
) # Return workflow first, then None for app
|
||||
|
||||
# Call the method and verify exception
|
||||
with pytest.raises(WorkflowInUseError) as excinfo:
|
||||
workflow_setup["workflow_service"].delete_workflow(
|
||||
session=workflow_setup["session"],
|
||||
workflow_id=workflow_setup["workflow_id"],
|
||||
tenant_id=workflow_setup["tenant_id"],
|
||||
)
|
||||
|
||||
# Verify error message
|
||||
assert "Cannot delete workflow that is published as a tool" in str(excinfo.value)
|
||||
|
||||
# Verify
|
||||
workflow_setup["session"].delete.assert_not_called()
|
||||
@@ -0,0 +1,477 @@
|
||||
import dataclasses
|
||||
import secrets
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeType
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.account import Account
|
||||
from models.enums import DraftVariableType
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowDraftVariable,
|
||||
WorkflowDraftVariableFile,
|
||||
WorkflowNodeExecutionModel,
|
||||
is_system_variable_editable,
|
||||
)
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
VariableResetError,
|
||||
WorkflowDraftVariableService,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine() -> Engine:
|
||||
return Mock(spec=Engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(mock_engine) -> Session:
|
||||
mock_session = Mock(spec=Session)
|
||||
mock_session.get_bind.return_value = mock_engine
|
||||
return mock_session
|
||||
|
||||
|
||||
class TestDraftVariableSaver:
|
||||
def _get_test_app_id(self):
|
||||
suffix = secrets.token_hex(6)
|
||||
return f"test_app_id_{suffix}"
|
||||
|
||||
def test__should_variable_be_visible(self):
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_user = Account(name="test", email="test@example.com")
|
||||
mock_user.id = str(uuid.uuid4())
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
app_id=test_app_id,
|
||||
node_id="test_node_id",
|
||||
node_type=NodeType.START,
|
||||
node_execution_id="test_execution_id",
|
||||
user=mock_user,
|
||||
)
|
||||
assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False
|
||||
assert saver._should_variable_be_visible("123", NodeType.START, "output") == True
|
||||
|
||||
def test__normalize_variable_for_start_node(self):
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TestCase:
|
||||
name: str
|
||||
input_node_id: str
|
||||
input_name: str
|
||||
expected_node_id: str
|
||||
expected_name: str
|
||||
|
||||
_NODE_ID = "1747228642872"
|
||||
cases = [
|
||||
TestCase(
|
||||
name="name with `sys.` prefix should return the system node_id",
|
||||
input_node_id=_NODE_ID,
|
||||
input_name="sys.workflow_id",
|
||||
expected_node_id=SYSTEM_VARIABLE_NODE_ID,
|
||||
expected_name="workflow_id",
|
||||
),
|
||||
TestCase(
|
||||
name="name without `sys.` prefix should return the original input node_id",
|
||||
input_node_id=_NODE_ID,
|
||||
input_name="start_input",
|
||||
expected_node_id=_NODE_ID,
|
||||
expected_name="start_input",
|
||||
),
|
||||
TestCase(
|
||||
name="dummy_variable should return the original input node_id",
|
||||
input_node_id=_NODE_ID,
|
||||
input_name="__dummy__",
|
||||
expected_node_id=_NODE_ID,
|
||||
expected_name="__dummy__",
|
||||
),
|
||||
]
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_user = MagicMock()
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
app_id=test_app_id,
|
||||
node_id=_NODE_ID,
|
||||
node_type=NodeType.START,
|
||||
node_execution_id="test_execution_id",
|
||||
user=mock_user,
|
||||
)
|
||||
for idx, c in enumerate(cases, 1):
|
||||
fail_msg = f"Test case {c.name} failed, index={idx}"
|
||||
node_id, name = saver._normalize_variable_for_start_node(c.input_name)
|
||||
assert node_id == c.expected_node_id, fail_msg
|
||||
assert name == c.expected_name, fail_msg
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Mock SQLAlchemy session."""
|
||||
from sqlalchemy import Engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_engine = MagicMock(spec=Engine)
|
||||
mock_session.get_bind.return_value = mock_engine
|
||||
return mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def draft_saver(self, mock_session):
|
||||
"""Create DraftVariableSaver instance with user context."""
|
||||
# Create a mock user
|
||||
mock_user = MagicMock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.tenant_id = "test-tenant-id"
|
||||
|
||||
return DraftVariableSaver(
|
||||
session=mock_session,
|
||||
app_id="test-app-id",
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
node_execution_id="test-execution-id",
|
||||
user=mock_user,
|
||||
)
|
||||
|
||||
def test_draft_saver_with_small_variables(self, draft_saver, mock_session):
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
|
||||
) as _mock_try_offload:
|
||||
_mock_try_offload.return_value = None
|
||||
mock_segment = StringSegment(value="small value")
|
||||
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
|
||||
|
||||
# Should not have large variable metadata
|
||||
assert draft_var.file_id is None
|
||||
_mock_try_offload.return_value = None
|
||||
|
||||
def test_draft_saver_with_large_variables(self, draft_saver, mock_session):
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
|
||||
) as _mock_try_offload:
|
||||
mock_segment = StringSegment(value="small value")
|
||||
mock_draft_var_file = WorkflowDraftVariableFile(
|
||||
id=str(uuidv7()),
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.ARRAY_STRING,
|
||||
upload_file_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
_mock_try_offload.return_value = mock_segment, mock_draft_var_file
|
||||
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
|
||||
|
||||
# Should not have large variable metadata
|
||||
assert draft_var.file_id == mock_draft_var_file.id
|
||||
|
||||
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable")
|
||||
def test_save_method_integration(self, mock_batch_upsert, draft_saver):
|
||||
"""Test complete save workflow."""
|
||||
outputs = {"result": {"data": "test_output"}, "metadata": {"type": "llm_response"}}
|
||||
|
||||
draft_saver.save(outputs=outputs)
|
||||
|
||||
# Should batch upsert draft variables
|
||||
mock_batch_upsert.assert_called_once()
|
||||
draft_vars = mock_batch_upsert.call_args[0][1]
|
||||
assert len(draft_vars) == 2
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableService:
|
||||
def _get_test_app_id(self):
|
||||
suffix = secrets.token_hex(6)
|
||||
return f"test_app_id_{suffix}"
|
||||
|
||||
def _create_test_workflow(self, app_id: str) -> Workflow:
|
||||
"""Create a real Workflow instance for testing"""
|
||||
return Workflow.new(
|
||||
tenant_id="test_tenant_id",
|
||||
app_id=app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features="{}",
|
||||
created_by="test_user_id",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
|
||||
def test_reset_conversation_variable(self, mock_session):
|
||||
"""Test resetting a conversation variable"""
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create real conversation variable
|
||||
test_value = StringSegment(value="test_value")
|
||||
variable = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=test_app_id, name="test_var", value=test_value, description="Test conversation variable"
|
||||
)
|
||||
|
||||
# Mock the _reset_conv_var method
|
||||
expected_result = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=test_app_id,
|
||||
name="test_var",
|
||||
value=StringSegment(value="reset_value"),
|
||||
)
|
||||
with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv:
|
||||
result = service.reset_variable(workflow, variable)
|
||||
|
||||
mock_reset_conv.assert_called_once_with(workflow, variable)
|
||||
assert result == expected_result
|
||||
|
||||
def test_reset_node_variable_with_no_execution_id(self, mock_session):
|
||||
"""Test resetting a node variable with no execution ID - should delete variable"""
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create real node variable with no execution ID
|
||||
test_value = StringSegment(value="test_value")
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id,
|
||||
node_id="test_node_id",
|
||||
name="test_var",
|
||||
value=test_value,
|
||||
node_execution_id="exec-id", # Set initially
|
||||
)
|
||||
# Manually set to None to simulate the test condition
|
||||
variable.node_execution_id = None
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Should delete the variable and return None
|
||||
mock_session.delete.assert_called_once_with(instance=variable)
|
||||
mock_session.flush.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
def test_reset_node_variable_with_missing_execution_record(
|
||||
self,
|
||||
mock_engine,
|
||||
mock_session,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Test resetting a node variable when execution record doesn't exist"""
|
||||
mock_repo_session = Mock(spec=Session)
|
||||
|
||||
mock_session_maker = MagicMock()
|
||||
# Mock the context manager protocol for sessionmaker
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
monkeypatch.setattr("services.workflow_draft_variable_service.sessionmaker", mock_session_maker)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
# Mock the repository to return None (no execution record found)
|
||||
service._api_node_execution_repo = Mock()
|
||||
service._api_node_execution_repo.get_execution_by_id.return_value = None
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create real node variable with execution ID
|
||||
test_value = StringSegment(value="test_value")
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
|
||||
)
|
||||
# Variable is editable by default from factory method
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
mock_session_maker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
|
||||
# Should delete the variable and return None
|
||||
mock_session.delete.assert_called_once_with(instance=variable)
|
||||
mock_session.flush.assert_called_once()
|
||||
assert result is None
|
||||
|
||||
def test_reset_node_variable_with_valid_execution_record(
|
||||
self,
|
||||
mock_session,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Test resetting a node variable with valid execution record - should restore from execution"""
|
||||
mock_repo_session = Mock(spec=Session)
|
||||
|
||||
mock_session_maker = MagicMock()
|
||||
# Mock the context manager protocol for sessionmaker
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
mock_session_maker = monkeypatch.setattr(
|
||||
"services.workflow_draft_variable_service.sessionmaker", mock_session_maker
|
||||
)
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.load_full_outputs.return_value = {"test_var": "output_value"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create real node variable with execution ID
|
||||
test_value = StringSegment(value="original_value")
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
|
||||
)
|
||||
# Variable is editable by default from factory method
|
||||
|
||||
# Mock workflow methods
|
||||
mock_node_config = {"type": "test_node"}
|
||||
with (
|
||||
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config),
|
||||
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM),
|
||||
):
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Verify last_edited_at was reset
|
||||
assert variable.last_edited_at is None
|
||||
# Verify session.flush was called
|
||||
mock_session.flush.assert_called()
|
||||
|
||||
# Should return the updated variable
|
||||
assert result == variable
|
||||
|
||||
def test_reset_non_editable_system_variable_raises_error(self, mock_session):
|
||||
"""Test that resetting a non-editable system variable raises an error"""
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create a non-editable system variable (workflow_id is not editable)
|
||||
test_value = StringSegment(value="test_workflow_id")
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id,
|
||||
name="workflow_id", # This is not in _EDITABLE_SYSTEM_VARIABLE
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
editable=False, # Non-editable system variable
|
||||
)
|
||||
|
||||
with pytest.raises(VariableResetError) as exc_info:
|
||||
service.reset_variable(workflow, variable)
|
||||
assert "cannot reset system variable" in str(exc_info.value)
|
||||
assert f"variable_id={variable.id}" in str(exc_info.value)
|
||||
|
||||
def test_reset_editable_system_variable_succeeds(self, mock_session):
|
||||
"""Test that resetting an editable system variable succeeds"""
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create an editable system variable (files is editable)
|
||||
test_value = StringSegment(value="[]")
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id,
|
||||
name="files", # This is in _EDITABLE_SYSTEM_VARIABLE
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
editable=True, # Editable system variable
|
||||
)
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.load_full_outputs.return_value = {"sys.files": "[]"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Should succeed and return the variable
|
||||
assert result == variable
|
||||
assert variable.last_edited_at is None
|
||||
mock_session.flush.assert_called()
|
||||
|
||||
def test_reset_query_system_variable_succeeds(self, mock_session):
|
||||
"""Test that resetting query system variable (another editable one) succeeds"""
|
||||
service = WorkflowDraftVariableService(mock_session)
|
||||
|
||||
test_app_id = self._get_test_app_id()
|
||||
workflow = self._create_test_workflow(test_app_id)
|
||||
|
||||
# Create an editable system variable (query is editable)
|
||||
test_value = StringSegment(value="original query")
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id,
|
||||
name="query", # This is in _EDITABLE_SYSTEM_VARIABLE
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
editable=True, # Editable system variable
|
||||
)
|
||||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.load_full_outputs.return_value = {"sys.query": "reset query"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
|
||||
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
# Should succeed and return the variable
|
||||
assert result == variable
|
||||
assert variable.last_edited_at is None
|
||||
mock_session.flush.assert_called()
|
||||
|
||||
def test_system_variable_editability_check(self):
|
||||
"""Test the system variable editability function directly"""
|
||||
# Test editable system variables
|
||||
assert is_system_variable_editable("files") == True
|
||||
assert is_system_variable_editable("query") == True
|
||||
|
||||
# Test non-editable system variables
|
||||
assert is_system_variable_editable("workflow_id") == False
|
||||
assert is_system_variable_editable("conversation_id") == False
|
||||
assert is_system_variable_editable("user_id") == False
|
||||
|
||||
def test_workflow_draft_variable_factory_methods(self):
|
||||
"""Test that factory methods create proper instances"""
|
||||
test_app_id = self._get_test_app_id()
|
||||
test_value = StringSegment(value="test_value")
|
||||
|
||||
# Test conversation variable factory
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=test_app_id, name="conv_var", value=test_value, description="Test conversation variable"
|
||||
)
|
||||
assert conv_var.get_variable_type() == DraftVariableType.CONVERSATION
|
||||
assert conv_var.editable == True
|
||||
assert conv_var.node_execution_id is None
|
||||
|
||||
# Test system variable factory
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=test_app_id, name="workflow_id", value=test_value, node_execution_id="exec-id", editable=False
|
||||
)
|
||||
assert sys_var.get_variable_type() == DraftVariableType.SYS
|
||||
assert sys_var.editable == False
|
||||
assert sys_var.node_execution_id == "exec-id"
|
||||
|
||||
# Test node variable factory
|
||||
node_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=test_app_id,
|
||||
node_id="node-id",
|
||||
name="node_var",
|
||||
value=test_value,
|
||||
node_execution_id="exec-id",
|
||||
visible=True,
|
||||
editable=True,
|
||||
)
|
||||
assert node_var.get_variable_type() == DraftVariableType.NODE
|
||||
assert node_var.visible == True
|
||||
assert node_var.editable == True
|
||||
assert node_var.node_execution_id == "exec-id"
|
||||
@@ -0,0 +1,288 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
|
||||
|
||||
class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
|
||||
@pytest.fixture
|
||||
def repository(self):
|
||||
mock_session_maker = MagicMock()
|
||||
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_execution(self):
|
||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
execution.id = str(uuid4())
|
||||
execution.tenant_id = "tenant-123"
|
||||
execution.app_id = "app-456"
|
||||
execution.workflow_id = "workflow-789"
|
||||
execution.workflow_run_id = "run-101"
|
||||
execution.node_id = "node-202"
|
||||
execution.index = 1
|
||||
execution.created_at = "2023-01-01T00:00:00Z"
|
||||
return execution
|
||||
|
||||
def test_get_node_last_execution_found(self, repository, mock_execution):
|
||||
"""Test getting the last execution for a node when it exists."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = mock_execution
|
||||
|
||||
# Act
|
||||
result = repository.get_node_last_execution(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_id="workflow-789",
|
||||
node_id="node-202",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_execution
|
||||
mock_session.scalar.assert_called_once()
|
||||
# Verify the query was constructed correctly
|
||||
call_args = mock_session.scalar.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
def test_get_node_last_execution_not_found(self, repository):
|
||||
"""Test getting the last execution for a node when it doesn't exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = repository.get_node_last_execution(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_id="workflow-789",
|
||||
node_id="node-202",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_executions_by_workflow_run(self, repository, mock_execution):
|
||||
"""Test getting all executions for a workflow run."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
executions = [mock_execution]
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = executions
|
||||
|
||||
# Act
|
||||
result = repository.get_executions_by_workflow_run(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_run_id="run-101",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == executions
|
||||
mock_session.execute.assert_called_once()
|
||||
# Verify the query was constructed correctly
|
||||
call_args = mock_session.execute.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
def test_get_executions_by_workflow_run_empty(self, repository):
|
||||
"""Test getting executions for a workflow run when none exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = repository.get_executions_by_workflow_run(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_run_id="run-101",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_get_execution_by_id_found(self, repository, mock_execution):
|
||||
"""Test getting execution by ID when it exists."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = mock_execution
|
||||
|
||||
# Act
|
||||
result = repository.get_execution_by_id(mock_execution.id)
|
||||
|
||||
# Assert
|
||||
assert result == mock_execution
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_execution_by_id_not_found(self, repository):
|
||||
"""Test getting execution by ID when it doesn't exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = repository.get_execution_by_id("non-existent-id")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_repository_implements_protocol(self, repository):
|
||||
"""Test that the repository implements the required protocol methods."""
|
||||
# Verify all protocol methods are implemented
|
||||
assert hasattr(repository, "get_node_last_execution")
|
||||
assert hasattr(repository, "get_executions_by_workflow_run")
|
||||
assert hasattr(repository, "get_execution_by_id")
|
||||
|
||||
# Verify methods are callable
|
||||
assert callable(repository.get_node_last_execution)
|
||||
assert callable(repository.get_executions_by_workflow_run)
|
||||
assert callable(repository.get_execution_by_id)
|
||||
assert callable(repository.delete_expired_executions)
|
||||
assert callable(repository.delete_executions_by_app)
|
||||
assert callable(repository.get_expired_executions_batch)
|
||||
assert callable(repository.delete_executions_by_ids)
|
||||
|
||||
def test_delete_expired_executions(self, repository):
|
||||
"""Test deleting expired executions."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the select query to return some IDs first time, then empty to stop loop
|
||||
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
|
||||
|
||||
# Mock execute method to handle both select and delete statements
|
||||
def mock_execute(stmt):
|
||||
mock_result = MagicMock()
|
||||
# For select statements, return execution IDs
|
||||
if hasattr(stmt, "limit"): # This is our select statement
|
||||
mock_result.scalars.return_value.all.return_value = execution_ids
|
||||
else: # This is our delete statement
|
||||
mock_result.rowcount = 2
|
||||
return mock_result
|
||||
|
||||
mock_session.execute.side_effect = mock_execute
|
||||
|
||||
before_date = datetime(2023, 1, 1)
|
||||
|
||||
# Act
|
||||
result = repository.delete_expired_executions(
|
||||
tenant_id="tenant-123",
|
||||
before_date=before_date,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == 2
|
||||
assert mock_session.execute.call_count == 2 # One select call, one delete call
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_app(self, repository):
|
||||
"""Test deleting executions by app."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the select query to return some IDs first time, then empty to stop loop
|
||||
execution_ids = ["id1", "id2"]
|
||||
|
||||
# Mock execute method to handle both select and delete statements
|
||||
def mock_execute(stmt):
|
||||
mock_result = MagicMock()
|
||||
# For select statements, return execution IDs
|
||||
if hasattr(stmt, "limit"): # This is our select statement
|
||||
mock_result.scalars.return_value.all.return_value = execution_ids
|
||||
else: # This is our delete statement
|
||||
mock_result.rowcount = 2
|
||||
return mock_result
|
||||
|
||||
mock_session.execute.side_effect = mock_execute
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_app(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == 2
|
||||
assert mock_session.execute.call_count == 2 # One select call, one delete call
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_get_expired_executions_batch(self, repository):
|
||||
"""Test getting expired executions batch for backup."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Create mock execution objects
|
||||
mock_execution1 = MagicMock()
|
||||
mock_execution1.id = "exec-1"
|
||||
mock_execution2 = MagicMock()
|
||||
mock_execution2.id = "exec-2"
|
||||
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
|
||||
|
||||
before_date = datetime(2023, 1, 1)
|
||||
|
||||
# Act
|
||||
result = repository.get_expired_executions_batch(
|
||||
tenant_id="tenant-123",
|
||||
before_date=before_date,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0].id == "exec-1"
|
||||
assert result[1].id == "exec-2"
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_ids(self, repository):
|
||||
"""Test deleting executions by IDs."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the delete query result
|
||||
mock_result = MagicMock()
|
||||
mock_result.rowcount = 3
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
execution_ids = ["id1", "id2", "id3"]
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_ids(execution_ids)
|
||||
|
||||
# Assert
|
||||
assert result == 3
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_ids_empty_list(self, repository):
|
||||
"""Test deleting executions with empty ID list."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_ids([])
|
||||
|
||||
# Assert
|
||||
assert result == 0
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
@@ -0,0 +1,163 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class TestWorkflowService:
|
||||
@pytest.fixture
|
||||
def workflow_service(self):
|
||||
mock_session_maker = MagicMock()
|
||||
return WorkflowService(mock_session_maker)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
app = MagicMock(spec=App)
|
||||
app.id = "app-id-1"
|
||||
app.workflow_id = "workflow-id-1"
|
||||
app.tenant_id = "tenant-id-1"
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflows(self):
|
||||
workflows = []
|
||||
for i in range(5):
|
||||
workflow = MagicMock(spec=Workflow)
|
||||
workflow.id = f"workflow-id-{i}"
|
||||
workflow.app_id = "app-id-1"
|
||||
workflow.created_at = f"2023-01-0{5 - i}" # Descending date order
|
||||
workflow.created_by = "user-id-1" if i % 2 == 0 else "user-id-2"
|
||||
workflow.marked_name = f"Workflow {i}" if i % 2 == 0 else ""
|
||||
workflows.append(workflow)
|
||||
return workflows
|
||||
|
||||
def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app):
|
||||
mock_app.workflow_id = None
|
||||
mock_session = MagicMock()
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
|
||||
)
|
||||
|
||||
assert workflows == []
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_not_called()
|
||||
|
||||
def test_get_all_published_workflow_basic(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
mock_scalar_result.all.return_value = mock_workflows[:3]
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
|
||||
)
|
||||
|
||||
assert workflows == mock_workflows[:3]
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
def test_get_all_published_workflow_pagination(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
# Return 4 items when limit is 3, which should indicate has_more=True
|
||||
mock_scalar_result.all.return_value = mock_workflows[:4]
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None
|
||||
)
|
||||
|
||||
# Should return only the first 3 items
|
||||
assert len(workflows) == 3
|
||||
assert workflows == mock_workflows[:3]
|
||||
assert has_more is True
|
||||
|
||||
# Test page 2
|
||||
mock_scalar_result.all.return_value = mock_workflows[3:]
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=2, limit=3, user_id=None
|
||||
)
|
||||
|
||||
assert len(workflows) == 2
|
||||
assert has_more is False
|
||||
|
||||
def test_get_all_published_workflow_user_filter(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
# Filter workflows for user-id-1
|
||||
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1"]
|
||||
mock_scalar_result.all.return_value = filtered_workflows
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1"
|
||||
)
|
||||
|
||||
assert workflows == filtered_workflows
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
# Verify that the select contains a user filter clause
|
||||
args = mock_session.scalars.call_args[0][0]
|
||||
assert "created_by" in str(args)
|
||||
|
||||
def test_get_all_published_workflow_named_only(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
# Filter workflows that have a marked_name
|
||||
named_workflows = [w for w in mock_workflows if w.marked_name]
|
||||
mock_scalar_result.all.return_value = named_workflows
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None, named_only=True
|
||||
)
|
||||
|
||||
assert workflows == named_workflows
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
# Verify that the select contains a named_only filter clause
|
||||
args = mock_session.scalars.call_args[0][0]
|
||||
assert "marked_name !=" in str(args)
|
||||
|
||||
def test_get_all_published_workflow_combined_filters(self, workflow_service, mock_app, mock_workflows):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
# Combined filter: user-id-1 and has marked_name
|
||||
filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1" and w.marked_name]
|
||||
mock_scalar_result.all.return_value = filtered_workflows
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1", named_only=True
|
||||
)
|
||||
|
||||
assert workflows == filtered_workflows
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
# Verify that both filters are applied
|
||||
args = mock_session.scalars.call_args[0][0]
|
||||
assert "created_by" in str(args)
|
||||
assert "marked_name !=" in str(args)
|
||||
|
||||
def test_get_all_published_workflow_empty_result(self, workflow_service, mock_app):
|
||||
mock_session = MagicMock()
|
||||
mock_scalar_result = MagicMock()
|
||||
mock_scalar_result.all.return_value = []
|
||||
mock_session.scalars.return_value = mock_scalar_result
|
||||
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None
|
||||
)
|
||||
|
||||
assert workflows == []
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
Reference in New Issue
Block a user