dify
This commit is contained in:
0
dify/api/tests/unit_tests/models/__init__.py
Normal file
0
dify/api/tests/unit_tests/models/__init__.py
Normal file
14
dify/api/tests/unit_tests/models/test_account.py
Normal file
14
dify/api/tests/unit_tests/models/test_account.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
|
||||
def test_account_is_privileged_role():
|
||||
assert TenantAccountRole.ADMIN == "admin"
|
||||
assert TenantAccountRole.OWNER == "owner"
|
||||
assert TenantAccountRole.EDITOR == "editor"
|
||||
assert TenantAccountRole.NORMAL == "normal"
|
||||
|
||||
assert TenantAccountRole.is_privileged_role(TenantAccountRole.ADMIN)
|
||||
assert TenantAccountRole.is_privileged_role(TenantAccountRole.OWNER)
|
||||
assert not TenantAccountRole.is_privileged_role(TenantAccountRole.NORMAL)
|
||||
assert not TenantAccountRole.is_privileged_role(TenantAccountRole.EDITOR)
|
||||
assert not TenantAccountRole.is_privileged_role("")
|
||||
886
dify/api/tests/unit_tests/models/test_account_models.py
Normal file
886
dify/api/tests/unit_tests/models/test_account_models.py
Normal file
@@ -0,0 +1,886 @@
|
||||
"""
|
||||
Comprehensive unit tests for Account model.
|
||||
|
||||
This test suite covers:
|
||||
- Account model validation
|
||||
- Password hashing/verification
|
||||
- Account status transitions
|
||||
- Tenant relationship integrity
|
||||
- Email uniqueness constraints
|
||||
"""
|
||||
|
||||
import base64
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.password import compare_password, hash_password, valid_password
|
||||
from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
|
||||
class TestAccountModelValidation:
|
||||
"""Test suite for Account model validation and basic operations."""
|
||||
|
||||
def test_account_creation_with_required_fields(self):
|
||||
"""Test creating an account with all required fields."""
|
||||
# Arrange & Act
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
password="hashed_password",
|
||||
password_salt="salt_value",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert account.name == "Test User"
|
||||
assert account.email == "test@example.com"
|
||||
assert account.password == "hashed_password"
|
||||
assert account.password_salt == "salt_value"
|
||||
assert account.status == "active" # Default value
|
||||
|
||||
def test_account_creation_with_optional_fields(self):
|
||||
"""Test creating an account with optional fields."""
|
||||
# Arrange & Act
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
avatar="https://example.com/avatar.png",
|
||||
interface_language="en-US",
|
||||
interface_theme="dark",
|
||||
timezone="America/New_York",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert account.avatar == "https://example.com/avatar.png"
|
||||
assert account.interface_language == "en-US"
|
||||
assert account.interface_theme == "dark"
|
||||
assert account.timezone == "America/New_York"
|
||||
|
||||
def test_account_creation_without_password(self):
|
||||
"""Test creating an account without password (for invite-based registration)."""
|
||||
# Arrange & Act
|
||||
account = Account(
|
||||
name="Invited User",
|
||||
email="invited@example.com",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert account.password is None
|
||||
assert account.password_salt is None
|
||||
assert not account.is_password_set
|
||||
|
||||
def test_account_is_password_set_property(self):
|
||||
"""Test the is_password_set property."""
|
||||
# Arrange
|
||||
account_with_password = Account(
|
||||
name="User With Password",
|
||||
email="withpass@example.com",
|
||||
password="hashed_password",
|
||||
)
|
||||
account_without_password = Account(
|
||||
name="User Without Password",
|
||||
email="nopass@example.com",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert account_with_password.is_password_set
|
||||
assert not account_without_password.is_password_set
|
||||
|
||||
def test_account_default_status(self):
|
||||
"""Test that account has default status of 'active'."""
|
||||
# Arrange & Act
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert account.status == "active"
|
||||
|
||||
def test_account_get_status_method(self):
|
||||
"""Test the get_status method returns AccountStatus enum."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
status="pending",
|
||||
)
|
||||
|
||||
# Act
|
||||
status = account.get_status()
|
||||
|
||||
# Assert
|
||||
assert status == AccountStatus.PENDING
|
||||
assert isinstance(status, AccountStatus)
|
||||
|
||||
|
||||
class TestPasswordHashingAndVerification:
|
||||
"""Test suite for password hashing and verification functionality."""
|
||||
|
||||
def test_password_hashing_produces_consistent_result(self):
|
||||
"""Test that hashing the same password with the same salt produces the same result."""
|
||||
# Arrange
|
||||
password = "TestPassword123"
|
||||
salt = secrets.token_bytes(16)
|
||||
|
||||
# Act
|
||||
hash1 = hash_password(password, salt)
|
||||
hash2 = hash_password(password, salt)
|
||||
|
||||
# Assert
|
||||
assert hash1 == hash2
|
||||
|
||||
def test_password_hashing_different_salts_produce_different_hashes(self):
|
||||
"""Test that different salts produce different hashes for the same password."""
|
||||
# Arrange
|
||||
password = "TestPassword123"
|
||||
salt1 = secrets.token_bytes(16)
|
||||
salt2 = secrets.token_bytes(16)
|
||||
|
||||
# Act
|
||||
hash1 = hash_password(password, salt1)
|
||||
hash2 = hash_password(password, salt2)
|
||||
|
||||
# Assert
|
||||
assert hash1 != hash2
|
||||
|
||||
def test_password_comparison_success(self):
|
||||
"""Test successful password comparison."""
|
||||
# Arrange
|
||||
password = "TestPassword123"
|
||||
salt = secrets.token_bytes(16)
|
||||
password_hashed = hash_password(password, salt)
|
||||
|
||||
# Encode to base64 as done in the application
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
# Act
|
||||
result = compare_password(password, base64_password_hashed, base64_salt)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_password_comparison_failure(self):
|
||||
"""Test password comparison with wrong password."""
|
||||
# Arrange
|
||||
correct_password = "TestPassword123"
|
||||
wrong_password = "WrongPassword456"
|
||||
salt = secrets.token_bytes(16)
|
||||
password_hashed = hash_password(correct_password, salt)
|
||||
|
||||
# Encode to base64
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
# Act
|
||||
result = compare_password(wrong_password, base64_password_hashed, base64_salt)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_valid_password_with_correct_format(self):
|
||||
"""Test password validation with correct format."""
|
||||
# Arrange
|
||||
valid_passwords = [
|
||||
"Password123",
|
||||
"Test1234",
|
||||
"MySecure1Pass",
|
||||
"abcdefgh1",
|
||||
]
|
||||
|
||||
# Act & Assert
|
||||
for password in valid_passwords:
|
||||
result = valid_password(password)
|
||||
assert result == password
|
||||
|
||||
def test_valid_password_with_incorrect_format(self):
|
||||
"""Test password validation with incorrect format."""
|
||||
# Arrange
|
||||
invalid_passwords = [
|
||||
"short1", # Too short
|
||||
"NoNumbers", # No numbers
|
||||
"12345678", # No letters
|
||||
"Pass1", # Too short
|
||||
]
|
||||
|
||||
# Act & Assert
|
||||
for password in invalid_passwords:
|
||||
with pytest.raises(ValueError, match="Password must contain letters and numbers"):
|
||||
valid_password(password)
|
||||
|
||||
def test_password_hashing_integration_with_account(self):
|
||||
"""Test password hashing integration with Account model."""
|
||||
# Arrange
|
||||
password = "SecurePass123"
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
password_hashed = hash_password(password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
# Act
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
password=base64_password_hashed,
|
||||
password_salt=base64_salt,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert account.is_password_set
|
||||
assert compare_password(password, account.password, account.password_salt)
|
||||
|
||||
|
||||
class TestAccountStatusTransitions:
|
||||
"""Test suite for account status transitions."""
|
||||
|
||||
def test_account_status_enum_values(self):
|
||||
"""Test that AccountStatus enum has all expected values."""
|
||||
# Assert
|
||||
assert AccountStatus.PENDING == "pending"
|
||||
assert AccountStatus.UNINITIALIZED == "uninitialized"
|
||||
assert AccountStatus.ACTIVE == "active"
|
||||
assert AccountStatus.BANNED == "banned"
|
||||
assert AccountStatus.CLOSED == "closed"
|
||||
|
||||
def test_account_status_transition_pending_to_active(self):
|
||||
"""Test transitioning account status from pending to active."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
status=AccountStatus.PENDING,
|
||||
)
|
||||
|
||||
# Act
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = datetime.now(UTC)
|
||||
|
||||
# Assert
|
||||
assert account.get_status() == AccountStatus.ACTIVE
|
||||
assert account.initialized_at is not None
|
||||
|
||||
def test_account_status_transition_active_to_banned(self):
|
||||
"""Test transitioning account status from active to banned."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
|
||||
# Act
|
||||
account.status = AccountStatus.BANNED
|
||||
|
||||
# Assert
|
||||
assert account.get_status() == AccountStatus.BANNED
|
||||
|
||||
def test_account_status_transition_active_to_closed(self):
|
||||
"""Test transitioning account status from active to closed."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
|
||||
# Act
|
||||
account.status = AccountStatus.CLOSED
|
||||
|
||||
# Assert
|
||||
assert account.get_status() == AccountStatus.CLOSED
|
||||
|
||||
def test_account_status_uninitialized(self):
|
||||
"""Test account with uninitialized status."""
|
||||
# Arrange & Act
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
status=AccountStatus.UNINITIALIZED,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert account.get_status() == AccountStatus.UNINITIALIZED
|
||||
assert account.initialized_at is None
|
||||
|
||||
|
||||
class TestTenantRelationshipIntegrity:
|
||||
"""Test suite for tenant relationship integrity."""
|
||||
|
||||
@patch("models.account.db")
|
||||
def test_account_current_tenant_property(self, mock_db):
|
||||
"""Test the current_tenant property getter."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
account._current_tenant = tenant
|
||||
|
||||
# Act
|
||||
result = account.current_tenant
|
||||
|
||||
# Assert
|
||||
assert result == tenant
|
||||
|
||||
@patch("models.account.Session")
|
||||
@patch("models.account.db")
|
||||
def test_account_current_tenant_setter_with_valid_tenant(self, mock_db, mock_session_class):
|
||||
"""Test setting current_tenant with a valid tenant relationship."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
# Mock the session and queries
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock TenantAccountJoin query result
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
)
|
||||
mock_session.scalar.return_value = tenant_join
|
||||
|
||||
# Mock Tenant query result
|
||||
mock_session.scalars.return_value.one.return_value = tenant
|
||||
|
||||
# Act
|
||||
account.current_tenant = tenant
|
||||
|
||||
# Assert
|
||||
assert account._current_tenant == tenant
|
||||
assert account.role == TenantAccountRole.OWNER
|
||||
|
||||
@patch("models.account.Session")
|
||||
@patch("models.account.db")
|
||||
def test_account_current_tenant_setter_without_relationship(self, mock_db, mock_session_class):
|
||||
"""Test setting current_tenant when no relationship exists."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
# Mock the session and queries
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock no TenantAccountJoin found
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
account.current_tenant = tenant
|
||||
|
||||
# Assert
|
||||
assert account._current_tenant is None
|
||||
|
||||
def test_account_current_tenant_id_property(self):
|
||||
"""Test the current_tenant_id property."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
# Act - with tenant
|
||||
account._current_tenant = tenant
|
||||
tenant_id = account.current_tenant_id
|
||||
|
||||
# Assert
|
||||
assert tenant_id == tenant.id
|
||||
|
||||
# Act - without tenant
|
||||
account._current_tenant = None
|
||||
tenant_id_none = account.current_tenant_id
|
||||
|
||||
# Assert
|
||||
assert tenant_id_none is None
|
||||
|
||||
@patch("models.account.Session")
|
||||
@patch("models.account.db")
|
||||
def test_account_set_tenant_id_method(self, mock_db, mock_session_class):
|
||||
"""Test the set_tenant_id method."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.ADMIN,
|
||||
)
|
||||
|
||||
# Mock the session and queries
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.first.return_value = (tenant, tenant_join)
|
||||
|
||||
# Act
|
||||
account.set_tenant_id(tenant.id)
|
||||
|
||||
# Assert
|
||||
assert account._current_tenant == tenant
|
||||
assert account.role == TenantAccountRole.ADMIN
|
||||
|
||||
@patch("models.account.Session")
|
||||
@patch("models.account.db")
|
||||
def test_account_set_tenant_id_with_no_relationship(self, mock_db, mock_session_class):
|
||||
"""Test set_tenant_id when no relationship exists."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.id = str(uuid4())
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Mock the session and queries
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
account.set_tenant_id(tenant_id)
|
||||
|
||||
# Assert - should not set tenant when no relationship exists
|
||||
# The method returns early without setting _current_tenant
|
||||
|
||||
|
||||
class TestAccountRolePermissions:
|
||||
"""Test suite for account role permissions."""
|
||||
|
||||
def test_is_admin_or_owner_with_admin_role(self):
|
||||
"""Test is_admin_or_owner property with admin role."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.role = TenantAccountRole.ADMIN
|
||||
|
||||
# Act & Assert
|
||||
assert account.is_admin_or_owner
|
||||
|
||||
def test_is_admin_or_owner_with_owner_role(self):
|
||||
"""Test is_admin_or_owner property with owner role."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.role = TenantAccountRole.OWNER
|
||||
|
||||
# Act & Assert
|
||||
assert account.is_admin_or_owner
|
||||
|
||||
def test_is_admin_or_owner_with_normal_role(self):
|
||||
"""Test is_admin_or_owner property with normal role."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.role = TenantAccountRole.NORMAL
|
||||
|
||||
# Act & Assert
|
||||
assert not account.is_admin_or_owner
|
||||
|
||||
def test_is_admin_property(self):
|
||||
"""Test is_admin property."""
|
||||
# Arrange
|
||||
admin_account = Account(name="Admin", email="admin@example.com")
|
||||
admin_account.role = TenantAccountRole.ADMIN
|
||||
|
||||
owner_account = Account(name="Owner", email="owner@example.com")
|
||||
owner_account.role = TenantAccountRole.OWNER
|
||||
|
||||
# Act & Assert
|
||||
assert admin_account.is_admin
|
||||
assert not owner_account.is_admin
|
||||
|
||||
def test_has_edit_permission_with_editing_roles(self):
|
||||
"""Test has_edit_permission property with roles that have edit permission."""
|
||||
# Arrange
|
||||
roles_with_edit = [
|
||||
TenantAccountRole.OWNER,
|
||||
TenantAccountRole.ADMIN,
|
||||
TenantAccountRole.EDITOR,
|
||||
]
|
||||
|
||||
for role in roles_with_edit:
|
||||
account = Account(name="Test User", email=f"test_{role}@example.com")
|
||||
account.role = role
|
||||
|
||||
# Act & Assert
|
||||
assert account.has_edit_permission, f"Role {role} should have edit permission"
|
||||
|
||||
def test_has_edit_permission_without_editing_roles(self):
|
||||
"""Test has_edit_permission property with roles that don't have edit permission."""
|
||||
# Arrange
|
||||
roles_without_edit = [
|
||||
TenantAccountRole.NORMAL,
|
||||
TenantAccountRole.DATASET_OPERATOR,
|
||||
]
|
||||
|
||||
for role in roles_without_edit:
|
||||
account = Account(name="Test User", email=f"test_{role}@example.com")
|
||||
account.role = role
|
||||
|
||||
# Act & Assert
|
||||
assert not account.has_edit_permission, f"Role {role} should not have edit permission"
|
||||
|
||||
def test_is_dataset_editor_property(self):
|
||||
"""Test is_dataset_editor property."""
|
||||
# Arrange
|
||||
dataset_roles = [
|
||||
TenantAccountRole.OWNER,
|
||||
TenantAccountRole.ADMIN,
|
||||
TenantAccountRole.EDITOR,
|
||||
TenantAccountRole.DATASET_OPERATOR,
|
||||
]
|
||||
|
||||
for role in dataset_roles:
|
||||
account = Account(name="Test User", email=f"test_{role}@example.com")
|
||||
account.role = role
|
||||
|
||||
# Act & Assert
|
||||
assert account.is_dataset_editor, f"Role {role} should have dataset edit permission"
|
||||
|
||||
# Test normal role doesn't have dataset edit permission
|
||||
normal_account = Account(name="Normal User", email="normal@example.com")
|
||||
normal_account.role = TenantAccountRole.NORMAL
|
||||
assert not normal_account.is_dataset_editor
|
||||
|
||||
def test_is_dataset_operator_property(self):
|
||||
"""Test is_dataset_operator property."""
|
||||
# Arrange
|
||||
dataset_operator = Account(name="Dataset Operator", email="operator@example.com")
|
||||
dataset_operator.role = TenantAccountRole.DATASET_OPERATOR
|
||||
|
||||
normal_account = Account(name="Normal User", email="normal@example.com")
|
||||
normal_account.role = TenantAccountRole.NORMAL
|
||||
|
||||
# Act & Assert
|
||||
assert dataset_operator.is_dataset_operator
|
||||
assert not normal_account.is_dataset_operator
|
||||
|
||||
def test_current_role_property(self):
|
||||
"""Test current_role property."""
|
||||
# Arrange
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account.role = TenantAccountRole.EDITOR
|
||||
|
||||
# Act
|
||||
current_role = account.current_role
|
||||
|
||||
# Assert
|
||||
assert current_role == TenantAccountRole.EDITOR
|
||||
|
||||
|
||||
class TestAccountGetByOpenId:
|
||||
"""Test suite for get_by_openid class method."""
|
||||
|
||||
@patch("models.account.db")
|
||||
def test_get_by_openid_success(self, mock_db):
|
||||
"""Test successful retrieval of account by OpenID."""
|
||||
# Arrange
|
||||
provider = "google"
|
||||
open_id = "google_user_123"
|
||||
account_id = str(uuid4())
|
||||
|
||||
mock_account_integrate = MagicMock()
|
||||
mock_account_integrate.account_id = account_id
|
||||
|
||||
mock_account = Account(name="Test User", email="test@example.com")
|
||||
mock_account.id = account_id
|
||||
|
||||
# Mock the query chain
|
||||
mock_query = MagicMock()
|
||||
mock_where = MagicMock()
|
||||
mock_where.one_or_none.return_value = mock_account_integrate
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_db.session.query.return_value = mock_query
|
||||
|
||||
# Mock the second query for account
|
||||
mock_account_query = MagicMock()
|
||||
mock_account_where = MagicMock()
|
||||
mock_account_where.one_or_none.return_value = mock_account
|
||||
mock_account_query.where.return_value = mock_account_where
|
||||
|
||||
# Setup query to return different results based on model
|
||||
def query_side_effect(model):
|
||||
if model.__name__ == "AccountIntegrate":
|
||||
return mock_query
|
||||
elif model.__name__ == "Account":
|
||||
return mock_account_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
# Act
|
||||
result = Account.get_by_openid(provider, open_id)
|
||||
|
||||
# Assert
|
||||
assert result == mock_account
|
||||
|
||||
@patch("models.account.db")
|
||||
def test_get_by_openid_not_found(self, mock_db):
|
||||
"""Test get_by_openid when account integrate doesn't exist."""
|
||||
# Arrange
|
||||
provider = "github"
|
||||
open_id = "github_user_456"
|
||||
|
||||
# Mock the query chain to return None
|
||||
mock_query = MagicMock()
|
||||
mock_where = MagicMock()
|
||||
mock_where.one_or_none.return_value = None
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_db.session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = Account.get_by_openid(provider, open_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestTenantAccountJoinModel:
|
||||
"""Test suite for TenantAccountJoin model."""
|
||||
|
||||
def test_tenant_account_join_creation(self):
|
||||
"""Test creating a TenantAccountJoin record."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant_id,
|
||||
account_id=account_id,
|
||||
role=TenantAccountRole.NORMAL,
|
||||
current=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert join.tenant_id == tenant_id
|
||||
assert join.account_id == account_id
|
||||
assert join.role == TenantAccountRole.NORMAL
|
||||
assert join.current is True
|
||||
|
||||
def test_tenant_account_join_default_values(self):
|
||||
"""Test default values for TenantAccountJoin."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant_id,
|
||||
account_id=account_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert join.current is False # Default value
|
||||
assert join.role == "normal" # Default value
|
||||
assert join.invited_by is None # Default value
|
||||
|
||||
def test_tenant_account_join_with_invited_by(self):
|
||||
"""Test TenantAccountJoin with invited_by field."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account_id = str(uuid4())
|
||||
inviter_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant_id,
|
||||
account_id=account_id,
|
||||
role=TenantAccountRole.EDITOR,
|
||||
invited_by=inviter_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert join.invited_by == inviter_id
|
||||
|
||||
|
||||
class TestTenantModel:
|
||||
"""Test suite for Tenant model."""
|
||||
|
||||
def test_tenant_creation(self):
|
||||
"""Test creating a Tenant."""
|
||||
# Arrange & Act
|
||||
tenant = Tenant(name="Test Workspace")
|
||||
|
||||
# Assert
|
||||
assert tenant.name == "Test Workspace"
|
||||
assert tenant.status == "normal" # Default value
|
||||
assert tenant.plan == "basic" # Default value
|
||||
|
||||
def test_tenant_custom_config_dict_property(self):
|
||||
"""Test custom_config_dict property getter."""
|
||||
# Arrange
|
||||
tenant = Tenant(name="Test Workspace")
|
||||
config = {"feature1": True, "feature2": "value"}
|
||||
tenant.custom_config = '{"feature1": true, "feature2": "value"}'
|
||||
|
||||
# Act
|
||||
result = tenant.custom_config_dict
|
||||
|
||||
# Assert
|
||||
assert result["feature1"] is True
|
||||
assert result["feature2"] == "value"
|
||||
|
||||
def test_tenant_custom_config_dict_property_empty(self):
|
||||
"""Test custom_config_dict property with empty config."""
|
||||
# Arrange
|
||||
tenant = Tenant(name="Test Workspace")
|
||||
tenant.custom_config = None
|
||||
|
||||
# Act
|
||||
result = tenant.custom_config_dict
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_tenant_custom_config_dict_setter(self):
|
||||
"""Test custom_config_dict property setter."""
|
||||
# Arrange
|
||||
tenant = Tenant(name="Test Workspace")
|
||||
config = {"feature1": True, "feature2": "value"}
|
||||
|
||||
# Act
|
||||
tenant.custom_config_dict = config
|
||||
|
||||
# Assert
|
||||
assert tenant.custom_config == '{"feature1": true, "feature2": "value"}'
|
||||
|
||||
@patch("models.account.db")
|
||||
def test_tenant_get_accounts(self, mock_db):
|
||||
"""Test getting accounts associated with a tenant."""
|
||||
# Arrange
|
||||
tenant = Tenant(name="Test Workspace")
|
||||
tenant.id = str(uuid4())
|
||||
|
||||
account1 = Account(name="User 1", email="user1@example.com")
|
||||
account1.id = str(uuid4())
|
||||
account2 = Account(name="User 2", email="user2@example.com")
|
||||
account2.id = str(uuid4())
|
||||
|
||||
# Mock the query chain
|
||||
mock_scalars = MagicMock()
|
||||
mock_scalars.all.return_value = [account1, account2]
|
||||
mock_db.session.scalars.return_value = mock_scalars
|
||||
|
||||
# Act
|
||||
accounts = tenant.get_accounts()
|
||||
|
||||
# Assert
|
||||
assert len(accounts) == 2
|
||||
assert account1 in accounts
|
||||
assert account2 in accounts
|
||||
|
||||
|
||||
class TestTenantStatusEnum:
|
||||
"""Test suite for TenantStatus enum."""
|
||||
|
||||
def test_tenant_status_enum_values(self):
|
||||
"""Test TenantStatus enum values."""
|
||||
# Arrange & Act
|
||||
from models.account import TenantStatus
|
||||
|
||||
# Assert
|
||||
assert TenantStatus.NORMAL == "normal"
|
||||
assert TenantStatus.ARCHIVE == "archive"
|
||||
|
||||
|
||||
class TestAccountIntegration:
|
||||
"""Integration tests for Account model with related models."""
|
||||
|
||||
def test_account_with_multiple_tenants(self):
|
||||
"""Test account associated with multiple tenants."""
|
||||
# Arrange
|
||||
account = Account(name="Multi-Tenant User", email="multi@example.com")
|
||||
account.id = str(uuid4())
|
||||
|
||||
tenant1_id = str(uuid4())
|
||||
tenant2_id = str(uuid4())
|
||||
|
||||
join1 = TenantAccountJoin(
|
||||
tenant_id=tenant1_id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
|
||||
join2 = TenantAccountJoin(
|
||||
tenant_id=tenant2_id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.NORMAL,
|
||||
current=False,
|
||||
)
|
||||
|
||||
# Assert - verify the joins are created correctly
|
||||
assert join1.account_id == account.id
|
||||
assert join2.account_id == account.id
|
||||
assert join1.current is True
|
||||
assert join2.current is False
|
||||
|
||||
def test_account_last_login_tracking(self):
|
||||
"""Test account last login tracking."""
|
||||
# Arrange
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
login_time = datetime.now(UTC)
|
||||
login_ip = "192.168.1.1"
|
||||
|
||||
# Act
|
||||
account.last_login_at = login_time
|
||||
account.last_login_ip = login_ip
|
||||
|
||||
# Assert
|
||||
assert account.last_login_at == login_time
|
||||
assert account.last_login_ip == login_ip
|
||||
|
||||
def test_account_initialization_tracking(self):
|
||||
"""Test account initialization tracking."""
|
||||
# Arrange
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
status=AccountStatus.PENDING,
|
||||
)
|
||||
|
||||
# Act - simulate initialization
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = datetime.now(UTC)
|
||||
|
||||
# Assert
|
||||
assert account.get_status() == AccountStatus.ACTIVE
|
||||
assert account.initialized_at is not None
|
||||
1151
dify/api/tests/unit_tests/models/test_app_models.py
Normal file
1151
dify/api/tests/unit_tests/models/test_app_models.py
Normal file
File diff suppressed because it is too large
Load Diff
11
dify/api/tests/unit_tests/models/test_base.py
Normal file
11
dify/api/tests/unit_tests/models/test_base.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from models.base import DefaultFieldsMixin
|
||||
|
||||
|
||||
class FooModel(DefaultFieldsMixin):
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
|
||||
|
||||
def test_repr():
|
||||
foo_model = FooModel(id="test-id")
|
||||
assert repr(foo_model) == "<FooModel(id=test-id)>"
|
||||
@@ -0,0 +1,26 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from core.variables import SegmentType
|
||||
from factories import variable_factory
|
||||
from models import ConversationVariable
|
||||
|
||||
|
||||
def test_from_variable_and_to_variable():
|
||||
variable = variable_factory.build_conversation_variable_from_mapping(
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"name": "name",
|
||||
"value_type": SegmentType.OBJECT,
|
||||
"value": {
|
||||
"key": {
|
||||
"key": "value",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
conversation_variable = ConversationVariable.from_variable(
|
||||
app_id="app_id", conversation_id="conversation_id", variable=variable
|
||||
)
|
||||
|
||||
assert conversation_variable.to_variable() == variable
|
||||
1341
dify/api/tests/unit_tests/models/test_dataset_models.py
Normal file
1341
dify/api/tests/unit_tests/models/test_dataset_models.py
Normal file
File diff suppressed because it is too large
Load Diff
83
dify/api/tests/unit_tests/models/test_model.py
Normal file
83
dify/api/tests/unit_tests/models/test_model.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import importlib
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
from models.model import Message
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_file_helpers(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Patch file_helpers.get_signed_file_url to a deterministic stub.
|
||||
"""
|
||||
model_module = importlib.import_module("models.model")
|
||||
dummy = types.SimpleNamespace(get_signed_file_url=lambda fid: f"https://signed.example/{fid}")
|
||||
# Inject/override file_helpers on models.model
|
||||
monkeypatch.setattr(model_module, "file_helpers", dummy, raising=False)
|
||||
|
||||
|
||||
def _wrap_md(url: str) -> str:
|
||||
"""
|
||||
Wrap a raw URL into the markdown that re_sign_file_url_answer expects:
|
||||
[link](<url>)
|
||||
"""
|
||||
return f"please click [file]({url}) to download."
|
||||
|
||||
|
||||
def test_file_preview_valid_replaced():
|
||||
"""
|
||||
Valid file-preview URL must be re-signed:
|
||||
- Extract upload_file_id correctly
|
||||
- Replace the original URL with the signed URL
|
||||
"""
|
||||
upload_id = "abc-123"
|
||||
url = f"/files/{upload_id}/file-preview?timestamp=111&nonce=222&sign=333"
|
||||
msg = Message(answer=_wrap_md(url))
|
||||
|
||||
out = msg.re_sign_file_url_answer
|
||||
assert f"https://signed.example/{upload_id}" in out
|
||||
assert url not in out
|
||||
|
||||
|
||||
def test_file_preview_misspelled_not_replaced():
|
||||
"""
|
||||
Misspelled endpoint 'file-previe?timestamp=' should NOT be rewritten.
|
||||
"""
|
||||
upload_id = "zzz-001"
|
||||
# path deliberately misspelled: file-previe? (missing 'w')
|
||||
# and we append ¬e=file-preview to trick the old `"file-preview" in url` check.
|
||||
url = f"/files/{upload_id}/file-previe?timestamp=111&nonce=222&sign=333¬e=file-preview"
|
||||
original = _wrap_md(url)
|
||||
msg = Message(answer=original)
|
||||
|
||||
out = msg.re_sign_file_url_answer
|
||||
# Expect NO replacement, should not rewrite misspelled file-previe URL
|
||||
assert out == original
|
||||
|
||||
|
||||
def test_image_preview_valid_replaced():
|
||||
"""
|
||||
Valid image-preview URL must be re-signed.
|
||||
"""
|
||||
upload_id = "img-789"
|
||||
url = f"/files/{upload_id}/image-preview?timestamp=123&nonce=456&sign=789"
|
||||
msg = Message(answer=_wrap_md(url))
|
||||
|
||||
out = msg.re_sign_file_url_answer
|
||||
assert f"https://signed.example/{upload_id}" in out
|
||||
assert url not in out
|
||||
|
||||
|
||||
def test_image_preview_misspelled_not_replaced():
|
||||
"""
|
||||
Misspelled endpoint 'image-previe?timestamp=' should NOT be rewritten.
|
||||
"""
|
||||
upload_id = "img-err-42"
|
||||
url = f"/files/{upload_id}/image-previe?timestamp=1&nonce=2&sign=3¬e=image-preview"
|
||||
original = _wrap_md(url)
|
||||
msg = Message(answer=original)
|
||||
|
||||
out = msg.re_sign_file_url_answer
|
||||
# Expect NO replacement, should not rewrite misspelled image-previe URL
|
||||
assert out == original
|
||||
22
dify/api/tests/unit_tests/models/test_plugin_entities.py
Normal file
22
dify/api/tests/unit_tests/models/test_plugin_entities.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import binascii
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.request import TriggerDispatchResponse
|
||||
|
||||
|
||||
def test_trigger_dispatch_response():
|
||||
raw_http_response = b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n{"message": "Hello, world!"}'
|
||||
|
||||
data: Mapping[str, Any] = {
|
||||
"user_id": "123",
|
||||
"events": ["event1", "event2"],
|
||||
"response": binascii.hexlify(raw_http_response).decode(),
|
||||
"payload": {"key": "value"},
|
||||
}
|
||||
|
||||
response = TriggerDispatchResponse(**data)
|
||||
|
||||
assert response.response.status_code == 200
|
||||
assert response.response.headers["Content-Type"] == "application/json"
|
||||
assert response.response.get_data(as_text=True) == '{"message": "Hello, world!"}'
|
||||
966
dify/api/tests/unit_tests/models/test_tool_models.py
Normal file
966
dify/api/tests/unit_tests/models/test_tool_models.py
Normal file
@@ -0,0 +1,966 @@
|
||||
"""
|
||||
Comprehensive unit tests for Tool models.
|
||||
|
||||
This test suite covers:
|
||||
- ToolProvider model validation (BuiltinToolProvider, ApiToolProvider)
|
||||
- BuiltinToolProvider relationships and credential management
|
||||
- ApiToolProvider credential storage and encryption
|
||||
- Tool OAuth client models
|
||||
- ToolLabelBinding relationships
|
||||
"""
|
||||
|
||||
import json
|
||||
from uuid import uuid4
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from models.tools import (
|
||||
ApiToolProvider,
|
||||
BuiltinToolProvider,
|
||||
ToolLabelBinding,
|
||||
ToolOAuthSystemClient,
|
||||
ToolOAuthTenantClient,
|
||||
)
|
||||
|
||||
|
||||
class TestBuiltinToolProviderValidation:
|
||||
"""Test suite for BuiltinToolProvider model validation and operations."""
|
||||
|
||||
def test_builtin_tool_provider_creation_with_required_fields(self):
|
||||
"""Test creating a builtin tool provider with all required fields."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
provider_name = "google"
|
||||
credentials = {"api_key": "test_key_123"}
|
||||
|
||||
# Act
|
||||
builtin_provider = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
encrypted_credentials=json.dumps(credentials),
|
||||
name="Google API Key 1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert builtin_provider.tenant_id == tenant_id
|
||||
assert builtin_provider.user_id == user_id
|
||||
assert builtin_provider.provider == provider_name
|
||||
assert builtin_provider.name == "Google API Key 1"
|
||||
assert builtin_provider.encrypted_credentials == json.dumps(credentials)
|
||||
|
||||
def test_builtin_tool_provider_credentials_property(self):
|
||||
"""Test credentials property parses JSON correctly."""
|
||||
# Arrange
|
||||
credentials_data = {
|
||||
"api_key": "sk-test123",
|
||||
"auth_type": "api_key",
|
||||
"endpoint": "https://api.example.com",
|
||||
}
|
||||
builtin_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="custom_provider",
|
||||
name="Custom Provider Key",
|
||||
encrypted_credentials=json.dumps(credentials_data),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = builtin_provider.credentials
|
||||
|
||||
# Assert
|
||||
assert result == credentials_data
|
||||
assert result["api_key"] == "sk-test123"
|
||||
assert result["auth_type"] == "api_key"
|
||||
|
||||
def test_builtin_tool_provider_credentials_empty_when_none(self):
|
||||
"""Test credentials property returns empty dict when encrypted_credentials is None."""
|
||||
# Arrange
|
||||
builtin_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="test_provider",
|
||||
name="Test Provider",
|
||||
encrypted_credentials=None,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = builtin_provider.credentials
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_builtin_tool_provider_credentials_empty_when_empty_string(self):
|
||||
"""Test credentials property returns empty dict when encrypted_credentials is empty."""
|
||||
# Arrange
|
||||
builtin_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="test_provider",
|
||||
name="Test Provider",
|
||||
encrypted_credentials="",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = builtin_provider.credentials
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_builtin_tool_provider_default_values(self):
|
||||
"""Test builtin tool provider default values."""
|
||||
# Arrange & Act
|
||||
builtin_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="test_provider",
|
||||
name="Test Provider",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert builtin_provider.is_default is False
|
||||
assert builtin_provider.credential_type == "api-key"
|
||||
assert builtin_provider.expires_at == -1
|
||||
|
||||
def test_builtin_tool_provider_with_oauth_credential_type(self):
|
||||
"""Test builtin tool provider with OAuth credential type."""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"access_token": "oauth_token_123",
|
||||
"refresh_token": "refresh_token_456",
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
|
||||
# Act
|
||||
builtin_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="google",
|
||||
name="Google OAuth",
|
||||
encrypted_credentials=json.dumps(credentials),
|
||||
credential_type="oauth2",
|
||||
expires_at=1735689600,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert builtin_provider.credential_type == "oauth2"
|
||||
assert builtin_provider.expires_at == 1735689600
|
||||
assert builtin_provider.credentials["access_token"] == "oauth_token_123"
|
||||
|
||||
def test_builtin_tool_provider_is_default_flag(self):
|
||||
"""Test is_default flag for builtin tool provider."""
|
||||
# Arrange
|
||||
provider1 = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="google",
|
||||
name="Google Key 1",
|
||||
is_default=True,
|
||||
)
|
||||
provider2 = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="google",
|
||||
name="Google Key 2",
|
||||
is_default=False,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert provider1.is_default is True
|
||||
assert provider2.is_default is False
|
||||
|
||||
def test_builtin_tool_provider_unique_constraint_fields(self):
|
||||
"""Test unique constraint fields (tenant_id, provider, name)."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
provider_name = "google"
|
||||
credential_name = "My Google Key"
|
||||
|
||||
# Act
|
||||
builtin_provider = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=str(uuid4()),
|
||||
provider=provider_name,
|
||||
name=credential_name,
|
||||
)
|
||||
|
||||
# Assert - these fields form unique constraint
|
||||
assert builtin_provider.tenant_id == tenant_id
|
||||
assert builtin_provider.provider == provider_name
|
||||
assert builtin_provider.name == credential_name
|
||||
|
||||
def test_builtin_tool_provider_multiple_credentials_same_provider(self):
|
||||
"""Test multiple credential sets for the same provider."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
provider = "openai"
|
||||
|
||||
# Act - create multiple credentials for same provider
|
||||
provider1 = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
name="OpenAI Key 1",
|
||||
encrypted_credentials=json.dumps({"api_key": "key1"}),
|
||||
)
|
||||
provider2 = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
name="OpenAI Key 2",
|
||||
encrypted_credentials=json.dumps({"api_key": "key2"}),
|
||||
)
|
||||
|
||||
# Assert - different names allow multiple credentials
|
||||
assert provider1.provider == provider2.provider
|
||||
assert provider1.name != provider2.name
|
||||
assert provider1.credentials != provider2.credentials
|
||||
|
||||
|
||||
class TestApiToolProviderValidation:
|
||||
"""Test suite for ApiToolProvider model validation and operations."""
|
||||
|
||||
def test_api_tool_provider_creation_with_required_fields(self):
|
||||
"""Test creating an API tool provider with all required fields."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
provider_name = "Custom API"
|
||||
schema = '{"openapi": "3.0.0", "info": {"title": "Test API"}}'
|
||||
tools = [{"name": "test_tool", "description": "A test tool"}]
|
||||
credentials = {"auth_type": "api_key", "api_key_value": "test123"}
|
||||
|
||||
# Act
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=provider_name,
|
||||
icon='{"type": "emoji", "value": "🔧"}',
|
||||
schema=schema,
|
||||
schema_type_str="openapi",
|
||||
description="Custom API for testing",
|
||||
tools_str=json.dumps(tools),
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert api_provider.tenant_id == tenant_id
|
||||
assert api_provider.user_id == user_id
|
||||
assert api_provider.name == provider_name
|
||||
assert api_provider.schema == schema
|
||||
assert api_provider.schema_type_str == "openapi"
|
||||
assert api_provider.description == "Custom API for testing"
|
||||
|
||||
def test_api_tool_provider_schema_type_property(self):
|
||||
"""Test schema_type property converts string to enum."""
|
||||
# Arrange
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Test API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = api_provider.schema_type
|
||||
|
||||
# Assert
|
||||
assert result == ApiProviderSchemaType.OPENAPI
|
||||
|
||||
def test_api_tool_provider_tools_property(self):
|
||||
"""Test tools property parses JSON and returns ApiToolBundle list."""
|
||||
# Arrange
|
||||
tools_data = [
|
||||
{
|
||||
"author": "test",
|
||||
"server_url": "https://api.weather.com",
|
||||
"method": "get",
|
||||
"summary": "Get weather information",
|
||||
"operation_id": "getWeather",
|
||||
"parameters": [],
|
||||
"openapi": {
|
||||
"operation_id": "getWeather",
|
||||
"parameters": [],
|
||||
"method": "get",
|
||||
"path": "/weather",
|
||||
"server_url": "https://api.weather.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
"author": "test",
|
||||
"server_url": "https://api.location.com",
|
||||
"method": "get",
|
||||
"summary": "Get location data",
|
||||
"operation_id": "getLocation",
|
||||
"parameters": [],
|
||||
"openapi": {
|
||||
"operation_id": "getLocation",
|
||||
"parameters": [],
|
||||
"method": "get",
|
||||
"path": "/location",
|
||||
"server_url": "https://api.location.com",
|
||||
},
|
||||
},
|
||||
]
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Weather API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Weather API",
|
||||
tools_str=json.dumps(tools_data),
|
||||
credentials_str="{}",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = api_provider.tools
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0].operation_id == "getWeather"
|
||||
assert result[1].operation_id == "getLocation"
|
||||
|
||||
def test_api_tool_provider_credentials_property(self):
|
||||
"""Test credentials property parses JSON correctly."""
|
||||
# Arrange
|
||||
credentials_data = {
|
||||
"auth_type": "api_key_header",
|
||||
"api_key_header": "Authorization",
|
||||
"api_key_value": "Bearer test_token",
|
||||
"api_key_header_prefix": "bearer",
|
||||
}
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Secure API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Secure API",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(credentials_data),
|
||||
)
|
||||
|
||||
# Act
|
||||
result = api_provider.credentials
|
||||
|
||||
# Assert
|
||||
assert result["auth_type"] == "api_key_header"
|
||||
assert result["api_key_header"] == "Authorization"
|
||||
assert result["api_key_value"] == "Bearer test_token"
|
||||
|
||||
def test_api_tool_provider_with_privacy_policy(self):
|
||||
"""Test API tool provider with privacy policy."""
|
||||
# Arrange
|
||||
privacy_policy_url = "https://example.com/privacy"
|
||||
|
||||
# Act
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Privacy API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="API with privacy policy",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
privacy_policy=privacy_policy_url,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert api_provider.privacy_policy == privacy_policy_url
|
||||
|
||||
def test_api_tool_provider_with_custom_disclaimer(self):
|
||||
"""Test API tool provider with custom disclaimer."""
|
||||
# Arrange
|
||||
disclaimer = "This API is provided as-is without warranty."
|
||||
|
||||
# Act
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Disclaimer API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="API with disclaimer",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
custom_disclaimer=disclaimer,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert api_provider.custom_disclaimer == disclaimer
|
||||
|
||||
def test_api_tool_provider_default_custom_disclaimer(self):
|
||||
"""Test API tool provider default custom_disclaimer is empty string."""
|
||||
# Arrange & Act
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Default API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="API",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert api_provider.custom_disclaimer == ""
|
||||
|
||||
def test_api_tool_provider_unique_constraint_fields(self):
|
||||
"""Test unique constraint fields (name, tenant_id)."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
provider_name = "Unique API"
|
||||
|
||||
# Act
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=str(uuid4()),
|
||||
name=provider_name,
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Unique API",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
)
|
||||
|
||||
# Assert - these fields form unique constraint
|
||||
assert api_provider.tenant_id == tenant_id
|
||||
assert api_provider.name == provider_name
|
||||
|
||||
def test_api_tool_provider_with_no_auth(self):
|
||||
"""Test API tool provider with no authentication."""
|
||||
# Arrange
|
||||
credentials = {"auth_type": "none"}
|
||||
|
||||
# Act
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Public API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Public API with no auth",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert api_provider.credentials["auth_type"] == "none"
|
||||
|
||||
def test_api_tool_provider_with_api_key_query_auth(self):
|
||||
"""Test API tool provider with API key in query parameter."""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"auth_type": "api_key_query",
|
||||
"api_key_query_param": "apikey",
|
||||
"api_key_value": "my_secret_key",
|
||||
}
|
||||
|
||||
# Act
|
||||
api_provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Query Auth API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="API with query auth",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert api_provider.credentials["auth_type"] == "api_key_query"
|
||||
assert api_provider.credentials["api_key_query_param"] == "apikey"
|
||||
|
||||
|
||||
class TestToolOAuthModels:
|
||||
"""Test suite for OAuth client models (system and tenant level)."""
|
||||
|
||||
def test_oauth_system_client_creation(self):
|
||||
"""Test creating a system-level OAuth client."""
|
||||
# Arrange
|
||||
plugin_id = "builtin.google"
|
||||
provider = "google"
|
||||
oauth_params = json.dumps(
|
||||
{"client_id": "system_client_id", "client_secret": "system_secret", "scope": "email profile"}
|
||||
)
|
||||
|
||||
# Act
|
||||
oauth_client = ToolOAuthSystemClient(
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
encrypted_oauth_params=oauth_params,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert oauth_client.plugin_id == plugin_id
|
||||
assert oauth_client.provider == provider
|
||||
assert oauth_client.encrypted_oauth_params == oauth_params
|
||||
|
||||
def test_oauth_system_client_unique_constraint(self):
|
||||
"""Test unique constraint on plugin_id and provider."""
|
||||
# Arrange
|
||||
plugin_id = "builtin.github"
|
||||
provider = "github"
|
||||
|
||||
# Act
|
||||
oauth_client = ToolOAuthSystemClient(
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
encrypted_oauth_params="{}",
|
||||
)
|
||||
|
||||
# Assert - these fields form unique constraint
|
||||
assert oauth_client.plugin_id == plugin_id
|
||||
assert oauth_client.provider == provider
|
||||
|
||||
def test_oauth_tenant_client_creation(self):
|
||||
"""Test creating a tenant-level OAuth client."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
plugin_id = "builtin.google"
|
||||
provider = "google"
|
||||
|
||||
# Act
|
||||
oauth_client = ToolOAuthTenantClient(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
)
|
||||
# Set encrypted_oauth_params after creation (it has init=False)
|
||||
oauth_params = json.dumps({"client_id": "tenant_client_id", "client_secret": "tenant_secret"})
|
||||
oauth_client.encrypted_oauth_params = oauth_params
|
||||
|
||||
# Assert
|
||||
assert oauth_client.tenant_id == tenant_id
|
||||
assert oauth_client.plugin_id == plugin_id
|
||||
assert oauth_client.provider == provider
|
||||
|
||||
def test_oauth_tenant_client_enabled_default(self):
|
||||
"""Test OAuth tenant client enabled flag has init=False and uses server default."""
|
||||
# Arrange & Act
|
||||
oauth_client = ToolOAuthTenantClient(
|
||||
tenant_id=str(uuid4()),
|
||||
plugin_id="builtin.slack",
|
||||
provider="slack",
|
||||
)
|
||||
|
||||
# Assert - enabled has init=False, so it won't be set until saved to DB
|
||||
# We can manually set it to test the field exists
|
||||
oauth_client.enabled = True
|
||||
assert oauth_client.enabled is True
|
||||
|
||||
def test_oauth_tenant_client_oauth_params_property(self):
|
||||
"""Test oauth_params property parses JSON correctly."""
|
||||
# Arrange
|
||||
params_data = {
|
||||
"client_id": "test_client_123",
|
||||
"client_secret": "secret_456",
|
||||
"redirect_uri": "https://app.example.com/callback",
|
||||
}
|
||||
oauth_client = ToolOAuthTenantClient(
|
||||
tenant_id=str(uuid4()),
|
||||
plugin_id="builtin.dropbox",
|
||||
provider="dropbox",
|
||||
)
|
||||
# Set encrypted_oauth_params after creation (it has init=False)
|
||||
oauth_client.encrypted_oauth_params = json.dumps(params_data)
|
||||
|
||||
# Act
|
||||
result = oauth_client.oauth_params
|
||||
|
||||
# Assert
|
||||
assert result == params_data
|
||||
assert result["client_id"] == "test_client_123"
|
||||
assert result["redirect_uri"] == "https://app.example.com/callback"
|
||||
|
||||
def test_oauth_tenant_client_oauth_params_empty_when_none(self):
|
||||
"""Test oauth_params returns empty dict when encrypted_oauth_params is None."""
|
||||
# Arrange
|
||||
oauth_client = ToolOAuthTenantClient(
|
||||
tenant_id=str(uuid4()),
|
||||
plugin_id="builtin.test",
|
||||
provider="test",
|
||||
)
|
||||
# encrypted_oauth_params has init=False, set it to None
|
||||
oauth_client.encrypted_oauth_params = None
|
||||
|
||||
# Act
|
||||
result = oauth_client.oauth_params
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_oauth_tenant_client_disabled_state(self):
|
||||
"""Test OAuth tenant client can be disabled."""
|
||||
# Arrange
|
||||
oauth_client = ToolOAuthTenantClient(
|
||||
tenant_id=str(uuid4()),
|
||||
plugin_id="builtin.microsoft",
|
||||
provider="microsoft",
|
||||
)
|
||||
|
||||
# Act
|
||||
oauth_client.enabled = False
|
||||
|
||||
# Assert
|
||||
assert oauth_client.enabled is False
|
||||
|
||||
|
||||
class TestToolLabelBinding:
|
||||
"""Test suite for ToolLabelBinding model."""
|
||||
|
||||
def test_tool_label_binding_creation(self):
|
||||
"""Test creating a tool label binding."""
|
||||
# Arrange
|
||||
tool_id = "google.search"
|
||||
tool_type = "builtin"
|
||||
label_name = "search"
|
||||
|
||||
# Act
|
||||
label_binding = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type=tool_type,
|
||||
label_name=label_name,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert label_binding.tool_id == tool_id
|
||||
assert label_binding.tool_type == tool_type
|
||||
assert label_binding.label_name == label_name
|
||||
|
||||
def test_tool_label_binding_unique_constraint(self):
|
||||
"""Test unique constraint on tool_id and label_name."""
|
||||
# Arrange
|
||||
tool_id = "openai.text_generation"
|
||||
label_name = "text"
|
||||
|
||||
# Act
|
||||
label_binding = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
label_name=label_name,
|
||||
)
|
||||
|
||||
# Assert - these fields form unique constraint
|
||||
assert label_binding.tool_id == tool_id
|
||||
assert label_binding.label_name == label_name
|
||||
|
||||
def test_tool_label_binding_multiple_labels_same_tool(self):
|
||||
"""Test multiple labels can be bound to the same tool."""
|
||||
# Arrange
|
||||
tool_id = "google.search"
|
||||
tool_type = "builtin"
|
||||
|
||||
# Act
|
||||
binding1 = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type=tool_type,
|
||||
label_name="search",
|
||||
)
|
||||
binding2 = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type=tool_type,
|
||||
label_name="productivity",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert binding1.tool_id == binding2.tool_id
|
||||
assert binding1.label_name != binding2.label_name
|
||||
|
||||
def test_tool_label_binding_different_tool_types(self):
|
||||
"""Test label bindings for different tool types."""
|
||||
# Arrange
|
||||
tool_types = ["builtin", "api", "workflow"]
|
||||
|
||||
# Act & Assert
|
||||
for tool_type in tool_types:
|
||||
binding = ToolLabelBinding(
|
||||
tool_id=f"test_tool_{tool_type}",
|
||||
tool_type=tool_type,
|
||||
label_name="test",
|
||||
)
|
||||
assert binding.tool_type == tool_type
|
||||
|
||||
|
||||
class TestCredentialStorage:
|
||||
"""Test suite for credential storage and encryption patterns."""
|
||||
|
||||
def test_builtin_provider_credential_storage_format(self):
|
||||
"""Test builtin provider stores credentials as JSON string."""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"api_key": "sk-test123",
|
||||
"endpoint": "https://api.example.com",
|
||||
"timeout": 30,
|
||||
}
|
||||
|
||||
# Act
|
||||
provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="test",
|
||||
name="Test Provider",
|
||||
encrypted_credentials=json.dumps(credentials),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(provider.encrypted_credentials, str)
|
||||
assert provider.credentials == credentials
|
||||
|
||||
def test_api_provider_credential_storage_format(self):
|
||||
"""Test API provider stores credentials as JSON string."""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"auth_type": "api_key_header",
|
||||
"api_key_header": "X-API-Key",
|
||||
"api_key_value": "secret_key_789",
|
||||
}
|
||||
|
||||
# Act
|
||||
provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Test API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(provider.credentials_str, str)
|
||||
assert provider.credentials == credentials
|
||||
|
||||
def test_builtin_provider_complex_credential_structure(self):
|
||||
"""Test builtin provider with complex nested credential structure."""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"auth_type": "oauth2",
|
||||
"oauth_config": {
|
||||
"access_token": "token123",
|
||||
"refresh_token": "refresh456",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
},
|
||||
"additional_headers": {"X-Custom-Header": "value"},
|
||||
}
|
||||
|
||||
# Act
|
||||
provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="oauth_provider",
|
||||
name="OAuth Provider",
|
||||
encrypted_credentials=json.dumps(credentials),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert provider.credentials["oauth_config"]["access_token"] == "token123"
|
||||
assert provider.credentials["additional_headers"]["X-Custom-Header"] == "value"
|
||||
|
||||
def test_api_provider_credential_update_pattern(self):
|
||||
"""Test pattern for updating API provider credentials."""
|
||||
# Arrange
|
||||
original_credentials = {"auth_type": "api_key_header", "api_key_value": "old_key"}
|
||||
provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
name="Update Test",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str=json.dumps(original_credentials),
|
||||
)
|
||||
|
||||
# Act - simulate credential update
|
||||
new_credentials = {"auth_type": "api_key_header", "api_key_value": "new_key"}
|
||||
provider.credentials_str = json.dumps(new_credentials)
|
||||
|
||||
# Assert
|
||||
assert provider.credentials["api_key_value"] == "new_key"
|
||||
|
||||
def test_builtin_provider_credential_expiration(self):
|
||||
"""Test builtin provider credential expiration tracking."""
|
||||
# Arrange
|
||||
future_timestamp = 1735689600 # Future date
|
||||
past_timestamp = 1609459200 # Past date
|
||||
|
||||
# Act
|
||||
active_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="active",
|
||||
name="Active Provider",
|
||||
expires_at=future_timestamp,
|
||||
)
|
||||
expired_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="expired",
|
||||
name="Expired Provider",
|
||||
expires_at=past_timestamp,
|
||||
)
|
||||
never_expires_provider = BuiltinToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=str(uuid4()),
|
||||
provider="permanent",
|
||||
name="Permanent Provider",
|
||||
expires_at=-1,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert active_provider.expires_at == future_timestamp
|
||||
assert expired_provider.expires_at == past_timestamp
|
||||
assert never_expires_provider.expires_at == -1
|
||||
|
||||
def test_oauth_client_credential_storage(self):
|
||||
"""Test OAuth client credential storage pattern."""
|
||||
# Arrange
|
||||
oauth_credentials = {
|
||||
"client_id": "oauth_client_123",
|
||||
"client_secret": "oauth_secret_456",
|
||||
"authorization_url": "https://oauth.example.com/authorize",
|
||||
"token_url": "https://oauth.example.com/token",
|
||||
"scope": "read write",
|
||||
}
|
||||
|
||||
# Act
|
||||
system_client = ToolOAuthSystemClient(
|
||||
plugin_id="builtin.oauth_test",
|
||||
provider="oauth_test",
|
||||
encrypted_oauth_params=json.dumps(oauth_credentials),
|
||||
)
|
||||
|
||||
tenant_client = ToolOAuthTenantClient(
|
||||
tenant_id=str(uuid4()),
|
||||
plugin_id="builtin.oauth_test",
|
||||
provider="oauth_test",
|
||||
)
|
||||
# Set encrypted_oauth_params after creation (it has init=False)
|
||||
tenant_client.encrypted_oauth_params = json.dumps(oauth_credentials)
|
||||
|
||||
# Assert
|
||||
assert system_client.encrypted_oauth_params == json.dumps(oauth_credentials)
|
||||
assert tenant_client.oauth_params == oauth_credentials
|
||||
|
||||
|
||||
class TestToolProviderRelationships:
|
||||
"""Test suite for tool provider relationships and associations."""
|
||||
|
||||
def test_builtin_provider_tenant_relationship(self):
|
||||
"""Test builtin provider belongs to a tenant."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
provider = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=str(uuid4()),
|
||||
provider="test",
|
||||
name="Test Provider",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert provider.tenant_id == tenant_id
|
||||
|
||||
def test_api_provider_user_relationship(self):
|
||||
"""Test API provider belongs to a user."""
|
||||
# Arrange
|
||||
user_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
provider = ApiToolProvider(
|
||||
tenant_id=str(uuid4()),
|
||||
user_id=user_id,
|
||||
name="User API",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert provider.user_id == user_id
|
||||
|
||||
def test_multiple_providers_same_tenant(self):
|
||||
"""Test multiple providers can belong to the same tenant."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
builtin1 = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider="google",
|
||||
name="Google Key 1",
|
||||
)
|
||||
builtin2 = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider="openai",
|
||||
name="OpenAI Key 1",
|
||||
)
|
||||
api1 = ApiToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name="Custom API 1",
|
||||
icon="{}",
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
description="Test",
|
||||
tools_str="[]",
|
||||
credentials_str="{}",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert builtin1.tenant_id == tenant_id
|
||||
assert builtin2.tenant_id == tenant_id
|
||||
assert api1.tenant_id == tenant_id
|
||||
|
||||
def test_tool_label_bindings_for_provider_tools(self):
|
||||
"""Test tool label bindings can be associated with provider tools."""
|
||||
# Arrange
|
||||
provider_name = "google"
|
||||
tool_id = f"{provider_name}.search"
|
||||
|
||||
# Act
|
||||
binding1 = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
label_name="search",
|
||||
)
|
||||
binding2 = ToolLabelBinding(
|
||||
tool_id=tool_id,
|
||||
tool_type="builtin",
|
||||
label_name="web",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert binding1.tool_id == tool_id
|
||||
assert binding2.tool_id == tool_id
|
||||
assert binding1.label_name != binding2.label_name
|
||||
191
dify/api/tests/unit_tests/models/test_types_enum_text.py
Normal file
191
dify/api/tests/unit_tests/models/test_types_enum_text.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from collections.abc import Callable, Iterable
|
||||
from enum import StrEnum
|
||||
from typing import Any, NamedTuple, TypeVar
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
|
||||
from sqlalchemy.sql.sqltypes import VARCHAR
|
||||
|
||||
from models.types import EnumText
|
||||
|
||||
_user_type_admin = "admin"
|
||||
_user_type_normal = "normal"
|
||||
|
||||
|
||||
class _Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class _UserType(StrEnum):
|
||||
admin = _user_type_admin
|
||||
normal = _user_type_normal
|
||||
|
||||
|
||||
class _EnumWithLongValue(StrEnum):
|
||||
unknown = "unknown"
|
||||
a_really_long_enum_values = "a_really_long_enum_values"
|
||||
|
||||
|
||||
class _User(_Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(sa.Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(sa.String(length=255), nullable=False)
|
||||
user_type: Mapped[_UserType] = mapped_column(
|
||||
EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal
|
||||
)
|
||||
user_type_nullable: Mapped[_UserType | None] = mapped_column(EnumText(enum_class=_UserType), nullable=True)
|
||||
|
||||
|
||||
class _ColumnTest(_Base):
|
||||
__tablename__ = "column_test"
|
||||
|
||||
id: Mapped[int] = mapped_column(sa.Integer, primary_key=True)
|
||||
|
||||
user_type: Mapped[_UserType] = mapped_column(
|
||||
EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal
|
||||
)
|
||||
explicit_length: Mapped[_UserType | None] = mapped_column(
|
||||
EnumText(_UserType, length=50), nullable=True, default=_UserType.normal
|
||||
)
|
||||
long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _first(it: Iterable[_T]) -> _T:
|
||||
ls = list(it)
|
||||
if not ls:
|
||||
raise ValueError("List is empty")
|
||||
return ls[0]
|
||||
|
||||
|
||||
class TestEnumText:
|
||||
def test_column_impl(self):
|
||||
engine = sa.create_engine("sqlite://", echo=False)
|
||||
_Base.metadata.create_all(engine)
|
||||
|
||||
inspector = sa.inspect(engine)
|
||||
columns = inspector.get_columns(_ColumnTest.__tablename__)
|
||||
|
||||
user_type_column = _first(c for c in columns if c["name"] == "user_type")
|
||||
sql_type = user_type_column["type"]
|
||||
assert isinstance(user_type_column["type"], VARCHAR)
|
||||
assert sql_type.length == 20
|
||||
assert user_type_column["nullable"] is False
|
||||
|
||||
explicit_length_column = _first(c for c in columns if c["name"] == "explicit_length")
|
||||
sql_type = explicit_length_column["type"]
|
||||
assert isinstance(sql_type, VARCHAR)
|
||||
assert sql_type.length == 50
|
||||
assert explicit_length_column["nullable"] is True
|
||||
|
||||
long_value_column = _first(c for c in columns if c["name"] == "long_value")
|
||||
sql_type = long_value_column["type"]
|
||||
assert isinstance(sql_type, VARCHAR)
|
||||
assert sql_type.length == len(_EnumWithLongValue.a_really_long_enum_values)
|
||||
|
||||
def test_insert_and_select(self):
|
||||
engine = sa.create_engine("sqlite://", echo=False)
|
||||
_Base.metadata.create_all(engine)
|
||||
|
||||
with Session(engine) as session:
|
||||
admin_user = _User(
|
||||
name="admin",
|
||||
user_type=_UserType.admin,
|
||||
user_type_nullable=None,
|
||||
)
|
||||
session.add(admin_user)
|
||||
session.flush()
|
||||
admin_user_id = admin_user.id
|
||||
|
||||
normal_user = _User(
|
||||
name="normal",
|
||||
user_type=_UserType.normal.value,
|
||||
user_type_nullable=_UserType.normal.value,
|
||||
)
|
||||
session.add(normal_user)
|
||||
session.flush()
|
||||
normal_user_id = normal_user.id
|
||||
session.commit()
|
||||
|
||||
with Session(engine) as session:
|
||||
user = session.query(_User).where(_User.id == admin_user_id).first()
|
||||
assert user.user_type == _UserType.admin
|
||||
assert user.user_type_nullable is None
|
||||
|
||||
with Session(engine) as session:
|
||||
user = session.query(_User).where(_User.id == normal_user_id).first()
|
||||
assert user.user_type == _UserType.normal
|
||||
assert user.user_type_nullable == _UserType.normal
|
||||
|
||||
def test_insert_invalid_values(self):
|
||||
def _session_insert_with_value(sess: Session, user_type: Any):
|
||||
user = _User(name="test_user", user_type=user_type)
|
||||
sess.add(user)
|
||||
sess.flush()
|
||||
|
||||
def _insert_with_user(sess: Session, user_type: Any):
|
||||
stmt = insert(_User).values(
|
||||
{
|
||||
"name": "test_user",
|
||||
"user_type": user_type,
|
||||
}
|
||||
)
|
||||
sess.execute(stmt)
|
||||
|
||||
class TestCase(NamedTuple):
|
||||
name: str
|
||||
action: Callable[[Session], None]
|
||||
exc_type: type[Exception]
|
||||
|
||||
engine = sa.create_engine("sqlite://", echo=False)
|
||||
_Base.metadata.create_all(engine)
|
||||
cases = [
|
||||
TestCase(
|
||||
name="session insert with invalid value",
|
||||
action=lambda s: _session_insert_with_value(s, "invalid"),
|
||||
exc_type=ValueError,
|
||||
),
|
||||
TestCase(
|
||||
name="session insert with invalid type",
|
||||
action=lambda s: _session_insert_with_value(s, 1),
|
||||
exc_type=ValueError,
|
||||
),
|
||||
TestCase(
|
||||
name="insert with invalid value",
|
||||
action=lambda s: _insert_with_user(s, "invalid"),
|
||||
exc_type=ValueError,
|
||||
),
|
||||
TestCase(
|
||||
name="insert with invalid type",
|
||||
action=lambda s: _insert_with_user(s, 1),
|
||||
exc_type=ValueError,
|
||||
),
|
||||
]
|
||||
for idx, c in enumerate(cases, 1):
|
||||
with pytest.raises(sa_exc.StatementError) as exc:
|
||||
with Session(engine) as session:
|
||||
c.action(session)
|
||||
|
||||
assert isinstance(exc.value.orig, c.exc_type), f"test case {idx} failed, name={c.name}"
|
||||
|
||||
def test_select_invalid_values(self):
|
||||
engine = sa.create_engine("sqlite://", echo=False)
|
||||
_Base.metadata.create_all(engine)
|
||||
|
||||
insertion_sql = """
|
||||
INSERT INTO users (id, name, user_type) VALUES
|
||||
(1, 'invalid_value', 'invalid');
|
||||
"""
|
||||
with Session(engine) as session:
|
||||
session.execute(sa.text(insertion_sql))
|
||||
session.commit()
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
with Session(engine) as session:
|
||||
_user = session.query(_User).where(_User.id == 1).first()
|
||||
299
dify/api/tests/unit_tests/models/test_workflow.py
Normal file
299
dify/api/tests/unit_tests/models/test_workflow.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import dataclasses
|
||||
import json
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
||||
from core.variables.segments import IntegerSegment, Segment
|
||||
from factories.variable_factory import build_segment
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
|
||||
|
||||
|
||||
def test_environment_variables():
|
||||
# tenant_id context variable removed - using current_user.current_tenant_id directly
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow(
|
||||
tenant_id="tenant_id",
|
||||
app_id="app_id",
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph="{}",
|
||||
features="{}",
|
||||
created_by="account_id",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Create some EnvironmentVariable instances
|
||||
variable1 = StringVariable.model_validate(
|
||||
{"name": "var1", "value": "value1", "id": str(uuid4()), "selector": ["env", "var1"]}
|
||||
)
|
||||
variable2 = IntegerVariable.model_validate(
|
||||
{"name": "var2", "value": 123, "id": str(uuid4()), "selector": ["env", "var2"]}
|
||||
)
|
||||
variable3 = SecretVariable.model_validate(
|
||||
{"name": "var3", "value": "secret", "id": str(uuid4()), "selector": ["env", "var3"]}
|
||||
)
|
||||
variable4 = FloatVariable.model_validate(
|
||||
{"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]}
|
||||
)
|
||||
|
||||
with (
|
||||
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
|
||||
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
|
||||
):
|
||||
# Set the environment_variables property of the Workflow instance
|
||||
variables = [variable1, variable2, variable3, variable4]
|
||||
workflow.environment_variables = variables
|
||||
|
||||
# Get the environment_variables property and assert its value
|
||||
assert workflow.environment_variables == variables
|
||||
|
||||
|
||||
def test_update_environment_variables():
|
||||
# tenant_id context variable removed - using current_user.current_tenant_id directly
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow(
|
||||
tenant_id="tenant_id",
|
||||
app_id="app_id",
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph="{}",
|
||||
features="{}",
|
||||
created_by="account_id",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Create some EnvironmentVariable instances
|
||||
variable1 = StringVariable.model_validate(
|
||||
{"name": "var1", "value": "value1", "id": str(uuid4()), "selector": ["env", "var1"]}
|
||||
)
|
||||
variable2 = IntegerVariable.model_validate(
|
||||
{"name": "var2", "value": 123, "id": str(uuid4()), "selector": ["env", "var2"]}
|
||||
)
|
||||
variable3 = SecretVariable.model_validate(
|
||||
{"name": "var3", "value": "secret", "id": str(uuid4()), "selector": ["env", "var3"]}
|
||||
)
|
||||
variable4 = FloatVariable.model_validate(
|
||||
{"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]}
|
||||
)
|
||||
|
||||
with (
|
||||
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
|
||||
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
|
||||
):
|
||||
variables = [variable1, variable2, variable3, variable4]
|
||||
|
||||
# Set the environment_variables property of the Workflow instance
|
||||
workflow.environment_variables = variables
|
||||
assert workflow.environment_variables == [variable1, variable2, variable3, variable4]
|
||||
|
||||
# Update the name of variable3 and keep the value as it is
|
||||
variables[2] = variable3.model_copy(
|
||||
update={
|
||||
"name": "new name",
|
||||
"value": HIDDEN_VALUE,
|
||||
}
|
||||
)
|
||||
|
||||
workflow.environment_variables = variables
|
||||
assert workflow.environment_variables[2].name == "new name"
|
||||
assert workflow.environment_variables[2].value == variable3.value
|
||||
|
||||
|
||||
def test_to_dict():
|
||||
# tenant_id context variable removed - using current_user.current_tenant_id directly
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow(
|
||||
tenant_id="tenant_id",
|
||||
app_id="app_id",
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph="{}",
|
||||
features="{}",
|
||||
created_by="account_id",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Create some EnvironmentVariable instances
|
||||
|
||||
with (
|
||||
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
|
||||
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
|
||||
):
|
||||
# Set the environment_variables property of the Workflow instance
|
||||
workflow.environment_variables = [
|
||||
SecretVariable.model_validate({"name": "secret", "value": "secret", "id": str(uuid4())}),
|
||||
StringVariable.model_validate({"name": "text", "value": "text", "id": str(uuid4())}),
|
||||
]
|
||||
|
||||
workflow_dict = workflow.to_dict()
|
||||
assert workflow_dict["environment_variables"][0]["value"] == ""
|
||||
assert workflow_dict["environment_variables"][1]["value"] == "text"
|
||||
|
||||
workflow_dict = workflow.to_dict(include_secret=True)
|
||||
assert workflow_dict["environment_variables"][0]["value"] == "secret"
|
||||
assert workflow_dict["environment_variables"][1]["value"] == "text"
|
||||
|
||||
|
||||
class TestWorkflowNodeExecution:
|
||||
def test_execution_metadata_dict(self):
|
||||
node_exec = WorkflowNodeExecutionModel()
|
||||
node_exec.execution_metadata = None
|
||||
assert node_exec.execution_metadata_dict == {}
|
||||
|
||||
original = {"a": 1, "b": ["2"]}
|
||||
node_exec.execution_metadata = json.dumps(original)
|
||||
assert node_exec.execution_metadata_dict == original
|
||||
|
||||
|
||||
class TestIsSystemVariableEditable:
|
||||
def test_is_system_variable(self):
|
||||
cases = [
|
||||
("query", True),
|
||||
("files", True),
|
||||
("dialogue_count", False),
|
||||
("conversation_id", False),
|
||||
("user_id", False),
|
||||
("app_id", False),
|
||||
("workflow_id", False),
|
||||
("workflow_run_id", False),
|
||||
]
|
||||
for name, editable in cases:
|
||||
assert editable == is_system_variable_editable(name)
|
||||
|
||||
assert is_system_variable_editable("invalid_or_new_system_variable") == False
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableGetValue:
|
||||
def test_get_value_by_case(self):
|
||||
@dataclasses.dataclass
|
||||
class TestCase:
|
||||
name: str
|
||||
value: Segment
|
||||
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
test_file = File(
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/example.jpg",
|
||||
filename="example.jpg",
|
||||
extension=".jpg",
|
||||
mime_type="image/jpeg",
|
||||
size=100,
|
||||
)
|
||||
cases: list[TestCase] = [
|
||||
TestCase(
|
||||
name="number/int",
|
||||
value=build_segment(1),
|
||||
),
|
||||
TestCase(
|
||||
name="number/float",
|
||||
value=build_segment(1.0),
|
||||
),
|
||||
TestCase(
|
||||
name="string",
|
||||
value=build_segment("a"),
|
||||
),
|
||||
TestCase(
|
||||
name="object",
|
||||
value=build_segment({}),
|
||||
),
|
||||
TestCase(
|
||||
name="file",
|
||||
value=build_segment(test_file),
|
||||
),
|
||||
TestCase(
|
||||
name="array[any]",
|
||||
value=build_segment([1, "a"]),
|
||||
),
|
||||
TestCase(
|
||||
name="array[string]",
|
||||
value=build_segment(["a", "b"]),
|
||||
),
|
||||
TestCase(
|
||||
name="array[number]/int",
|
||||
value=build_segment([1, 2]),
|
||||
),
|
||||
TestCase(
|
||||
name="array[number]/float",
|
||||
value=build_segment([1.0, 2.0]),
|
||||
),
|
||||
TestCase(
|
||||
name="array[number]/mixed",
|
||||
value=build_segment([1, 2.0]),
|
||||
),
|
||||
TestCase(
|
||||
name="array[object]",
|
||||
value=build_segment([{}, {"a": 1}]),
|
||||
),
|
||||
TestCase(
|
||||
name="none",
|
||||
value=build_segment(None),
|
||||
),
|
||||
]
|
||||
|
||||
for idx, c in enumerate(cases, 1):
|
||||
fail_msg = f"test case {c.name} failed, index={idx}"
|
||||
draft_var = WorkflowDraftVariable()
|
||||
draft_var.set_value(c.value)
|
||||
assert c.value == draft_var.get_value(), fail_msg
|
||||
|
||||
def test_file_variable_preserves_all_fields(self):
|
||||
"""Test that File type variables preserve all fields during encoding/decoding."""
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
# Create a File with specific field values
|
||||
test_file = File(
|
||||
id="test_file_id",
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/test.jpg",
|
||||
filename="test.jpg",
|
||||
extension=".jpg",
|
||||
mime_type="image/jpeg",
|
||||
size=12345, # Specific size to test preservation
|
||||
storage_key="test_storage_key",
|
||||
)
|
||||
|
||||
# Create a FileSegment and WorkflowDraftVariable
|
||||
file_segment = build_segment(test_file)
|
||||
draft_var = WorkflowDraftVariable()
|
||||
draft_var.set_value(file_segment)
|
||||
|
||||
# Retrieve the value and verify all fields are preserved
|
||||
retrieved_segment = draft_var.get_value()
|
||||
retrieved_file = retrieved_segment.value
|
||||
|
||||
# Verify all important fields are preserved
|
||||
assert retrieved_file.id == test_file.id
|
||||
assert retrieved_file.tenant_id == test_file.tenant_id
|
||||
assert retrieved_file.type == test_file.type
|
||||
assert retrieved_file.transfer_method == test_file.transfer_method
|
||||
assert retrieved_file.remote_url == test_file.remote_url
|
||||
assert retrieved_file.filename == test_file.filename
|
||||
assert retrieved_file.extension == test_file.extension
|
||||
assert retrieved_file.mime_type == test_file.mime_type
|
||||
assert retrieved_file.size == test_file.size # This was the main issue being fixed
|
||||
# Note: storage_key is not serialized in model_dump() so it won't be preserved
|
||||
|
||||
# Verify the segments have the same type and the important fields match
|
||||
assert file_segment.value_type == retrieved_segment.value_type
|
||||
|
||||
def test_get_and_set_value(self):
|
||||
draft_var = WorkflowDraftVariable()
|
||||
int_var = IntegerSegment(value=1)
|
||||
draft_var.set_value(int_var)
|
||||
value = draft_var.get_value()
|
||||
assert value == int_var
|
||||
1044
dify/api/tests/unit_tests/models/test_workflow_models.py
Normal file
1044
dify/api/tests/unit_tests/models/test_workflow_models.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Unit tests for WorkflowNodeExecutionOffload model, focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionModel:
|
||||
"""Test WorkflowNodeExecutionModel with process_data truncation features."""
|
||||
|
||||
def create_mock_offload_data(
|
||||
self,
|
||||
inputs_file_id: str | None = None,
|
||||
outputs_file_id: str | None = None,
|
||||
process_data_file_id: str | None = None,
|
||||
) -> WorkflowNodeExecutionOffload:
|
||||
"""Create a mock offload data object."""
|
||||
offload = Mock(spec=WorkflowNodeExecutionOffload)
|
||||
offload.inputs_file_id = inputs_file_id
|
||||
offload.outputs_file_id = outputs_file_id
|
||||
offload.process_data_file_id = process_data_file_id
|
||||
|
||||
# Mock file objects
|
||||
if inputs_file_id:
|
||||
offload.inputs_file = Mock(spec=UploadFile)
|
||||
else:
|
||||
offload.inputs_file = None
|
||||
|
||||
if outputs_file_id:
|
||||
offload.outputs_file = Mock(spec=UploadFile)
|
||||
else:
|
||||
offload.outputs_file = None
|
||||
|
||||
if process_data_file_id:
|
||||
offload.process_data_file = Mock(spec=UploadFile)
|
||||
else:
|
||||
offload.process_data_file = None
|
||||
|
||||
return offload
|
||||
|
||||
def test_process_data_truncated_property_false_when_no_offload_data(self):
|
||||
"""Test process_data_truncated returns False when no offload_data."""
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
execution.offload_data = []
|
||||
|
||||
assert execution.process_data_truncated is False
|
||||
|
||||
def test_process_data_truncated_property_false_when_no_process_data_file(self):
|
||||
"""Test process_data_truncated returns False when no process_data file."""
|
||||
from models.enums import ExecutionOffLoadType
|
||||
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Create real offload instances for inputs and outputs but not process_data
|
||||
inputs_offload = WorkflowNodeExecutionOffload()
|
||||
inputs_offload.type_ = ExecutionOffLoadType.INPUTS
|
||||
inputs_offload.file_id = "inputs-file"
|
||||
|
||||
outputs_offload = WorkflowNodeExecutionOffload()
|
||||
outputs_offload.type_ = ExecutionOffLoadType.OUTPUTS
|
||||
outputs_offload.file_id = "outputs-file"
|
||||
|
||||
execution.offload_data = [inputs_offload, outputs_offload]
|
||||
|
||||
assert execution.process_data_truncated is False
|
||||
|
||||
def test_process_data_truncated_property_true_when_process_data_file_exists(self):
|
||||
"""Test process_data_truncated returns True when process_data file exists."""
|
||||
from models.enums import ExecutionOffLoadType
|
||||
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Create a real offload instance for process_data
|
||||
process_data_offload = WorkflowNodeExecutionOffload()
|
||||
process_data_offload.type_ = ExecutionOffLoadType.PROCESS_DATA
|
||||
process_data_offload.file_id = "process-data-file-id"
|
||||
execution.offload_data = [process_data_offload]
|
||||
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
def test_load_full_process_data_with_no_offload_data(self):
|
||||
"""Test load_full_process_data when no offload data exists."""
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
execution.offload_data = []
|
||||
execution.process_data = '{"test": "data"}'
|
||||
|
||||
# Mock session and storage
|
||||
mock_session = Mock()
|
||||
mock_storage = Mock()
|
||||
|
||||
result = execution.load_full_process_data(mock_session, mock_storage)
|
||||
|
||||
assert result == {"test": "data"}
|
||||
|
||||
def test_load_full_process_data_with_no_file(self):
|
||||
"""Test load_full_process_data when no process_data file exists."""
|
||||
from models.enums import ExecutionOffLoadType
|
||||
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Create offload data for inputs only, not process_data
|
||||
inputs_offload = WorkflowNodeExecutionOffload()
|
||||
inputs_offload.type_ = ExecutionOffLoadType.INPUTS
|
||||
inputs_offload.file_id = "inputs-file"
|
||||
|
||||
execution.offload_data = [inputs_offload]
|
||||
execution.process_data = '{"test": "data"}'
|
||||
|
||||
# Mock session and storage
|
||||
mock_session = Mock()
|
||||
mock_storage = Mock()
|
||||
|
||||
result = execution.load_full_process_data(mock_session, mock_storage)
|
||||
|
||||
assert result == {"test": "data"}
|
||||
|
||||
def test_load_full_process_data_with_file(self):
|
||||
"""Test load_full_process_data when process_data file exists."""
|
||||
from models.enums import ExecutionOffLoadType
|
||||
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Create process_data offload
|
||||
process_data_offload = WorkflowNodeExecutionOffload()
|
||||
process_data_offload.type_ = ExecutionOffLoadType.PROCESS_DATA
|
||||
process_data_offload.file_id = "file-id"
|
||||
|
||||
execution.offload_data = [process_data_offload]
|
||||
execution.process_data = '{"truncated": "data"}'
|
||||
|
||||
# Mock session and storage
|
||||
mock_session = Mock()
|
||||
mock_storage = Mock()
|
||||
|
||||
# Mock the _load_full_content method to return full data
|
||||
full_process_data = {"full": "data", "large_field": "x" * 10000}
|
||||
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
# Mock the _load_full_content method
|
||||
def mock_load_full_content(session, file_id, storage):
|
||||
assert session == mock_session
|
||||
assert file_id == "file-id"
|
||||
assert storage == mock_storage
|
||||
return full_process_data
|
||||
|
||||
mp.setattr(execution, "_load_full_content", mock_load_full_content)
|
||||
|
||||
result = execution.load_full_process_data(mock_session, mock_storage)
|
||||
|
||||
assert result == full_process_data
|
||||
|
||||
def test_consistency_with_inputs_outputs_truncation(self):
|
||||
"""Test that process_data truncation behaves consistently with inputs/outputs."""
|
||||
from models.enums import ExecutionOffLoadType
|
||||
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Create offload data for all three types
|
||||
inputs_offload = WorkflowNodeExecutionOffload()
|
||||
inputs_offload.type_ = ExecutionOffLoadType.INPUTS
|
||||
inputs_offload.file_id = "inputs-file"
|
||||
|
||||
outputs_offload = WorkflowNodeExecutionOffload()
|
||||
outputs_offload.type_ = ExecutionOffLoadType.OUTPUTS
|
||||
outputs_offload.file_id = "outputs-file"
|
||||
|
||||
process_data_offload = WorkflowNodeExecutionOffload()
|
||||
process_data_offload.type_ = ExecutionOffLoadType.PROCESS_DATA
|
||||
process_data_offload.file_id = "process-data-file"
|
||||
|
||||
execution.offload_data = [inputs_offload, outputs_offload, process_data_offload]
|
||||
|
||||
# All three should be truncated
|
||||
assert execution.inputs_truncated is True
|
||||
assert execution.outputs_truncated is True
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
def test_mixed_truncation_states(self):
|
||||
"""Test mixed states of truncation."""
|
||||
from models.enums import ExecutionOffLoadType
|
||||
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Only process_data is truncated
|
||||
process_data_offload = WorkflowNodeExecutionOffload()
|
||||
process_data_offload.type_ = ExecutionOffLoadType.PROCESS_DATA
|
||||
process_data_offload.file_id = "process-data-file"
|
||||
|
||||
execution.offload_data = [process_data_offload]
|
||||
|
||||
assert execution.inputs_truncated is False
|
||||
assert execution.outputs_truncated is False
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
def test_preload_offload_data_and_files_method_exists(self):
|
||||
"""Test that the preload method includes process_data_file."""
|
||||
# This test verifies the method exists and can be called
|
||||
# The actual SQL behavior would be tested in integration tests
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = select(WorkflowNodeExecutionModel)
|
||||
|
||||
# This should not raise an exception
|
||||
preloaded_stmt = WorkflowNodeExecutionModel.preload_offload_data_and_files(stmt)
|
||||
|
||||
# The statement should be modified (different object)
|
||||
assert preloaded_stmt is not stmt
|
||||
188
dify/api/tests/unit_tests/models/test_workflow_trigger_log.py
Normal file
188
dify/api/tests/unit_tests/models/test_workflow_trigger_log.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
from models.engine import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db_scalar(monkeypatch):
|
||||
"""Provide a controllable fake for db.session.scalar (SQLAlchemy 2.0 style)."""
|
||||
calls = []
|
||||
|
||||
def _install(side_effect):
|
||||
def _fake_scalar(statement):
|
||||
calls.append(statement)
|
||||
return side_effect(statement)
|
||||
|
||||
# Patch the modern API used by the model implementation
|
||||
monkeypatch.setattr(db.session, "scalar", _fake_scalar)
|
||||
|
||||
# Backward-compatibility: if the implementation still uses db.session.get,
|
||||
# make it delegate to the same side_effect so tests remain valid on older code.
|
||||
if hasattr(db.session, "get"):
|
||||
|
||||
def _fake_get(*_args, **_kwargs):
|
||||
return side_effect(None)
|
||||
|
||||
monkeypatch.setattr(db.session, "get", _fake_get)
|
||||
|
||||
return calls
|
||||
|
||||
return _install
|
||||
|
||||
|
||||
def make_account(id_: str = "acc-1"):
|
||||
# Use a simple object to avoid constructing a full SQLAlchemy model instance
|
||||
# Python 3.12 forbids reassigning __class__ for SimpleNamespace; not needed here.
|
||||
obj = types.SimpleNamespace()
|
||||
obj.id = id_
|
||||
return obj
|
||||
|
||||
|
||||
def make_end_user(id_: str = "user-1"):
|
||||
# Lightweight stand-in object; no need to spoof class identity.
|
||||
obj = types.SimpleNamespace()
|
||||
obj.id = id_
|
||||
return obj
|
||||
|
||||
|
||||
def test_created_by_account_returns_account_when_role_account(fake_db_scalar):
|
||||
account = make_account("acc-1")
|
||||
|
||||
# The implementation uses db.session.scalar(select(Account)...). We only need to
|
||||
# return the expected object when called; the exact SQL is irrelevant for this unit test.
|
||||
def side_effect(_statement):
|
||||
return account
|
||||
|
||||
fake_db_scalar(side_effect)
|
||||
|
||||
log = WorkflowNodeExecutionModel(
|
||||
tenant_id="t1",
|
||||
app_id="a1",
|
||||
workflow_id="w1",
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=None,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=None,
|
||||
node_id="n1",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status="succeeded",
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
execution_metadata=None,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
created_by="acc-1",
|
||||
)
|
||||
|
||||
assert log.created_by_account is account
|
||||
|
||||
|
||||
def test_created_by_account_returns_none_when_role_not_account(fake_db_scalar):
|
||||
# Even if an Account with matching id exists, property should return None when role is END_USER
|
||||
account = make_account("acc-1")
|
||||
|
||||
def side_effect(_statement):
|
||||
return account
|
||||
|
||||
fake_db_scalar(side_effect)
|
||||
|
||||
log = WorkflowNodeExecutionModel(
|
||||
tenant_id="t1",
|
||||
app_id="a1",
|
||||
workflow_id="w1",
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=None,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=None,
|
||||
node_id="n1",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status="succeeded",
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
execution_metadata=None,
|
||||
created_by_role=CreatorUserRole.END_USER.value,
|
||||
created_by="acc-1",
|
||||
)
|
||||
|
||||
assert log.created_by_account is None
|
||||
|
||||
|
||||
def test_created_by_end_user_returns_end_user_when_role_end_user(fake_db_scalar):
|
||||
end_user = make_end_user("user-1")
|
||||
|
||||
def side_effect(_statement):
|
||||
return end_user
|
||||
|
||||
fake_db_scalar(side_effect)
|
||||
|
||||
log = WorkflowNodeExecutionModel(
|
||||
tenant_id="t1",
|
||||
app_id="a1",
|
||||
workflow_id="w1",
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=None,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=None,
|
||||
node_id="n1",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status="succeeded",
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
execution_metadata=None,
|
||||
created_by_role=CreatorUserRole.END_USER.value,
|
||||
created_by="user-1",
|
||||
)
|
||||
|
||||
assert log.created_by_end_user is end_user
|
||||
|
||||
|
||||
def test_created_by_end_user_returns_none_when_role_not_end_user(fake_db_scalar):
|
||||
end_user = make_end_user("user-1")
|
||||
|
||||
def side_effect(_statement):
|
||||
return end_user
|
||||
|
||||
fake_db_scalar(side_effect)
|
||||
|
||||
log = WorkflowNodeExecutionModel(
|
||||
tenant_id="t1",
|
||||
app_id="a1",
|
||||
workflow_id="w1",
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=None,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=None,
|
||||
node_id="n1",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status="succeeded",
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
execution_metadata=None,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
created_by="user-1",
|
||||
)
|
||||
|
||||
assert log.created_by_end_user is None
|
||||
Reference in New Issue
Block a user