This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1,107 @@
import pytest
from libs.validators import validate_description_length
class TestDescriptionValidationUnit:
"""Unit tests for the centralized description validation function."""
def test_validate_description_length_valid(self):
"""Test validation function with valid descriptions."""
# Empty string should be valid
assert validate_description_length("") == ""
# None should be valid
assert validate_description_length(None) is None
# Short description should be valid
short_desc = "Short description"
assert validate_description_length(short_desc) == short_desc
# Exactly 400 characters should be valid
exactly_400 = "x" * 400
assert validate_description_length(exactly_400) == exactly_400
# Just under limit should be valid
just_under = "x" * 399
assert validate_description_length(just_under) == just_under
def test_validate_description_length_invalid(self):
"""Test validation function with invalid descriptions."""
# 401 characters should fail
just_over = "x" * 401
with pytest.raises(ValueError) as exc_info:
validate_description_length(just_over)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
# 500 characters should fail
way_over = "x" * 500
with pytest.raises(ValueError) as exc_info:
validate_description_length(way_over)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
# 1000 characters should fail
very_long = "x" * 1000
with pytest.raises(ValueError) as exc_info:
validate_description_length(very_long)
assert "Description cannot exceed 400 characters." in str(exc_info.value)
def test_boundary_values(self):
"""Test boundary values around the 400 character limit."""
boundary_tests = [
(0, True), # Empty
(1, True), # Minimum
(399, True), # Just under limit
(400, True), # Exactly at limit
(401, False), # Just over limit
(402, False), # Over limit
(500, False), # Way over limit
]
for length, should_pass in boundary_tests:
test_desc = "x" * length
if should_pass:
# Should not raise exception
assert validate_description_length(test_desc) == test_desc
else:
# Should raise ValueError
with pytest.raises(ValueError):
validate_description_length(test_desc)
def test_special_characters(self):
"""Test validation with special characters, Unicode, etc."""
# Unicode characters
unicode_desc = "测试描述" * 100 # Chinese characters
if len(unicode_desc) <= 400:
assert validate_description_length(unicode_desc) == unicode_desc
# Special characters
special_desc = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?" * 10
if len(special_desc) <= 400:
assert validate_description_length(special_desc) == special_desc
# Mixed content
mixed_desc = "Mixed content: 测试 123 !@# " * 15
if len(mixed_desc) <= 400:
assert validate_description_length(mixed_desc) == mixed_desc
elif len(mixed_desc) > 400:
with pytest.raises(ValueError):
validate_description_length(mixed_desc)
def test_whitespace_handling(self):
"""Test validation with various whitespace scenarios."""
# Leading/trailing whitespace
whitespace_desc = " Description with whitespace "
if len(whitespace_desc) <= 400:
assert validate_description_length(whitespace_desc) == whitespace_desc
# Newlines and tabs
multiline_desc = "Line 1\nLine 2\tTabbed content"
if len(multiline_desc) <= 400:
assert validate_description_length(multiline_desc) == multiline_desc
# Only whitespace over limit
only_spaces = " " * 401
with pytest.raises(ValueError):
validate_description_length(only_spaces)

View File

@@ -0,0 +1,411 @@
import uuid
from collections import OrderedDict
from typing import Any, NamedTuple
from unittest.mock import MagicMock, patch
import pytest
from flask_restx import marshal
from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
_serialize_full_content,
)
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
from services.workflow_draft_variable_service import WorkflowDraftVariableList
_TEST_APP_ID = "test_app_id"
_TEST_NODE_EXEC_ID = str(uuid.uuid4())
class TestWorkflowDraftVariableFields:
def test_serialize_full_content(self):
"""Test that _serialize_full_content uses pre-loaded relationships."""
# Create mock objects with relationships pre-loaded
mock_variable_file = MagicMock(spec=WorkflowDraftVariableFile)
mock_variable_file.size = 100000
mock_variable_file.length = 50
mock_variable_file.value_type = SegmentType.OBJECT
mock_variable_file.upload_file_id = "test-upload-file-id"
mock_variable = MagicMock(spec=WorkflowDraftVariable)
mock_variable.file_id = "test-file-id"
mock_variable.variable_file = mock_variable_file
# Mock the file helpers
with patch("controllers.console.app.workflow_draft_variable.file_helpers") as mock_file_helpers:
mock_file_helpers.get_signed_file_url.return_value = "http://example.com/signed-url"
# Call the function
result = _serialize_full_content(mock_variable)
# Verify it returns the expected structure
assert result is not None
assert result["size_bytes"] == 100000
assert result["length"] == 50
assert result["value_type"] == "object"
assert "download_url" in result
assert result["download_url"] == "http://example.com/signed-url"
# Verify it used the pre-loaded relationships (no database queries)
mock_file_helpers.get_signed_file_url.assert_called_once_with("test-upload-file-id", as_attachment=True)
def test_serialize_full_content_handles_none_cases(self):
"""Test that _serialize_full_content handles None cases properly."""
# Test with no file_id
draft_var = WorkflowDraftVariable()
draft_var.file_id = None
result = _serialize_full_content(draft_var)
assert result is None
def test_serialize_full_content_should_raises_when_file_id_exists_but_file_is_none(self):
# Test with no file_id
draft_var = WorkflowDraftVariable()
draft_var.file_id = str(uuid.uuid4())
draft_var.variable_file = None
with pytest.raises(AssertionError):
result = _serialize_full_content(draft_var)
def test_conversation_variable(self):
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
)
conv_var.id = str(uuid.uuid4())
conv_var.visible = True
expected_without_value: OrderedDict[str, Any] = OrderedDict(
{
"id": str(conv_var.id),
"type": conv_var.get_variable_type().value,
"name": "conv_var",
"description": "",
"selector": [CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
"value_type": "number",
"edited": False,
"visible": True,
"is_truncated": False,
}
)
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = 1
expected_with_value["full_content"] = None
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
def test_create_sys_variable(self):
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=_TEST_APP_ID,
name="sys_var",
value=build_segment("a"),
editable=True,
node_execution_id=_TEST_NODE_EXEC_ID,
)
sys_var.id = str(uuid.uuid4())
sys_var.last_edited_at = naive_utc_now()
sys_var.visible = True
expected_without_value = OrderedDict(
{
"id": str(sys_var.id),
"type": sys_var.get_variable_type().value,
"name": "sys_var",
"description": "",
"selector": [SYSTEM_VARIABLE_NODE_ID, "sys_var"],
"value_type": "string",
"edited": True,
"visible": True,
"is_truncated": False,
}
)
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = "a"
expected_with_value["full_content"] = None
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
def test_node_variable(self):
node_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="node_var",
value=build_segment([1, "a"]),
visible=False,
node_execution_id=_TEST_NODE_EXEC_ID,
)
node_var.id = str(uuid.uuid4())
node_var.last_edited_at = naive_utc_now()
expected_without_value: OrderedDict[str, Any] = OrderedDict(
{
"id": str(node_var.id),
"type": node_var.get_variable_type().value,
"name": "node_var",
"description": "",
"selector": ["test_node", "node_var"],
"value_type": "array[any]",
"edited": True,
"visible": False,
"is_truncated": False,
}
)
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = [1, "a"]
expected_with_value["full_content"] = None
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
def test_node_variable_with_file(self):
node_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="node_var",
value=build_segment([1, "a"]),
visible=False,
node_execution_id=_TEST_NODE_EXEC_ID,
)
node_var.id = str(uuid.uuid4())
node_var.last_edited_at = naive_utc_now()
variable_file = WorkflowDraftVariableFile(
id=str(uuidv7()),
upload_file_id=str(uuid.uuid4()),
size=1024,
length=10,
value_type=SegmentType.ARRAY_STRING,
)
node_var.variable_file = variable_file
node_var.file_id = variable_file.id
expected_without_value: OrderedDict[str, Any] = OrderedDict(
{
"id": str(node_var.id),
"type": node_var.get_variable_type().value,
"name": "node_var",
"description": "",
"selector": ["test_node", "node_var"],
"value_type": "array[any]",
"edited": True,
"visible": False,
"is_truncated": True,
}
)
with patch("controllers.console.app.workflow_draft_variable.file_helpers") as mock_file_helpers:
mock_file_helpers.get_signed_file_url.return_value = "http://example.com/signed-url"
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()
expected_with_value["value"] = [1, "a"]
expected_with_value["full_content"] = {
"size_bytes": 1024,
"value_type": "array[string]",
"length": 10,
"download_url": "http://example.com/signed-url",
}
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
class TestWorkflowDraftVariableList:
def test_workflow_draft_variable_list(self):
class TestCase(NamedTuple):
name: str
var_list: WorkflowDraftVariableList
expected: dict
node_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="test_var",
value=build_segment("a"),
visible=True,
node_execution_id=_TEST_NODE_EXEC_ID,
)
node_var.id = str(uuid.uuid4())
node_var_dict = OrderedDict(
{
"id": str(node_var.id),
"type": node_var.get_variable_type().value,
"name": "test_var",
"description": "",
"selector": ["test_node", "test_var"],
"value_type": "string",
"edited": False,
"visible": True,
"is_truncated": False,
}
)
cases = [
TestCase(
name="empty variable list",
var_list=WorkflowDraftVariableList(variables=[]),
expected=OrderedDict(
{
"items": [],
"total": None,
}
),
),
TestCase(
name="empty variable list with total",
var_list=WorkflowDraftVariableList(variables=[], total=10),
expected=OrderedDict(
{
"items": [],
"total": 10,
}
),
),
TestCase(
name="non-empty variable list",
var_list=WorkflowDraftVariableList(variables=[node_var], total=None),
expected=OrderedDict(
{
"items": [node_var_dict],
"total": None,
}
),
),
TestCase(
name="non-empty variable list with total",
var_list=WorkflowDraftVariableList(variables=[node_var], total=10),
expected=OrderedDict(
{
"items": [node_var_dict],
"total": 10,
}
),
),
]
for idx, case in enumerate(cases, 1):
assert marshal(case.var_list, _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) == case.expected, (
f"Test case {idx} failed, {case.name=}"
)
def test_workflow_node_variables_fields():
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
)
resp = marshal(WorkflowDraftVariableList(variables=[conv_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
assert isinstance(resp, dict)
assert len(resp["items"]) == 1
item_dict = resp["items"][0]
assert item_dict["name"] == "conv_var"
assert item_dict["value"] == 1
def test_workflow_file_variable_with_signed_url():
"""Test that File type variables include signed URLs in API responses."""
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
# Create a File object with LOCAL_FILE transfer method (which generates signed URLs)
test_file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_upload_file_id",
filename="test.jpg",
extension=".jpg",
mime_type="image/jpeg",
size=12345,
)
# Create a WorkflowDraftVariable with the File
file_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="file_var",
value=build_segment(test_file),
node_execution_id=_TEST_NODE_EXEC_ID,
)
# Marshal the variable using the API fields
resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
# Verify the response structure
assert isinstance(resp, dict)
assert len(resp["items"]) == 1
item_dict = resp["items"][0]
assert item_dict["name"] == "file_var"
# Verify the value is a dict (File.to_dict() result) and contains expected fields
value = item_dict["value"]
assert isinstance(value, dict)
# Verify the File fields are preserved
assert value["id"] == test_file.id
assert value["filename"] == test_file.filename
assert value["type"] == test_file.type.value
assert value["transfer_method"] == test_file.transfer_method.value
assert value["size"] == test_file.size
# Verify the URL is present (it should be a signed URL for LOCAL_FILE transfer method)
remote_url = value["remote_url"]
assert remote_url is not None
assert isinstance(remote_url, str)
# For LOCAL_FILE, the URL should contain signature parameters
assert "timestamp=" in remote_url
assert "nonce=" in remote_url
assert "sign=" in remote_url
def test_workflow_file_variable_remote_url():
"""Test that File type variables with REMOTE_URL transfer method return the remote URL."""
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
# Create a File object with REMOTE_URL transfer method
test_file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/test.jpg",
filename="test.jpg",
extension=".jpg",
mime_type="image/jpeg",
size=12345,
)
# Create a WorkflowDraftVariable with the File
file_var = WorkflowDraftVariable.new_node_variable(
app_id=_TEST_APP_ID,
node_id="test_node",
name="file_var",
value=build_segment(test_file),
node_execution_id=_TEST_NODE_EXEC_ID,
)
# Marshal the variable using the API fields
resp = marshal(WorkflowDraftVariableList(variables=[file_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
# Verify the response structure
assert isinstance(resp, dict)
assert len(resp["items"]) == 1
item_dict = resp["items"][0]
assert item_dict["name"] == "file_var"
# Verify the value is a dict (File.to_dict() result) and contains expected fields
value = item_dict["value"]
assert isinstance(value, dict)
remote_url = value["remote_url"]
# For REMOTE_URL, the URL should be the original remote URL
assert remote_url == test_file.remote_url

View File

@@ -0,0 +1,456 @@
"""
Test suite for account activation flows.
This module tests the account activation mechanism including:
- Invitation token validation
- Account activation with user preferences
- Workspace member onboarding
- Initial login after activation
"""
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.activate import ActivateApi, ActivateCheckApi
from controllers.console.error import AlreadyActivateError
from models.account import AccountStatus
class TestActivateCheckApi:
"""Test cases for checking activation token validity."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_invitation(self):
"""Create mock invitation object."""
tenant = MagicMock()
tenant.id = "workspace-123"
tenant.name = "Test Workspace"
return {
"data": {"email": "invitee@example.com"},
"tenant": tenant,
}
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation):
"""
Test checking valid invitation token.
Verifies that:
- Valid token returns invitation data
- Workspace information is included
- Invitee email is returned
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
# Act
with app.test_request_context(
"/activate/check?workspace_id=workspace-123&email=invitee@example.com&token=valid_token"
):
api = ActivateCheckApi()
response = api.get()
# Assert
assert response["is_valid"] is True
assert response["data"]["workspace_name"] == "Test Workspace"
assert response["data"]["workspace_id"] == "workspace-123"
assert response["data"]["email"] == "invitee@example.com"
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
def test_check_invalid_invitation_token(self, mock_get_invitation, app):
"""
Test checking invalid invitation token.
Verifies that:
- Invalid token returns is_valid as False
- No data is returned for invalid tokens
"""
# Arrange
mock_get_invitation.return_value = None
# Act
with app.test_request_context(
"/activate/check?workspace_id=workspace-123&email=test@example.com&token=invalid_token"
):
api = ActivateCheckApi()
response = api.get()
# Assert
assert response["is_valid"] is False
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation):
"""
Test checking token without workspace ID.
Verifies that:
- Token can be checked without workspace_id parameter
- System handles None workspace_id gracefully
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
# Act
with app.test_request_context("/activate/check?email=invitee@example.com&token=valid_token"):
api = ActivateCheckApi()
response = api.get()
# Assert
assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation):
"""
Test checking token without email parameter.
Verifies that:
- Token can be checked without email parameter
- System handles None email gracefully
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
# Act
with app.test_request_context("/activate/check?workspace_id=workspace-123&token=valid_token"):
api = ActivateCheckApi()
response = api.get()
# Assert
assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token")
class TestActivateApi:
"""Test cases for account activation endpoint."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_account(self):
"""Create mock account object."""
account = MagicMock()
account.id = "account-123"
account.email = "invitee@example.com"
account.status = AccountStatus.PENDING
return account
@pytest.fixture
def mock_invitation(self, mock_account):
"""Create mock invitation with account."""
tenant = MagicMock()
tenant.id = "workspace-123"
tenant.name = "Test Workspace"
return {
"data": {"email": "invitee@example.com"},
"tenant": tenant,
"account": mock_account,
}
@pytest.fixture
def mock_token_pair(self):
"""Create mock token pair object."""
token_pair = MagicMock()
token_pair.access_token = "access_token"
token_pair.refresh_token = "refresh_token"
token_pair.csrf_token = "csrf_token"
token_pair.model_dump.return_value = {
"access_token": "access_token",
"refresh_token": "refresh_token",
"csrf_token": "csrf_token",
}
return token_pair
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_successful_account_activation(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_account,
mock_token_pair,
):
"""
Test successful account activation.
Verifies that:
- Account is activated with user preferences
- Account status is set to ACTIVE
- User is logged in after activation
- Invitation token is revoked
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
"/activate",
method="POST",
json={
"workspace_id": "workspace-123",
"email": "invitee@example.com",
"token": "valid_token",
"name": "John Doe",
"interface_language": "en-US",
"timezone": "UTC",
},
):
api = ActivateApi()
response = api.post()
# Assert
assert response["result"] == "success"
assert mock_account.name == "John Doe"
assert mock_account.interface_language == "en-US"
assert mock_account.timezone == "UTC"
assert mock_account.status == AccountStatus.ACTIVE
assert mock_account.initialized_at is not None
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
mock_db.session.commit.assert_called_once()
mock_login.assert_called_once()
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
def test_activation_with_invalid_token(self, mock_get_invitation, app):
"""
Test account activation with invalid token.
Verifies that:
- AlreadyActivateError is raised for invalid tokens
- No account changes are made
"""
# Arrange
mock_get_invitation.return_value = None
# Act & Assert
with app.test_request_context(
"/activate",
method="POST",
json={
"workspace_id": "workspace-123",
"email": "invitee@example.com",
"token": "invalid_token",
"name": "John Doe",
"interface_language": "en-US",
"timezone": "UTC",
},
):
api = ActivateApi()
with pytest.raises(AlreadyActivateError):
api.post()
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_activation_sets_interface_theme(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_account,
mock_token_pair,
):
"""
Test that activation sets default interface theme.
Verifies that:
- Interface theme is set to 'light' by default
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
"/activate",
method="POST",
json={
"workspace_id": "workspace-123",
"email": "invitee@example.com",
"token": "valid_token",
"name": "John Doe",
"interface_language": "en-US",
"timezone": "UTC",
},
):
api = ActivateApi()
api.post()
# Assert
assert mock_account.interface_theme == "light"
@pytest.mark.parametrize(
("language", "timezone"),
[
("en-US", "UTC"),
("zh-Hans", "Asia/Shanghai"),
("ja-JP", "Asia/Tokyo"),
("es-ES", "Europe/Madrid"),
],
)
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_activation_with_different_locales(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_account,
mock_token_pair,
language,
timezone,
):
"""
Test account activation with various language and timezone combinations.
Verifies that:
- Different languages are accepted
- Different timezones are accepted
- User preferences are properly stored
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
"/activate",
method="POST",
json={
"workspace_id": "workspace-123",
"email": "invitee@example.com",
"token": "valid_token",
"name": "Test User",
"interface_language": language,
"timezone": timezone,
},
):
api = ActivateApi()
response = api.post()
# Assert
assert response["result"] == "success"
assert mock_account.interface_language == language
assert mock_account.timezone == timezone
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_activation_returns_token_data(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_token_pair,
):
"""
Test that activation returns authentication tokens.
Verifies that:
- Token pair is returned in response
- All token types are included (access, refresh, csrf)
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
"/activate",
method="POST",
json={
"workspace_id": "workspace-123",
"email": "invitee@example.com",
"token": "valid_token",
"name": "John Doe",
"interface_language": "en-US",
"timezone": "UTC",
},
):
api = ActivateApi()
response = api.post()
# Assert
assert "data" in response
assert response["data"]["access_token"] == "access_token"
assert response["data"]["refresh_token"] == "refresh_token"
assert response["data"]["csrf_token"] == "csrf_token"
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_activation_without_workspace_id(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_token_pair,
):
"""
Test account activation without workspace_id.
Verifies that:
- Activation can proceed without workspace_id
- Token revocation handles None workspace_id
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
"/activate",
method="POST",
json={
"email": "invitee@example.com",
"token": "valid_token",
"name": "John Doe",
"interface_language": "en-US",
"timezone": "UTC",
},
):
api = ActivateApi()
response = api.post()
# Assert
assert response["result"] == "success"
mock_revoke_token.assert_called_once_with(None, "invitee@example.com", "valid_token")

View File

@@ -0,0 +1,138 @@
"""Test authentication security to prevent user enumeration."""
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_restx import Api
import services.errors.account
from controllers.console.auth.error import AuthenticationFailedError
from controllers.console.auth.login import LoginApi
class TestAuthenticationSecurity:
"""Test authentication endpoints for security against user enumeration."""
def setup_method(self):
"""Set up test fixtures."""
self.app = Flask(__name__)
self.api = Api(self.app)
self.api.add_resource(LoginApi, "/login")
self.client = self.app.test_client()
self.app.config["TESTING"] = True
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
def test_login_invalid_email_with_registration_allowed(
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
):
"""Test that invalid email raises AuthenticationFailedError when account not found."""
# Arrange
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = True
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
):
login_api = LoginApi()
# Assert
with pytest.raises(AuthenticationFailedError) as exc_info:
login_api.post()
assert exc_info.value.error_code == "authentication_failed"
assert exc_info.value.description == "Invalid email or password."
mock_add_rate_limit.assert_called_once_with("nonexistent@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
def test_login_wrong_password_returns_error(
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db
):
"""Test that wrong password returns AuthenticationFailedError."""
# Arrange
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "existing@example.com", "password": "WrongPass123!"}
):
login_api = LoginApi()
# Assert
with pytest.raises(AuthenticationFailedError) as exc_info:
login_api.post()
assert exc_info.value.error_code == "authentication_failed"
assert exc_info.value.description == "Invalid email or password."
mock_add_rate_limit.assert_called_once_with("existing@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
def test_login_invalid_email_with_registration_disabled(
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
):
"""Test that invalid email raises AuthenticationFailedError when account not found."""
# Arrange
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_features.return_value.is_allow_register = False
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
):
login_api = LoginApi()
# Assert
with pytest.raises(AuthenticationFailedError) as exc_info:
login_api.post()
assert exc_info.value.error_code == "authentication_failed"
assert exc_info.value.description == "Invalid email or password."
mock_add_rate_limit.assert_called_once_with("nonexistent@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
@patch("controllers.console.auth.login.AccountService.send_reset_password_email")
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
"""Test that reset password returns success with token for existing accounts."""
# Mock the setup check
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
# Test with existing account
mock_get_user.return_value = MagicMock(email="existing@example.com")
mock_send_email.return_value = "token123"
with self.app.test_request_context("/reset-password", method="POST", json={"email": "existing@example.com"}):
from controllers.console.auth.login import ResetPasswordSendEmailApi
api = ResetPasswordSendEmailApi()
result = api.post()
assert result == {"result": "success", "data": "token123"}

View File

@@ -0,0 +1,546 @@
"""
Test suite for email verification authentication flows.
This module tests the email code login mechanism including:
- Email code sending with rate limiting
- Code verification and validation
- Account creation via email verification
- Workspace creation for new users
"""
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError
from controllers.console.auth.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
from controllers.console.error import (
AccountInFreezeError,
AccountNotFound,
EmailSendIpLimitError,
NotAllowedCreateWorkspace,
WorkspacesLimitExceeded,
)
from services.errors.account import AccountRegisterError
class TestEmailCodeLoginSendEmailApi:
"""Test cases for sending email verification codes."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_account(self):
"""Create mock account object."""
account = MagicMock()
account.email = "test@example.com"
account.name = "Test User"
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
@patch("controllers.console.auth.login.AccountService.send_email_code_login_email")
def test_send_email_code_existing_user(
self, mock_send_email, mock_get_user, mock_is_ip_limit, mock_db, app, mock_account
):
"""
Test sending email code to existing user.
Verifies that:
- Email code is sent to existing account
- Token is generated and returned
- IP rate limiting is checked
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = mock_account
mock_send_email.return_value = "email_token_123"
# Act
with app.test_request_context(
"/email-code-login", method="POST", json={"email": "test@example.com", "language": "en-US"}
):
api = EmailCodeLoginSendEmailApi()
response = api.post()
# Assert
assert response["result"] == "success"
assert response["data"] == "email_token_123"
mock_send_email.assert_called_once_with(account=mock_account, language="en-US")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
@patch("controllers.console.auth.login.FeatureService.get_system_features")
@patch("controllers.console.auth.login.AccountService.send_email_code_login_email")
def test_send_email_code_new_user_registration_allowed(
self, mock_send_email, mock_get_features, mock_get_user, mock_is_ip_limit, mock_db, app
):
"""
Test sending email code to new user when registration is allowed.
Verifies that:
- Email code is sent even for non-existent accounts
- Registration is allowed by system features
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = None
mock_get_features.return_value.is_allow_register = True
mock_send_email.return_value = "email_token_123"
# Act
with app.test_request_context(
"/email-code-login", method="POST", json={"email": "newuser@example.com", "language": "en-US"}
):
api = EmailCodeLoginSendEmailApi()
response = api.post()
# Assert
assert response["result"] == "success"
mock_send_email.assert_called_once_with(email="newuser@example.com", language="en-US")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
@patch("controllers.console.auth.login.FeatureService.get_system_features")
def test_send_email_code_new_user_registration_disabled(
self, mock_get_features, mock_get_user, mock_is_ip_limit, mock_db, app
):
"""
Test sending email code to new user when registration is disabled.
Verifies that:
- AccountNotFound is raised for non-existent accounts
- Registration is blocked by system features
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = None
mock_get_features.return_value.is_allow_register = False
# Act & Assert
with app.test_request_context("/email-code-login", method="POST", json={"email": "newuser@example.com"}):
api = EmailCodeLoginSendEmailApi()
with pytest.raises(AccountNotFound):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit")
def test_send_email_code_ip_rate_limited(self, mock_is_ip_limit, mock_db, app):
"""
Test email code sending blocked by IP rate limit.
Verifies that:
- EmailSendIpLimitError is raised when IP limit exceeded
- Prevents spam and abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = True
# Act & Assert
with app.test_request_context("/email-code-login", method="POST", json={"email": "test@example.com"}):
api = EmailCodeLoginSendEmailApi()
with pytest.raises(EmailSendIpLimitError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
def test_send_email_code_frozen_account(self, mock_get_user, mock_is_ip_limit, mock_db, app):
"""
Test email code sending to frozen account.
Verifies that:
- AccountInFreezeError is raised for frozen accounts
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.side_effect = AccountRegisterError("Account frozen")
# Act & Assert
with app.test_request_context("/email-code-login", method="POST", json={"email": "frozen@example.com"}):
api = EmailCodeLoginSendEmailApi()
with pytest.raises(AccountInFreezeError):
api.post()
@pytest.mark.parametrize(
("language_input", "expected_language"),
[
("zh-Hans", "zh-Hans"),
("en-US", "en-US"),
(None, "en-US"),
],
)
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
@patch("controllers.console.auth.login.AccountService.send_email_code_login_email")
def test_send_email_code_language_handling(
self,
mock_send_email,
mock_get_user,
mock_is_ip_limit,
mock_db,
app,
mock_account,
language_input,
expected_language,
):
"""
Test email code sending with different language preferences.
Verifies that:
- Language parameter is correctly processed
- Defaults to en-US when not specified
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = False
mock_get_user.return_value = mock_account
mock_send_email.return_value = "token"
# Act
with app.test_request_context(
"/email-code-login", method="POST", json={"email": "test@example.com", "language": language_input}
):
api = EmailCodeLoginSendEmailApi()
api.post()
# Assert
call_args = mock_send_email.call_args
assert call_args.kwargs["language"] == expected_language
class TestEmailCodeLoginApi:
"""Test cases for email code verification and login."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_account(self):
"""Create mock account object."""
account = MagicMock()
account.email = "test@example.com"
account.name = "Test User"
return account
@pytest.fixture
def mock_token_pair(self):
"""Create mock token pair object."""
token_pair = MagicMock()
token_pair.access_token = "access_token"
token_pair.refresh_token = "refresh_token"
token_pair.csrf_token = "csrf_token"
return token_pair
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
@patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token")
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.AccountService.login")
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
def test_email_code_login_existing_user(
self,
mock_reset_rate_limit,
mock_login,
mock_get_tenants,
mock_get_user,
mock_revoke_token,
mock_get_data,
mock_db,
app,
mock_account,
mock_token_pair,
):
"""
Test successful email code login for existing user.
Verifies that:
- Email and code are validated
- Token is revoked after use
- User is logged in with token pair
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()]
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "valid_token"},
):
api = EmailCodeLoginApi()
response = api.post()
# Assert
assert response.json["result"] == "success"
mock_revoke_token.assert_called_once_with("valid_token")
mock_login.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
@patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token")
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
@patch("controllers.console.auth.login.AccountService.create_account_and_tenant")
@patch("controllers.console.auth.login.AccountService.login")
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
def test_email_code_login_new_user_creates_account(
self,
mock_reset_rate_limit,
mock_login,
mock_create_account,
mock_get_user,
mock_revoke_token,
mock_get_data,
mock_db,
app,
mock_account,
mock_token_pair,
):
"""
Test email code login creates new account for new user.
Verifies that:
- New account is created when user doesn't exist
- Workspace is created for new user
- User is logged in after account creation
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"}
mock_get_user.return_value = None
mock_create_account.return_value = mock_account
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "newuser@example.com", "code": "123456", "token": "valid_token", "language": "en-US"},
):
api = EmailCodeLoginApi()
response = api.post()
# Assert
assert response.json["result"] == "success"
mock_create_account.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
def test_email_code_login_invalid_token(self, mock_get_data, mock_db, app):
"""
Test email code login with invalid token.
Verifies that:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = None
# Act & Assert
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "invalid_token"},
):
api = EmailCodeLoginApi()
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
def test_email_code_login_email_mismatch(self, mock_get_data, mock_db, app):
"""
Test email code login with mismatched email.
Verifies that:
- InvalidEmailError is raised when email doesn't match token
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
# Act & Assert
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "different@example.com", "code": "123456", "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(InvalidEmailError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
def test_email_code_login_wrong_code(self, mock_get_data, mock_db, app):
"""
Test email code login with incorrect code.
Verifies that:
- EmailCodeError is raised for wrong verification code
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
# Act & Assert
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "wrong_code", "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(EmailCodeError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
@patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token")
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.FeatureService.get_system_features")
def test_email_code_login_creates_workspace_for_user_without_tenant(
self,
mock_get_features,
mock_get_tenants,
mock_get_user,
mock_revoke_token,
mock_get_data,
mock_db,
app,
mock_account,
):
"""
Test email code login creates workspace for user without tenant.
Verifies that:
- Workspace is created when user has no tenants
- User is added as owner of new workspace
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []
mock_features = MagicMock()
mock_features.is_allow_create_workspace = True
mock_features.license.workspaces.is_available.return_value = True
mock_get_features.return_value = mock_features
# Act & Assert - Should not raise WorkspacesLimitExceeded
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "token"},
):
api = EmailCodeLoginApi()
# This would complete the flow, but we're testing workspace creation logic
# In real implementation, TenantService.create_tenant would be called
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
@patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token")
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.FeatureService.get_system_features")
def test_email_code_login_workspace_limit_exceeded(
self,
mock_get_features,
mock_get_tenants,
mock_get_user,
mock_revoke_token,
mock_get_data,
mock_db,
app,
mock_account,
):
"""
Test email code login fails when workspace limit exceeded.
Verifies that:
- WorkspacesLimitExceeded is raised when limit reached
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []
mock_features = MagicMock()
mock_features.license.workspaces.is_available.return_value = False
mock_get_features.return_value = mock_features
# Act & Assert
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(WorkspacesLimitExceeded):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
@patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token")
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.FeatureService.get_system_features")
def test_email_code_login_workspace_creation_not_allowed(
self,
mock_get_features,
mock_get_tenants,
mock_get_user,
mock_revoke_token,
mock_get_data,
mock_db,
app,
mock_account,
):
"""
Test email code login fails when workspace creation not allowed.
Verifies that:
- NotAllowedCreateWorkspace is raised when creation disabled
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_get_user.return_value = mock_account
mock_get_tenants.return_value = []
mock_features = MagicMock()
mock_features.is_allow_create_workspace = False
mock_get_features.return_value = mock_features
# Act & Assert
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(NotAllowedCreateWorkspace):
api.post()

View File

@@ -0,0 +1,433 @@
"""
Test suite for login and logout authentication flows.
This module tests the core authentication endpoints including:
- Email/password login with rate limiting
- Session management and logout
- Cookie-based token handling
- Account status validation
"""
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_restx import Api
from controllers.console.auth.error import (
AuthenticationFailedError,
EmailPasswordLoginLimitError,
InvalidEmailError,
)
from controllers.console.auth.login import LoginApi, LogoutApi
from controllers.console.error import (
AccountBannedError,
AccountInFreezeError,
WorkspacesLimitExceeded,
)
from services.errors.account import AccountLoginError, AccountPasswordError
class TestLoginApi:
"""Test cases for the LoginApi endpoint."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def api(self, app):
"""Create Flask-RESTX API instance."""
return Api(app)
@pytest.fixture
def client(self, app, api):
"""Create test client."""
api.add_resource(LoginApi, "/login")
return app.test_client()
@pytest.fixture
def mock_account(self):
"""Create mock account object."""
account = MagicMock()
account.id = "test-account-id"
account.email = "test@example.com"
account.name = "Test User"
return account
@pytest.fixture
def mock_token_pair(self):
"""Create mock token pair object."""
token_pair = MagicMock()
token_pair.access_token = "mock_access_token"
token_pair.refresh_token = "mock_refresh_token"
token_pair.csrf_token = "mock_csrf_token"
return token_pair
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.AccountService.login")
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
def test_successful_login_without_invitation(
self,
mock_reset_rate_limit,
mock_login,
mock_get_tenants,
mock_authenticate,
mock_get_invitation,
mock_is_rate_limit,
mock_db,
app,
mock_account,
mock_token_pair,
):
"""
Test successful login flow without invitation token.
Verifies that:
- Valid credentials authenticate successfully
- Tokens are generated and set in cookies
- Rate limit is reset after successful login
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()] # Has at least one tenant
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"}
):
login_api = LoginApi()
response = login_api.post()
# Assert
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!")
mock_login.assert_called_once()
mock_reset_rate_limit.assert_called_once_with("test@example.com")
assert response.json["result"] == "success"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.AccountService.login")
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
def test_successful_login_with_valid_invitation(
self,
mock_reset_rate_limit,
mock_login,
mock_get_tenants,
mock_authenticate,
mock_get_invitation,
mock_is_rate_limit,
mock_db,
app,
mock_account,
mock_token_pair,
):
"""
Test successful login with valid invitation token.
Verifies that:
- Invitation token is validated
- Email matches invitation email
- Authentication proceeds with invitation token
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = {"data": {"email": "test@example.com"}}
mock_authenticate.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()]
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
"/login",
method="POST",
json={"email": "test@example.com", "password": "ValidPass123!", "invite_token": "valid_token"},
):
login_api = LoginApi()
response = login_api.post()
# Assert
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", "valid_token")
assert response.json["result"] == "success"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
"""
Test login rejection when rate limit is exceeded.
Verifies that:
- Rate limit check is performed before authentication
- EmailPasswordLoginLimitError is raised when limit exceeded
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = True
mock_get_invitation.return_value = None
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "password"}
):
login_api = LoginApi()
with pytest.raises(EmailPasswordLoginLimitError):
login_api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True)
@patch("controllers.console.auth.login.BillingService.is_email_in_freeze")
def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app):
"""
Test login rejection for frozen accounts.
Verifies that:
- Billing freeze status is checked when billing enabled
- AccountInFreezeError is raised for frozen accounts
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_frozen.return_value = True
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "frozen@example.com", "password": "password"}
):
login_api = LoginApi()
with pytest.raises(AccountInFreezeError):
login_api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
def test_login_fails_with_invalid_credentials(
self,
mock_add_rate_limit,
mock_authenticate,
mock_get_invitation,
mock_is_rate_limit,
mock_db,
app,
):
"""
Test login failure with invalid credentials.
Verifies that:
- AuthenticationFailedError is raised for wrong password
- Login error rate limit counter is incremented
- Generic error message prevents user enumeration
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "WrongPass123!"}
):
login_api = LoginApi()
with pytest.raises(AuthenticationFailedError):
login_api.post()
mock_add_rate_limit.assert_called_once_with("test@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.AccountService.authenticate")
def test_login_fails_for_banned_account(
self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app
):
"""
Test login rejection for banned accounts.
Verifies that:
- AccountBannedError is raised for banned accounts
- Login is prevented even with valid credentials
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.side_effect = AccountLoginError("Account is banned")
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "banned@example.com", "password": "ValidPass123!"}
):
login_api = LoginApi()
with pytest.raises(AccountBannedError):
login_api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.FeatureService.get_system_features")
def test_login_fails_when_no_workspace_and_limit_exceeded(
self,
mock_get_features,
mock_get_tenants,
mock_authenticate,
mock_get_invitation,
mock_is_rate_limit,
mock_db,
app,
mock_account,
):
"""
Test login failure when user has no workspace and workspace limit exceeded.
Verifies that:
- WorkspacesLimitExceeded is raised when limit reached
- User cannot login without an assigned workspace
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = None
mock_authenticate.return_value = mock_account
mock_get_tenants.return_value = [] # No tenants
mock_features = MagicMock()
mock_features.is_allow_create_workspace = True
mock_features.license.workspaces.is_available.return_value = False
mock_get_features.return_value = mock_features
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"}
):
login_api = LoginApi()
with pytest.raises(WorkspacesLimitExceeded):
login_api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
"""
Test login failure when invitation email doesn't match login email.
Verifies that:
- InvalidEmailError is raised for email mismatch
- Security check prevents invitation token abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}}
# Act & Assert
with app.test_request_context(
"/login",
method="POST",
json={"email": "different@example.com", "password": "ValidPass123!", "invite_token": "token"},
):
login_api = LoginApi()
with pytest.raises(InvalidEmailError):
login_api.post()
class TestLogoutApi:
"""Test cases for the LogoutApi endpoint."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_account(self):
"""Create mock account object."""
account = MagicMock()
account.id = "test-account-id"
account.email = "test@example.com"
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.current_account_with_tenant")
@patch("controllers.console.auth.login.AccountService.logout")
@patch("controllers.console.auth.login.flask_login.logout_user")
def test_successful_logout(
self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app, mock_account
):
"""
Test successful logout flow.
Verifies that:
- User session is terminated
- AccountService.logout is called
- All authentication cookies are cleared
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_current_account.return_value = (mock_account, MagicMock())
# Act
with app.test_request_context("/logout", method="POST"):
logout_api = LogoutApi()
response = logout_api.post()
# Assert
mock_service_logout.assert_called_once_with(account=mock_account)
mock_logout_user.assert_called_once()
assert response.json["result"] == "success"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.current_account_with_tenant")
@patch("controllers.console.auth.login.flask_login")
def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app):
"""
Test logout for anonymous (not logged in) user.
Verifies that:
- Anonymous users can call logout endpoint
- No errors are raised
- Success response is returned
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
# Create a mock anonymous user that will pass isinstance check
anonymous_user = MagicMock()
mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {})
anonymous_user.__class__ = mock_flask_login.AnonymousUserMixin
mock_current_account.return_value = (anonymous_user, None)
# Act
with app.test_request_context("/logout", method="POST"):
logout_api = LogoutApi()
response = logout_api.post()
# Assert
assert response.json["result"] == "success"

View File

@@ -0,0 +1,500 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.oauth import (
OAuthCallback,
OAuthLogin,
_generate_account,
_get_account_by_openid_or_email,
get_oauth_providers,
)
from libs.oauth import OAuthUserInfo
from models.account import AccountStatus
from services.errors.account import AccountRegisterError
class TestGetOAuthProviders:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.mark.parametrize(
("github_config", "google_config", "expected_github", "expected_google"),
[
# Both providers configured
(
{"id": "github_id", "secret": "github_secret"},
{"id": "google_id", "secret": "google_secret"},
True,
True,
),
# Only GitHub configured
({"id": "github_id", "secret": "github_secret"}, {"id": None, "secret": None}, True, False),
# Only Google configured
({"id": None, "secret": None}, {"id": "google_id", "secret": "google_secret"}, False, True),
# No providers configured
({"id": None, "secret": None}, {"id": None, "secret": None}, False, False),
],
)
@patch("controllers.console.auth.oauth.dify_config")
def test_should_configure_oauth_providers_correctly(
self, mock_config, app, github_config, google_config, expected_github, expected_google
):
mock_config.GITHUB_CLIENT_ID = github_config["id"]
mock_config.GITHUB_CLIENT_SECRET = github_config["secret"]
mock_config.GOOGLE_CLIENT_ID = google_config["id"]
mock_config.GOOGLE_CLIENT_SECRET = google_config["secret"]
mock_config.CONSOLE_API_URL = "http://localhost"
with app.app_context():
providers = get_oauth_providers()
assert (providers["github"] is not None) == expected_github
assert (providers["google"] is not None) == expected_google
class TestOAuthLogin:
@pytest.fixture
def resource(self):
return OAuthLogin()
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_oauth_provider(self):
provider = MagicMock()
provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?..."
return provider
@pytest.mark.parametrize(
("invite_token", "expected_token"),
[
(None, None),
("test_invite_token", "test_invite_token"),
("", None),
],
)
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth.redirect")
def test_should_handle_oauth_login_with_various_tokens(
self,
mock_redirect,
mock_get_providers,
resource,
app,
mock_oauth_provider,
invite_token,
expected_token,
):
mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
query_string = f"invite_token={invite_token}" if invite_token else ""
with app.test_request_context(f"/auth/oauth/github?{query_string}"):
resource.get("github")
mock_oauth_provider.get_authorization_url.assert_called_once_with(invite_token=expected_token)
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
@pytest.mark.parametrize(
("provider", "expected_error"),
[
("invalid_provider", "Invalid provider"),
("github", "Invalid provider"), # When GitHub is not configured
("google", "Invalid provider"), # When Google is not configured
],
)
@patch("controllers.console.auth.oauth.get_oauth_providers")
def test_should_return_error_for_invalid_providers(
self, mock_get_providers, resource, app, provider, expected_error
):
mock_get_providers.return_value = {"github": None, "google": None}
with app.test_request_context(f"/auth/oauth/{provider}"):
response, status_code = resource.get(provider)
assert status_code == 400
assert response["error"] == expected_error
class TestOAuthCallback:
@pytest.fixture
def resource(self):
return OAuthCallback()
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def oauth_setup(self):
"""Common OAuth setup for callback tests"""
oauth_provider = MagicMock()
oauth_provider.get_access_token.return_value = "access_token"
oauth_provider.get_user_info.return_value = OAuthUserInfo(id="123", name="Test User", email="test@example.com")
account = MagicMock()
account.status = AccountStatus.ACTIVE
token_pair = MagicMock()
token_pair.access_token = "jwt_access_token"
token_pair.refresh_token = "jwt_refresh_token"
return {"provider": oauth_provider, "account": account, "token_pair": token_pair}
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.redirect")
def test_should_handle_successful_oauth_callback(
self,
mock_redirect,
mock_tenant_service,
mock_account_service,
mock_generate_account,
mock_get_providers,
mock_config,
resource,
app,
oauth_setup,
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
mock_generate_account.return_value = oauth_setup["account"]
mock_account_service.login.return_value = oauth_setup["token_pair"]
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
resource.get("github")
oauth_setup["provider"].get_access_token.assert_called_once_with("test_code")
oauth_setup["provider"].get_user_info.assert_called_once_with("access_token")
mock_redirect.assert_called_once_with("http://localhost:3000")
@pytest.mark.parametrize(
("exception", "expected_error"),
[
(Exception("OAuth error"), "OAuth process failed"),
(ValueError("Invalid token"), "OAuth process failed"),
(KeyError("Missing key"), "OAuth process failed"),
],
)
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.get_oauth_providers")
def test_should_handle_oauth_exceptions(
self, mock_get_providers, mock_db, resource, app, exception, expected_error
):
# Mock database session
mock_db.session = MagicMock()
mock_db.session.rollback = MagicMock()
# Import the real requests module to create a proper exception
import httpx
request_exception = httpx.RequestError("OAuth error")
request_exception.response = MagicMock()
request_exception.response.text = str(exception)
mock_oauth_provider = MagicMock()
mock_oauth_provider.get_access_token.side_effect = request_exception
mock_get_providers.return_value = {"github": mock_oauth_provider}
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
response, status_code = resource.get("github")
assert status_code == 400
assert response["error"] == expected_error
@pytest.mark.parametrize(
("account_status", "expected_redirect"),
[
(AccountStatus.BANNED, "http://localhost:3000/signin?message=Account is banned."),
# CLOSED status: Currently NOT handled, will proceed to login (security issue)
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
(
AccountStatus.CLOSED.value,
"http://localhost:3000",
),
],
)
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.redirect")
def test_should_redirect_based_on_account_status(
self,
mock_redirect,
mock_generate_account,
mock_get_providers,
mock_config,
mock_db,
mock_tenant_service,
mock_account_service,
resource,
app,
oauth_setup,
account_status,
expected_redirect,
):
# Mock database session
mock_db.session = MagicMock()
mock_db.session.rollback = MagicMock()
mock_db.session.commit = MagicMock()
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
account = MagicMock()
account.status = account_status
account.id = "123"
mock_generate_account.return_value = account
# Mock login for CLOSED status
mock_token_pair = MagicMock()
mock_token_pair.access_token = "jwt_access_token"
mock_token_pair.refresh_token = "jwt_refresh_token"
mock_token_pair.csrf_token = "csrf_token"
mock_account_service.login.return_value = mock_token_pair
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
resource.get("github")
mock_redirect.assert_called_once_with(expected_redirect)
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.AccountService")
def test_should_activate_pending_account(
self,
mock_account_service,
mock_tenant_service,
mock_db,
mock_generate_account,
mock_get_providers,
mock_config,
resource,
app,
oauth_setup,
):
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
mock_account = MagicMock()
mock_account.status = AccountStatus.PENDING
mock_generate_account.return_value = mock_account
mock_token_pair = MagicMock()
mock_token_pair.access_token = "jwt_access_token"
mock_token_pair.refresh_token = "jwt_refresh_token"
mock_token_pair.csrf_token = "csrf_token"
mock_account_service.login.return_value = mock_token_pair
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
resource.get("github")
assert mock_account.status == AccountStatus.ACTIVE
assert mock_account.initialized_at is not None
mock_db.session.commit.assert_called_once()
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.redirect")
def test_defensive_check_for_closed_account_status(
self,
mock_redirect,
mock_account_service,
mock_tenant_service,
mock_db,
mock_generate_account,
mock_get_providers,
mock_config,
resource,
app,
oauth_setup,
):
"""Defensive test for CLOSED account status handling in OAuth callback.
This is a defensive test documenting expected security behavior for CLOSED accounts.
Current behavior: CLOSED status is NOT checked, allowing closed accounts to login.
Expected behavior: CLOSED accounts should be rejected like BANNED accounts.
Context:
- AccountStatus.CLOSED is defined in the enum but never used in production
- The close_account() method exists but is never called
- Account deletion uses external service instead of status change
- All authentication services (OAuth, password, email) don't check CLOSED status
TODO: If CLOSED status is implemented in the future:
1. Update OAuth callback to check for CLOSED status
2. Add similar checks to all authentication services for consistency
3. Update this test to verify the rejection behavior
Security consideration: Until properly implemented, CLOSED status provides no protection.
"""
# Setup
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
# Create account with CLOSED status
closed_account = MagicMock()
closed_account.status = AccountStatus.CLOSED
closed_account.id = "123"
closed_account.name = "Closed Account"
mock_generate_account.return_value = closed_account
# Mock successful login (current behavior)
mock_token_pair = MagicMock()
mock_token_pair.access_token = "jwt_access_token"
mock_token_pair.refresh_token = "jwt_refresh_token"
mock_token_pair.csrf_token = "csrf_token"
mock_account_service.login.return_value = mock_token_pair
# Execute OAuth callback
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
resource.get("github")
# Verify current behavior: login succeeds (this is NOT ideal)
mock_redirect.assert_called_once_with("http://localhost:3000")
mock_account_service.login.assert_called_once()
# Document expected behavior in comments:
# Expected: mock_redirect.assert_called_once_with(
# "http://localhost:3000/signin?message=Account is closed."
# )
# Expected: mock_account_service.login.assert_not_called()
class TestAccountGeneration:
@pytest.fixture
def user_info(self):
return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
@pytest.fixture
def mock_account(self):
account = MagicMock()
account.name = "Test User"
return account
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.Account")
@patch("controllers.console.auth.oauth.Session")
@patch("controllers.console.auth.oauth.select")
def test_should_get_account_by_openid_or_email(
self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account
):
# Mock db.engine for Session creation
mock_db.engine = MagicMock()
# Test OpenID found
mock_account_model.get_by_openid.return_value = mock_account
result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account
mock_account_model.get_by_openid.assert_called_once_with("github", "123")
# Test fallback to email
mock_account_model.get_by_openid.return_value = None
mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
mock_session.return_value.__enter__.return_value = mock_session_instance
result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account
@pytest.mark.parametrize(
("allow_register", "existing_account", "should_create"),
[
(True, None, True), # New account creation allowed
(True, "existing", False), # Existing account
(False, None, False), # Registration not allowed
],
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
def test_should_handle_account_generation_scenarios(
self,
mock_db,
mock_tenant_service,
mock_account_service,
mock_register_service,
mock_feature_service,
mock_get_account,
app,
user_info,
mock_account,
allow_register,
existing_account,
should_create,
):
mock_get_account.return_value = mock_account if existing_account else None
mock_feature_service.get_system_features.return_value.is_allow_register = allow_register
mock_register_service.register.return_value = mock_account
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
if not allow_register and not existing_account:
with pytest.raises(AccountRegisterError):
_generate_account("github", user_info)
else:
result = _generate_account("github", user_info)
assert result == mock_account
if should_create:
mock_register_service.register.assert_called_once_with(
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.tenant_was_created")
def test_should_create_workspace_for_account_without_tenant(
self,
mock_event,
mock_account_service,
mock_feature_service,
mock_tenant_service,
mock_get_account,
app,
user_info,
mock_account,
):
mock_get_account.return_value = mock_account
mock_tenant_service.get_join_tenants.return_value = []
mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True
mock_new_tenant = MagicMock()
mock_tenant_service.create_tenant.return_value = mock_new_tenant
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
result = _generate_account("github", user_info)
assert result == mock_account
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
mock_tenant_service.create_tenant_member.assert_called_once_with(
mock_new_tenant, mock_account, role="owner"
)
mock_event.send.assert_called_once_with(mock_new_tenant)

View File

@@ -0,0 +1,508 @@
"""
Test suite for password reset authentication flows.
This module tests the password reset mechanism including:
- Password reset email sending
- Verification code validation
- Password reset with token
- Rate limiting and security checks
"""
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.error import (
EmailCodeError,
EmailPasswordResetLimitError,
InvalidEmailError,
InvalidTokenError,
PasswordMismatchError,
)
from controllers.console.auth.forgot_password import (
ForgotPasswordCheckApi,
ForgotPasswordResetApi,
ForgotPasswordSendEmailApi,
)
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
class TestForgotPasswordSendEmailApi:
"""Test cases for sending password reset emails."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_account(self):
"""Create mock account object."""
account = MagicMock()
account.email = "test@example.com"
account.name = "Test User"
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password.select")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@patch("controllers.console.auth.forgot_password.FeatureService.get_system_features")
def test_send_reset_email_success(
self,
mock_get_features,
mock_send_email,
mock_select,
mock_session,
mock_is_ip_limit,
mock_forgot_db,
mock_wraps_db,
app,
mock_account,
):
"""
Test successful password reset email sending.
Verifies that:
- Email is sent to valid account
- Reset token is generated and returned
- IP rate limiting is checked
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_forgot_db.engine = MagicMock()
mock_is_ip_limit.return_value = False
mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_send_email.return_value = "reset_token_123"
mock_get_features.return_value.is_allow_register = True
# Act
with app.test_request_context(
"/forgot-password", method="POST", json={"email": "test@example.com", "language": "en-US"}
):
api = ForgotPasswordSendEmailApi()
response = api.post()
# Assert
assert response["result"] == "success"
assert response["data"] == "reset_token_123"
mock_send_email.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app):
"""
Test password reset email blocked by IP rate limit.
Verifies that:
- EmailSendIpLimitError is raised when IP limit exceeded
- No email is sent when rate limited
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_ip_limit.return_value = True
# Act & Assert
with app.test_request_context("/forgot-password", method="POST", json={"email": "test@example.com"}):
api = ForgotPasswordSendEmailApi()
with pytest.raises(EmailSendIpLimitError):
api.post()
@pytest.mark.parametrize(
("language_input", "expected_language"),
[
("zh-Hans", "zh-Hans"),
("en-US", "en-US"),
("fr-FR", "en-US"), # Defaults to en-US for unsupported
(None, "en-US"), # Defaults to en-US when not provided
],
)
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password.select")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@patch("controllers.console.auth.forgot_password.FeatureService.get_system_features")
def test_send_reset_email_language_handling(
self,
mock_get_features,
mock_send_email,
mock_select,
mock_session,
mock_is_ip_limit,
mock_forgot_db,
mock_wraps_db,
app,
mock_account,
language_input,
expected_language,
):
"""
Test password reset email with different language preferences.
Verifies that:
- Language parameter is correctly processed
- Unsupported languages default to en-US
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_forgot_db.engine = MagicMock()
mock_is_ip_limit.return_value = False
mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_send_email.return_value = "token"
mock_get_features.return_value.is_allow_register = True
# Act
with app.test_request_context(
"/forgot-password", method="POST", json={"email": "test@example.com", "language": language_input}
):
api = ForgotPasswordSendEmailApi()
api.post()
# Assert
call_args = mock_send_email.call_args
assert call_args.kwargs["language"] == expected_language
class TestForgotPasswordCheckApi:
"""Test cases for verifying password reset codes."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
def test_verify_code_success(
self,
mock_reset_rate_limit,
mock_generate_token,
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
mock_db,
app,
):
"""
Test successful verification code validation.
Verifies that:
- Valid code is accepted
- Old token is revoked
- New token is generated for reset phase
- Rate limit is reset on success
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
mock_generate_token.return_value = (None, "new_token")
# Act
with app.test_request_context(
"/forgot-password/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "old_token"},
):
api = ForgotPasswordCheckApi()
response = api.post()
# Assert
assert response["is_valid"] is True
assert response["email"] == "test@example.com"
assert response["token"] == "new_token"
mock_revoke_token.assert_called_once_with("old_token")
mock_reset_rate_limit.assert_called_once_with("test@example.com")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
"""
Test code verification blocked by rate limit.
Verifies that:
- EmailPasswordResetLimitError is raised when limit exceeded
- Prevents brute force attacks on verification codes
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = True
# Act & Assert
with app.test_request_context(
"/forgot-password/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "token"},
):
api = ForgotPasswordCheckApi()
with pytest.raises(EmailPasswordResetLimitError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app):
"""
Test code verification with invalid token.
Verifies that:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = None
# Act & Assert
with app.test_request_context(
"/forgot-password/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "invalid_token"},
):
api = ForgotPasswordCheckApi()
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app):
"""
Test code verification with mismatched email.
Verifies that:
- InvalidEmailError is raised when email doesn't match token
- Prevents token abuse
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
# Act & Assert
with app.test_request_context(
"/forgot-password/validity",
method="POST",
json={"email": "different@example.com", "code": "123456", "token": "token"},
):
api = ForgotPasswordCheckApi()
with pytest.raises(InvalidEmailError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app):
"""
Test code verification with incorrect code.
Verifies that:
- EmailCodeError is raised for wrong code
- Rate limit counter is incremented
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
# Act & Assert
with app.test_request_context(
"/forgot-password/validity",
method="POST",
json={"email": "test@example.com", "code": "wrong_code", "token": "token"},
):
api = ForgotPasswordCheckApi()
with pytest.raises(EmailCodeError):
api.post()
mock_add_rate_limit.assert_called_once_with("test@example.com")
class TestForgotPasswordResetApi:
"""Test cases for resetting password with verified token."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_account(self):
"""Create mock account object."""
account = MagicMock()
account.email = "test@example.com"
account.name = "Test User"
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password.select")
@patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants")
def test_reset_password_success(
self,
mock_get_tenants,
mock_select,
mock_session,
mock_revoke_token,
mock_get_data,
mock_forgot_db,
mock_wraps_db,
app,
mock_account,
):
"""
Test successful password reset.
Verifies that:
- Password is updated with new hashed value
- Token is revoked after use
- Success response is returned
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_forgot_db.engine = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_get_tenants.return_value = [MagicMock()]
# Act
with app.test_request_context(
"/forgot-password/resets",
method="POST",
json={"token": "valid_token", "new_password": "NewPass123!", "password_confirm": "NewPass123!"},
):
api = ForgotPasswordResetApi()
response = api.post()
# Assert
assert response["result"] == "success"
mock_revoke_token.assert_called_once_with("valid_token")
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_mismatch(self, mock_get_data, mock_db, app):
"""
Test password reset with mismatched passwords.
Verifies that:
- PasswordMismatchError is raised when passwords don't match
- No password update occurs
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
# Act & Assert
with app.test_request_context(
"/forgot-password/resets",
method="POST",
json={"token": "token", "new_password": "NewPass123!", "password_confirm": "DifferentPass123!"},
):
api = ForgotPasswordResetApi()
with pytest.raises(PasswordMismatchError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_invalid_token(self, mock_get_data, mock_db, app):
"""
Test password reset with invalid token.
Verifies that:
- InvalidTokenError is raised for invalid/expired tokens
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = None
# Act & Assert
with app.test_request_context(
"/forgot-password/resets",
method="POST",
json={"token": "invalid_token", "new_password": "NewPass123!", "password_confirm": "NewPass123!"},
):
api = ForgotPasswordResetApi()
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app):
"""
Test password reset with token not in reset phase.
Verifies that:
- InvalidTokenError is raised when token is not in reset phase
- Prevents use of verification-phase tokens for reset
"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock()
mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"}
# Act & Assert
with app.test_request_context(
"/forgot-password/resets",
method="POST",
json={"token": "token", "new_password": "NewPass123!", "password_confirm": "NewPass123!"},
):
api = ForgotPasswordResetApi()
with pytest.raises(InvalidTokenError):
api.post()
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password.select")
def test_reset_password_account_not_found(
self, mock_select, mock_session, mock_revoke_token, mock_get_data, mock_forgot_db, mock_wraps_db, app
):
"""
Test password reset for non-existent account.
Verifies that:
- AccountNotFound is raised when account doesn't exist
"""
# Arrange
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
mock_forgot_db.engine = MagicMock()
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
mock_session.return_value.__enter__.return_value = mock_session_instance
# Act & Assert
with app.test_request_context(
"/forgot-password/resets",
method="POST",
json={"token": "token", "new_password": "NewPass123!", "password_confirm": "NewPass123!"},
):
api = ForgotPasswordResetApi()
with pytest.raises(AccountNotFound):
api.post()

View File

@@ -0,0 +1,198 @@
"""
Test suite for token refresh authentication flows.
This module tests the token refresh mechanism including:
- Access token refresh using refresh token
- Cookie-based token extraction and renewal
- Token expiration and validation
- Error handling for invalid tokens
"""
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_restx import Api
from controllers.console.auth.login import RefreshTokenApi
class TestRefreshTokenApi:
"""Test cases for the RefreshTokenApi endpoint."""
@pytest.fixture
def app(self):
"""Create Flask test application."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def api(self, app):
"""Create Flask-RESTX API instance."""
return Api(app)
@pytest.fixture
def client(self, app, api):
"""Create test client."""
api.add_resource(RefreshTokenApi, "/refresh-token")
return app.test_client()
@pytest.fixture
def mock_token_pair(self):
"""Create mock token pair object."""
token_pair = MagicMock()
token_pair.access_token = "new_access_token"
token_pair.refresh_token = "new_refresh_token"
token_pair.csrf_token = "new_csrf_token"
return token_pair
@patch("controllers.console.auth.login.extract_refresh_token")
@patch("controllers.console.auth.login.AccountService.refresh_token")
def test_successful_token_refresh(self, mock_refresh_token, mock_extract_token, app, mock_token_pair):
"""
Test successful token refresh flow.
Verifies that:
- Refresh token is extracted from cookies
- New token pair is generated
- New tokens are set in response cookies
- Success response is returned
"""
# Arrange
mock_extract_token.return_value = "valid_refresh_token"
mock_refresh_token.return_value = mock_token_pair
# Act
with app.test_request_context("/refresh-token", method="POST"):
refresh_api = RefreshTokenApi()
response = refresh_api.post()
# Assert
mock_extract_token.assert_called_once()
mock_refresh_token.assert_called_once_with("valid_refresh_token")
assert response.json["result"] == "success"
@patch("controllers.console.auth.login.extract_refresh_token")
def test_refresh_fails_without_token(self, mock_extract_token, app):
"""
Test token refresh failure when no refresh token provided.
Verifies that:
- Error is returned when refresh token is missing
- 401 status code is returned
- Appropriate error message is provided
"""
# Arrange
mock_extract_token.return_value = None
# Act
with app.test_request_context("/refresh-token", method="POST"):
refresh_api = RefreshTokenApi()
response, status_code = refresh_api.post()
# Assert
assert status_code == 401
assert response["result"] == "fail"
assert "No refresh token provided" in response["message"]
@patch("controllers.console.auth.login.extract_refresh_token")
@patch("controllers.console.auth.login.AccountService.refresh_token")
def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app):
"""
Test token refresh failure with invalid refresh token.
Verifies that:
- Exception is caught when token is invalid
- 401 status code is returned
- Error message is included in response
"""
# Arrange
mock_extract_token.return_value = "invalid_refresh_token"
mock_refresh_token.side_effect = Exception("Invalid refresh token")
# Act
with app.test_request_context("/refresh-token", method="POST"):
refresh_api = RefreshTokenApi()
response, status_code = refresh_api.post()
# Assert
assert status_code == 401
assert response["result"] == "fail"
assert "Invalid refresh token" in response["message"]
@patch("controllers.console.auth.login.extract_refresh_token")
@patch("controllers.console.auth.login.AccountService.refresh_token")
def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app):
"""
Test token refresh failure with expired refresh token.
Verifies that:
- Expired tokens are rejected
- 401 status code is returned
- Appropriate error handling
"""
# Arrange
mock_extract_token.return_value = "expired_refresh_token"
mock_refresh_token.side_effect = Exception("Refresh token expired")
# Act
with app.test_request_context("/refresh-token", method="POST"):
refresh_api = RefreshTokenApi()
response, status_code = refresh_api.post()
# Assert
assert status_code == 401
assert response["result"] == "fail"
assert "expired" in response["message"].lower()
@patch("controllers.console.auth.login.extract_refresh_token")
@patch("controllers.console.auth.login.AccountService.refresh_token")
def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app):
"""
Test token refresh with empty string token.
Verifies that:
- Empty string is treated as no token
- 401 status code is returned
"""
# Arrange
mock_extract_token.return_value = ""
# Act
with app.test_request_context("/refresh-token", method="POST"):
refresh_api = RefreshTokenApi()
response, status_code = refresh_api.post()
# Assert
assert status_code == 401
assert response["result"] == "fail"
@patch("controllers.console.auth.login.extract_refresh_token")
@patch("controllers.console.auth.login.AccountService.refresh_token")
def test_refresh_updates_all_tokens(self, mock_refresh_token, mock_extract_token, app, mock_token_pair):
"""
Test that token refresh updates all three tokens.
Verifies that:
- Access token is updated
- Refresh token is rotated
- CSRF token is regenerated
"""
# Arrange
mock_extract_token.return_value = "valid_refresh_token"
mock_refresh_token.return_value = mock_token_pair
# Act
with app.test_request_context("/refresh-token", method="POST"):
refresh_api = RefreshTokenApi()
response = refresh_api.post()
# Assert
assert response.json["result"] == "success"
# Verify new token pair was generated
mock_refresh_token.assert_called_once_with("valid_refresh_token")
# In real implementation, cookies would be set with new values
assert mock_token_pair.access_token == "new_access_token"
assert mock_token_pair.refresh_token == "new_refresh_token"
assert mock_token_pair.csrf_token == "new_csrf_token"

View File

@@ -0,0 +1,253 @@
import base64
import json
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest
from controllers.console.billing.billing import PartnerTenants
from models.account import Account
class TestPartnerTenants:
"""Unit tests for PartnerTenants controller."""
@pytest.fixture
def app(self):
"""Create Flask app for testing."""
app = Flask(__name__)
app.config["TESTING"] = True
app.config["SECRET_KEY"] = "test-secret-key"
return app
@pytest.fixture
def mock_account(self):
"""Create a mock account."""
account = MagicMock(spec=Account)
account.id = "account-123"
account.email = "test@example.com"
account.current_tenant_id = "tenant-456"
account.is_authenticated = True
return account
@pytest.fixture
def mock_billing_service(self):
"""Mock BillingService."""
with patch("controllers.console.billing.billing.BillingService") as mock_service:
yield mock_service
@pytest.fixture
def mock_decorators(self):
"""Mock decorators to avoid database access."""
with (
patch("controllers.console.wraps.db") as mock_db,
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
patch("libs.login.dify_config.LOGIN_DISABLED", False),
patch("libs.login.check_csrf_token") as mock_csrf,
):
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
mock_csrf.return_value = None
yield {"db": mock_db, "csrf": mock_csrf}
def test_put_success(self, app, mock_account, mock_billing_service, mock_decorators):
"""Test successful partner tenants bindings sync."""
# Arrange
partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8")
click_id = "click-id-789"
expected_response = {"result": "success", "data": {"synced": True}}
mock_billing_service.sync_partner_tenants_bindings.return_value = expected_response
with app.test_request_context(
method="PUT",
json={"click_id": click_id},
path=f"/billing/partners/{partner_key_encoded}/tenants",
):
with (
patch(
"controllers.console.billing.billing.current_account_with_tenant",
return_value=(mock_account, "tenant-456"),
),
patch("libs.login._get_user", return_value=mock_account),
):
resource = PartnerTenants()
result = resource.put(partner_key_encoded)
# Assert
assert result == expected_response
mock_billing_service.sync_partner_tenants_bindings.assert_called_once_with(
mock_account.id, "partner-key-123", click_id
)
def test_put_invalid_partner_key_base64(self, app, mock_account, mock_billing_service, mock_decorators):
"""Test that invalid base64 partner_key raises BadRequest."""
# Arrange
invalid_partner_key = "invalid-base64-!@#$"
click_id = "click-id-789"
with app.test_request_context(
method="PUT",
json={"click_id": click_id},
path=f"/billing/partners/{invalid_partner_key}/tenants",
):
with (
patch(
"controllers.console.billing.billing.current_account_with_tenant",
return_value=(mock_account, "tenant-456"),
),
patch("libs.login._get_user", return_value=mock_account),
):
resource = PartnerTenants()
# Act & Assert
with pytest.raises(BadRequest) as exc_info:
resource.put(invalid_partner_key)
assert "Invalid partner_key" in str(exc_info.value)
def test_put_missing_click_id(self, app, mock_account, mock_billing_service, mock_decorators):
"""Test that missing click_id raises BadRequest."""
# Arrange
partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8")
with app.test_request_context(
method="PUT",
json={},
path=f"/billing/partners/{partner_key_encoded}/tenants",
):
with (
patch(
"controllers.console.billing.billing.current_account_with_tenant",
return_value=(mock_account, "tenant-456"),
),
patch("libs.login._get_user", return_value=mock_account),
):
resource = PartnerTenants()
# Act & Assert
# reqparse will raise BadRequest for missing required field
with pytest.raises(BadRequest):
resource.put(partner_key_encoded)
def test_put_billing_service_json_decode_error(self, app, mock_account, mock_billing_service, mock_decorators):
"""Test handling of billing service JSON decode error.
When billing service returns non-200 status code with invalid JSON response,
response.json() raises JSONDecodeError. This exception propagates to the controller
and should be handled by the global error handler (handle_general_exception),
which returns a 500 status code with error details.
Note: In unit tests, when directly calling resource.put(), the exception is raised
directly. In actual Flask application, the error handler would catch it and return
a 500 response with JSON: {"code": "unknown", "message": "...", "status": 500}
"""
# Arrange
partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8")
click_id = "click-id-789"
# Simulate JSON decode error when billing service returns invalid JSON
# This happens when billing service returns non-200 with empty/invalid response body
json_decode_error = json.JSONDecodeError("Expecting value", "", 0)
mock_billing_service.sync_partner_tenants_bindings.side_effect = json_decode_error
with app.test_request_context(
method="PUT",
json={"click_id": click_id},
path=f"/billing/partners/{partner_key_encoded}/tenants",
):
with (
patch(
"controllers.console.billing.billing.current_account_with_tenant",
return_value=(mock_account, "tenant-456"),
),
patch("libs.login._get_user", return_value=mock_account),
):
resource = PartnerTenants()
# Act & Assert
# JSONDecodeError will be raised from the controller
# In actual Flask app, this would be caught by handle_general_exception
# which returns: {"code": "unknown", "message": str(e), "status": 500}
with pytest.raises(json.JSONDecodeError) as exc_info:
resource.put(partner_key_encoded)
# Verify the exception is JSONDecodeError
assert isinstance(exc_info.value, json.JSONDecodeError)
assert "Expecting value" in str(exc_info.value)
def test_put_empty_click_id(self, app, mock_account, mock_billing_service, mock_decorators):
"""Test that empty click_id raises BadRequest."""
# Arrange
partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8")
click_id = ""
with app.test_request_context(
method="PUT",
json={"click_id": click_id},
path=f"/billing/partners/{partner_key_encoded}/tenants",
):
with (
patch(
"controllers.console.billing.billing.current_account_with_tenant",
return_value=(mock_account, "tenant-456"),
),
patch("libs.login._get_user", return_value=mock_account),
):
resource = PartnerTenants()
# Act & Assert
with pytest.raises(BadRequest) as exc_info:
resource.put(partner_key_encoded)
assert "Invalid partner information" in str(exc_info.value)
def test_put_empty_partner_key_after_decode(self, app, mock_account, mock_billing_service, mock_decorators):
"""Test that empty partner_key after decode raises BadRequest."""
# Arrange
# Base64 encode an empty string
empty_partner_key_encoded = base64.b64encode(b"").decode("utf-8")
click_id = "click-id-789"
with app.test_request_context(
method="PUT",
json={"click_id": click_id},
path=f"/billing/partners/{empty_partner_key_encoded}/tenants",
):
with (
patch(
"controllers.console.billing.billing.current_account_with_tenant",
return_value=(mock_account, "tenant-456"),
),
patch("libs.login._get_user", return_value=mock_account),
):
resource = PartnerTenants()
# Act & Assert
with pytest.raises(BadRequest) as exc_info:
resource.put(empty_partner_key_encoded)
assert "Invalid partner information" in str(exc_info.value)
def test_put_empty_user_id(self, app, mock_account, mock_billing_service, mock_decorators):
"""Test that empty user id raises BadRequest."""
# Arrange
partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8")
click_id = "click-id-789"
mock_account.id = None # Empty user id
with app.test_request_context(
method="PUT",
json={"click_id": click_id},
path=f"/billing/partners/{partner_key_encoded}/tenants",
):
with (
patch(
"controllers.console.billing.billing.current_account_with_tenant",
return_value=(mock_account, "tenant-456"),
),
patch("libs.login._get_user", return_value=mock_account),
):
resource = PartnerTenants()
# Act & Assert
with pytest.raises(BadRequest) as exc_info:
resource.put(partner_key_encoded)
assert "Invalid partner information" in str(exc_info.value)

View File

@@ -0,0 +1,278 @@
import io
from unittest.mock import patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
class TestFileUploadSecurity:
"""Test file upload security logic without complex framework setup"""
# Test 1: Basic file validation
def test_should_validate_file_presence(self):
"""Test that missing file is detected"""
from flask import Flask, request
app = Flask(__name__)
with app.test_request_context(method="POST", data={}):
# Simulate the check in FileApi.post()
if "file" not in request.files:
with pytest.raises(NoFileUploadedError):
raise NoFileUploadedError()
def test_should_validate_multiple_files(self):
"""Test that multiple files are rejected"""
from flask import Flask, request
app = Flask(__name__)
file_data = {
"file": (io.BytesIO(b"content1"), "file1.txt", "text/plain"),
"file2": (io.BytesIO(b"content2"), "file2.txt", "text/plain"),
}
with app.test_request_context(method="POST", data=file_data, content_type="multipart/form-data"):
# Simulate the check in FileApi.post()
if len(request.files) > 1:
with pytest.raises(TooManyFilesError):
raise TooManyFilesError()
def test_should_validate_empty_filename(self):
"""Test that empty filename is rejected"""
from flask import Flask, request
app = Flask(__name__)
file_data = {"file": (io.BytesIO(b"content"), "", "text/plain")}
with app.test_request_context(method="POST", data=file_data, content_type="multipart/form-data"):
file = request.files["file"]
if not file.filename:
with pytest.raises(FilenameNotExistsError):
raise FilenameNotExistsError
# Test 2: Security - Filename sanitization
def test_should_detect_path_traversal_in_filename(self):
"""Test protection against directory traversal attacks"""
dangerous_filenames = [
"../../../etc/passwd",
"..\\..\\windows\\system32\\config\\sam",
"../../../../etc/shadow",
"./../../../sensitive.txt",
]
for filename in dangerous_filenames:
# Any filename containing .. should be considered dangerous
assert ".." in filename, f"Filename {filename} should be detected as path traversal"
def test_should_detect_null_byte_injection(self):
"""Test protection against null byte injection"""
dangerous_filenames = [
"file.jpg\x00.php",
"document.pdf\x00.exe",
"image.png\x00.sh",
]
for filename in dangerous_filenames:
# Null bytes should be detected
assert "\x00" in filename, f"Filename {filename} should be detected as null byte injection"
def test_should_sanitize_special_characters(self):
"""Test that special characters in filenames are handled safely"""
# Characters that could be problematic in various contexts
dangerous_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\x00"]
for char in dangerous_chars:
filename = f"file{char}name.txt"
# These characters should be detected or sanitized
assert any(c in filename for c in dangerous_chars)
# Test 3: Permission validation
def test_should_validate_dataset_permissions(self):
"""Test dataset upload permission logic"""
class MockUser:
is_dataset_editor = False
user = MockUser()
source = "datasets"
# Simulate the permission check in FileApi.post()
if source == "datasets" and not user.is_dataset_editor:
with pytest.raises(Forbidden):
raise Forbidden()
def test_should_allow_general_upload_without_permission(self):
"""Test general upload doesn't require dataset permission"""
class MockUser:
is_dataset_editor = False
user = MockUser()
source = None # General upload
# This should not raise an exception
if source == "datasets" and not user.is_dataset_editor:
raise Forbidden()
# Test passes if no exception is raised
# Test 4: Service error handling
@patch("services.file_service.FileService.upload_file")
def test_should_handle_file_too_large_error(self, mock_upload):
"""Test that service FileTooLargeError is properly converted"""
mock_upload.side_effect = ServiceFileTooLargeError("File too large")
try:
mock_upload(filename="test.txt", content=b"data", mimetype="text/plain", user=None, source=None)
except ServiceFileTooLargeError as e:
# Simulate the error conversion in FileApi.post()
with pytest.raises(FileTooLargeError):
raise FileTooLargeError(e.description)
@patch("services.file_service.FileService.upload_file")
def test_should_handle_unsupported_file_type_error(self, mock_upload):
"""Test that service UnsupportedFileTypeError is properly converted"""
mock_upload.side_effect = ServiceUnsupportedFileTypeError()
try:
mock_upload(
filename="test.exe", content=b"data", mimetype="application/octet-stream", user=None, source=None
)
except ServiceUnsupportedFileTypeError:
# Simulate the error conversion in FileApi.post()
with pytest.raises(UnsupportedFileTypeError):
raise UnsupportedFileTypeError()
# Test 5: File type security
def test_should_identify_dangerous_file_extensions(self):
"""Test detection of potentially dangerous file extensions"""
dangerous_extensions = [
".php",
".PHP",
".pHp", # PHP files (case variations)
".exe",
".EXE", # Executables
".sh",
".SH", # Shell scripts
".bat",
".BAT", # Batch files
".cmd",
".CMD", # Command files
".ps1",
".PS1", # PowerShell
".jar",
".JAR", # Java archives
".vbs",
".VBS", # VBScript
]
safe_extensions = [".txt", ".pdf", ".jpg", ".png", ".doc", ".docx"]
# Just verify our test data is correct
for ext in dangerous_extensions:
assert ext.lower() in [".php", ".exe", ".sh", ".bat", ".cmd", ".ps1", ".jar", ".vbs"]
for ext in safe_extensions:
assert ext.lower() not in [".php", ".exe", ".sh", ".bat", ".cmd", ".ps1", ".jar", ".vbs"]
def test_should_detect_double_extensions(self):
"""Test detection of double extension attacks"""
suspicious_filenames = [
"image.jpg.php",
"document.pdf.exe",
"photo.png.sh",
"file.txt.bat",
]
for filename in suspicious_filenames:
# Check that these have multiple extensions
parts = filename.split(".")
assert len(parts) > 2, f"Filename {filename} should have multiple extensions"
# Test 6: Configuration validation
def test_upload_configuration_structure(self):
"""Test that upload configuration has correct structure"""
# Simulate the configuration returned by FileApi.get()
config = {
"file_size_limit": 15,
"batch_count_limit": 5,
"image_file_size_limit": 10,
"video_file_size_limit": 500,
"audio_file_size_limit": 50,
"workflow_file_upload_limit": 10,
}
# Verify all required fields are present
required_fields = [
"file_size_limit",
"batch_count_limit",
"image_file_size_limit",
"video_file_size_limit",
"audio_file_size_limit",
"workflow_file_upload_limit",
]
for field in required_fields:
assert field in config, f"Missing required field: {field}"
assert isinstance(config[field], int), f"Field {field} should be an integer"
assert config[field] > 0, f"Field {field} should be positive"
# Test 7: Source parameter handling
def test_source_parameter_normalization(self):
"""Test that source parameter is properly normalized"""
test_cases = [
("datasets", "datasets"),
("other", None),
("", None),
(None, None),
]
for input_source, expected in test_cases:
# Simulate the source normalization in FileApi.post()
source = "datasets" if input_source == "datasets" else None
if source not in ("datasets", None):
source = None
assert source == expected
# Test 8: Boundary conditions
def test_should_handle_edge_case_file_sizes(self):
"""Test handling of boundary file sizes"""
test_cases = [
(0, "Empty file"), # 0 bytes
(1, "Single byte"), # 1 byte
(15 * 1024 * 1024 - 1, "Just under limit"), # Just under 15MB
(15 * 1024 * 1024, "At limit"), # Exactly 15MB
(15 * 1024 * 1024 + 1, "Just over limit"), # Just over 15MB
]
for size, description in test_cases:
# Just verify our test data
assert isinstance(size, int), f"{description}: Size should be integer"
assert size >= 0, f"{description}: Size should be non-negative"
def test_should_handle_special_mime_types(self):
"""Test handling of various MIME types"""
mime_type_tests = [
("application/octet-stream", "Generic binary"),
("text/plain", "Plain text"),
("image/jpeg", "JPEG image"),
("application/pdf", "PDF document"),
("", "Empty MIME type"),
(None, "None MIME type"),
]
for mime_type, description in mime_type_tests:
# Verify test data structure
if mime_type is not None:
assert isinstance(mime_type, str), f"{description}: MIME type should be string or None"

View File

@@ -0,0 +1,396 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_login import LoginManager, UserMixin
from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
from controllers.console.workspace.error import AccountNotInitializedError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
enterprise_license_required,
only_edition_cloud,
only_edition_enterprise,
only_edition_self_hosted,
setup_required,
)
from models.account import AccountStatus
from services.feature_service import LicenseStatus
class MockUser(UserMixin):
"""Simple User class for testing."""
def __init__(self, user_id: str):
self.id = user_id
self.current_tenant_id = "tenant123"
def get_id(self) -> str:
return self.id
def create_app_with_login():
"""Create a Flask app with LoginManager configured."""
app = Flask(__name__)
app.config["SECRET_KEY"] = "test-secret-key"
login_manager = LoginManager()
login_manager.init_app(app)
@login_manager.user_loader
def load_user(user_id: str):
return MockUser(user_id)
return app
class TestAccountInitialization:
"""Test account initialization decorator"""
def test_should_allow_initialized_account(self):
"""Test that initialized accounts can access protected views"""
# Arrange
mock_user = MagicMock()
mock_user.status = AccountStatus.ACTIVE
@account_initialization_required
def protected_view():
return "success"
# Act
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")):
result = protected_view()
# Assert
assert result == "success"
def test_should_reject_uninitialized_account(self):
"""Test that uninitialized accounts raise AccountNotInitializedError"""
# Arrange
mock_user = MagicMock()
mock_user.status = AccountStatus.UNINITIALIZED
@account_initialization_required
def protected_view():
return "success"
# Act & Assert
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")):
with pytest.raises(AccountNotInitializedError):
protected_view()
class TestEditionChecks:
"""Test edition-specific decorators"""
def test_only_edition_cloud_allows_cloud_edition(self):
"""Test cloud edition decorator allows CLOUD edition"""
# Arrange
@only_edition_cloud
def cloud_view():
return "cloud_success"
# Act
with patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"):
result = cloud_view()
# Assert
assert result == "cloud_success"
def test_only_edition_cloud_rejects_other_editions(self):
"""Test cloud edition decorator rejects non-CLOUD editions"""
# Arrange
app = Flask(__name__)
@only_edition_cloud
def cloud_view():
return "cloud_success"
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
with pytest.raises(Exception) as exc_info:
cloud_view()
assert exc_info.value.code == 404
def test_only_edition_enterprise_allows_when_enabled(self):
"""Test enterprise edition decorator allows when ENTERPRISE_ENABLED is True"""
# Arrange
@only_edition_enterprise
def enterprise_view():
return "enterprise_success"
# Act
with patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True):
result = enterprise_view()
# Assert
assert result == "enterprise_success"
def test_only_edition_self_hosted_allows_self_hosted(self):
"""Test self-hosted edition decorator allows SELF_HOSTED edition"""
# Arrange
@only_edition_self_hosted
def self_hosted_view():
return "self_hosted_success"
# Act
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
result = self_hosted_view()
# Assert
assert result == "self_hosted_success"
class TestBillingResourceLimits:
"""Test billing resource limit decorators"""
def test_should_allow_when_under_resource_limit(self):
"""Test that requests are allowed when under resource limits"""
# Arrange
mock_features = MagicMock()
mock_features.billing.enabled = True
mock_features.members.limit = 10
mock_features.members.size = 5
@cloud_edition_billing_resource_check("members")
def add_member():
return "member_added"
# Act
with patch(
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = add_member()
# Assert
assert result == "member_added"
def test_should_reject_when_over_resource_limit(self):
"""Test that requests are rejected when over resource limits"""
# Arrange
app = create_app_with_login()
mock_features = MagicMock()
mock_features.billing.enabled = True
mock_features.members.limit = 10
mock_features.members.size = 10
@cloud_edition_billing_resource_check("members")
def add_member():
return "member_added"
# Act & Assert
with app.test_request_context():
with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
add_member()
assert exc_info.value.code == 403
assert "members has reached the limit" in str(exc_info.value.description)
def test_should_check_source_for_documents_limit(self):
"""Test document limit checks request source"""
# Arrange
app = create_app_with_login()
mock_features = MagicMock()
mock_features.billing.enabled = True
mock_features.documents_upload_quota.limit = 100
mock_features.documents_upload_quota.size = 100
@cloud_edition_billing_resource_check("documents")
def upload_document():
return "document_uploaded"
# Test 1: Should reject when source is datasets
with app.test_request_context("/?source=datasets"):
with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
upload_document()
assert exc_info.value.code == 403
# Test 2: Should allow when source is not datasets
with app.test_request_context("/?source=other"):
with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = upload_document()
assert result == "document_uploaded"
class TestRateLimiting:
"""Test rate limiting decorator"""
@patch("controllers.console.wraps.redis_client")
@patch("controllers.console.wraps.db")
def test_should_allow_requests_within_rate_limit(self, mock_db, mock_redis):
"""Test that requests within rate limit are allowed"""
# Arrange
mock_rate_limit = MagicMock()
mock_rate_limit.enabled = True
mock_rate_limit.limit = 10
mock_redis.zcard.return_value = 5 # 5 requests in window
@cloud_edition_billing_rate_limit_check("knowledge")
def knowledge_request():
return "knowledge_success"
# Act
with patch(
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):
result = knowledge_request()
# Assert
assert result == "knowledge_success"
mock_redis.zadd.assert_called_once()
mock_redis.zremrangebyscore.assert_called_once()
@patch("controllers.console.wraps.redis_client")
@patch("controllers.console.wraps.db")
def test_should_reject_requests_over_rate_limit(self, mock_db, mock_redis):
"""Test that requests over rate limit are rejected and logged"""
# Arrange
app = create_app_with_login()
mock_rate_limit = MagicMock()
mock_rate_limit.enabled = True
mock_rate_limit.limit = 10
mock_rate_limit.subscription_plan = "pro"
mock_redis.zcard.return_value = 11 # Over limit
mock_session = MagicMock()
mock_db.session = mock_session
@cloud_edition_billing_rate_limit_check("knowledge")
def knowledge_request():
return "knowledge_success"
# Act & Assert
with app.test_request_context():
with patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MockUser("test_user"), "tenant123"),
):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):
with pytest.raises(Exception) as exc_info:
knowledge_request()
# Verify error
assert exc_info.value.code == 403
assert "rate limit" in str(exc_info.value.description)
# Verify rate limit log was created
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
class TestSystemSetup:
"""Test system setup decorator"""
@patch("controllers.console.wraps.db")
def test_should_allow_when_setup_complete(self, mock_db):
"""Test that requests are allowed when setup is complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
@setup_required
def admin_view():
return "admin_success"
# Act
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
result = admin_view()
# Assert
assert result == "admin_success"
@patch("controllers.console.wraps.db")
@patch("controllers.console.wraps.os.environ.get")
def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db):
"""Test NotInitValidateError when INIT_PASSWORD is set but setup not complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = None # No setup
mock_environ_get.return_value = "some_password"
@setup_required
def admin_view():
return "admin_success"
# Act & Assert
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
with pytest.raises(NotInitValidateError):
admin_view()
@patch("controllers.console.wraps.db")
@patch("controllers.console.wraps.os.environ.get")
def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db):
"""Test NotSetupError when no INIT_PASSWORD and setup not complete"""
# Arrange
mock_db.session.query.return_value.first.return_value = None # No setup
mock_environ_get.return_value = None # No INIT_PASSWORD
@setup_required
def admin_view():
return "admin_success"
# Act & Assert
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
with pytest.raises(NotSetupError):
admin_view()
class TestEnterpriseLicense:
"""Test enterprise license decorator"""
def test_should_allow_with_valid_license(self):
"""Test that valid licenses allow access"""
# Arrange
mock_settings = MagicMock()
mock_settings.license.status = LicenseStatus.ACTIVE
@enterprise_license_required
def enterprise_feature():
return "enterprise_success"
# Act
with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
result = enterprise_feature()
# Assert
assert result == "enterprise_success"
@pytest.mark.parametrize("invalid_status", [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST])
def test_should_reject_with_invalid_license(self, invalid_status):
"""Test that invalid licenses raise UnauthorizedAndForceLogout"""
# Arrange
mock_settings = MagicMock()
mock_settings.license.status = invalid_status
@enterprise_license_required
def enterprise_feature():
return "enterprise_success"
# Act & Assert
with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
with pytest.raises(UnauthorizedAndForceLogout) as exc_info:
enterprise_feature()
assert "license is invalid" in str(exc_info.value)