dify
This commit is contained in:
@@ -0,0 +1,456 @@
|
||||
"""
|
||||
Test suite for account activation flows.
|
||||
|
||||
This module tests the account activation mechanism including:
|
||||
- Invitation token validation
|
||||
- Account activation with user preferences
|
||||
- Workspace member onboarding
|
||||
- Initial login after activation
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.activate import ActivateApi, ActivateCheckApi
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from models.account import AccountStatus
|
||||
|
||||
|
||||
class TestActivateCheckApi:
|
||||
"""Test cases for checking activation token validity."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_invitation(self):
|
||||
"""Create mock invitation object."""
|
||||
tenant = MagicMock()
|
||||
tenant.id = "workspace-123"
|
||||
tenant.name = "Test Workspace"
|
||||
|
||||
return {
|
||||
"data": {"email": "invitee@example.com"},
|
||||
"tenant": tenant,
|
||||
}
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation):
|
||||
"""
|
||||
Test checking valid invitation token.
|
||||
|
||||
Verifies that:
|
||||
- Valid token returns invitation data
|
||||
- Workspace information is included
|
||||
- Invitee email is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/activate/check?workspace_id=workspace-123&email=invitee@example.com&token=valid_token"
|
||||
):
|
||||
api = ActivateCheckApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert response["is_valid"] is True
|
||||
assert response["data"]["workspace_name"] == "Test Workspace"
|
||||
assert response["data"]["workspace_id"] == "workspace-123"
|
||||
assert response["data"]["email"] == "invitee@example.com"
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
def test_check_invalid_invitation_token(self, mock_get_invitation, app):
|
||||
"""
|
||||
Test checking invalid invitation token.
|
||||
|
||||
Verifies that:
|
||||
- Invalid token returns is_valid as False
|
||||
- No data is returned for invalid tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = None
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/activate/check?workspace_id=workspace-123&email=test@example.com&token=invalid_token"
|
||||
):
|
||||
api = ActivateCheckApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert response["is_valid"] is False
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation):
|
||||
"""
|
||||
Test checking token without workspace ID.
|
||||
|
||||
Verifies that:
|
||||
- Token can be checked without workspace_id parameter
|
||||
- System handles None workspace_id gracefully
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/activate/check?email=invitee@example.com&token=valid_token"):
|
||||
api = ActivateCheckApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert response["is_valid"] is True
|
||||
mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token")
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation):
|
||||
"""
|
||||
Test checking token without email parameter.
|
||||
|
||||
Verifies that:
|
||||
- Token can be checked without email parameter
|
||||
- System handles None email gracefully
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/activate/check?workspace_id=workspace-123&token=valid_token"):
|
||||
api = ActivateCheckApi()
|
||||
response = api.get()
|
||||
|
||||
# Assert
|
||||
assert response["is_valid"] is True
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token")
|
||||
|
||||
|
||||
class TestActivateApi:
|
||||
"""Test cases for account activation endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create mock account object."""
|
||||
account = MagicMock()
|
||||
account.id = "account-123"
|
||||
account.email = "invitee@example.com"
|
||||
account.status = AccountStatus.PENDING
|
||||
return account
|
||||
|
||||
@pytest.fixture
|
||||
def mock_invitation(self, mock_account):
|
||||
"""Create mock invitation with account."""
|
||||
tenant = MagicMock()
|
||||
tenant.id = "workspace-123"
|
||||
tenant.name = "Test Workspace"
|
||||
|
||||
return {
|
||||
"data": {"email": "invitee@example.com"},
|
||||
"tenant": tenant,
|
||||
"account": mock_account,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token_pair(self):
|
||||
"""Create mock token pair object."""
|
||||
token_pair = MagicMock()
|
||||
token_pair.access_token = "access_token"
|
||||
token_pair.refresh_token = "refresh_token"
|
||||
token_pair.csrf_token = "csrf_token"
|
||||
token_pair.model_dump.return_value = {
|
||||
"access_token": "access_token",
|
||||
"refresh_token": "refresh_token",
|
||||
"csrf_token": "csrf_token",
|
||||
}
|
||||
return token_pair
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
@patch("controllers.console.auth.activate.AccountService.login")
|
||||
def test_successful_account_activation(
|
||||
self,
|
||||
mock_login,
|
||||
mock_db,
|
||||
mock_revoke_token,
|
||||
mock_get_invitation,
|
||||
app,
|
||||
mock_invitation,
|
||||
mock_account,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""
|
||||
Test successful account activation.
|
||||
|
||||
Verifies that:
|
||||
- Account is activated with user preferences
|
||||
- Account status is set to ACTIVE
|
||||
- User is logged in after activation
|
||||
- Invitation token is revoked
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_login.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/activate",
|
||||
method="POST",
|
||||
json={
|
||||
"workspace_id": "workspace-123",
|
||||
"email": "invitee@example.com",
|
||||
"token": "valid_token",
|
||||
"name": "John Doe",
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
):
|
||||
api = ActivateApi()
|
||||
response = api.post()
|
||||
|
||||
# Assert
|
||||
assert response["result"] == "success"
|
||||
assert mock_account.name == "John Doe"
|
||||
assert mock_account.interface_language == "en-US"
|
||||
assert mock_account.timezone == "UTC"
|
||||
assert mock_account.status == AccountStatus.ACTIVE
|
||||
assert mock_account.initialized_at is not None
|
||||
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_login.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
def test_activation_with_invalid_token(self, mock_get_invitation, app):
|
||||
"""
|
||||
Test account activation with invalid token.
|
||||
|
||||
Verifies that:
|
||||
- AlreadyActivateError is raised for invalid tokens
|
||||
- No account changes are made
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/activate",
|
||||
method="POST",
|
||||
json={
|
||||
"workspace_id": "workspace-123",
|
||||
"email": "invitee@example.com",
|
||||
"token": "invalid_token",
|
||||
"name": "John Doe",
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
):
|
||||
api = ActivateApi()
|
||||
with pytest.raises(AlreadyActivateError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
@patch("controllers.console.auth.activate.AccountService.login")
|
||||
def test_activation_sets_interface_theme(
|
||||
self,
|
||||
mock_login,
|
||||
mock_db,
|
||||
mock_revoke_token,
|
||||
mock_get_invitation,
|
||||
app,
|
||||
mock_invitation,
|
||||
mock_account,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""
|
||||
Test that activation sets default interface theme.
|
||||
|
||||
Verifies that:
|
||||
- Interface theme is set to 'light' by default
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_login.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/activate",
|
||||
method="POST",
|
||||
json={
|
||||
"workspace_id": "workspace-123",
|
||||
"email": "invitee@example.com",
|
||||
"token": "valid_token",
|
||||
"name": "John Doe",
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
):
|
||||
api = ActivateApi()
|
||||
api.post()
|
||||
|
||||
# Assert
|
||||
assert mock_account.interface_theme == "light"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("language", "timezone"),
|
||||
[
|
||||
("en-US", "UTC"),
|
||||
("zh-Hans", "Asia/Shanghai"),
|
||||
("ja-JP", "Asia/Tokyo"),
|
||||
("es-ES", "Europe/Madrid"),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
@patch("controllers.console.auth.activate.AccountService.login")
|
||||
def test_activation_with_different_locales(
|
||||
self,
|
||||
mock_login,
|
||||
mock_db,
|
||||
mock_revoke_token,
|
||||
mock_get_invitation,
|
||||
app,
|
||||
mock_invitation,
|
||||
mock_account,
|
||||
mock_token_pair,
|
||||
language,
|
||||
timezone,
|
||||
):
|
||||
"""
|
||||
Test account activation with various language and timezone combinations.
|
||||
|
||||
Verifies that:
|
||||
- Different languages are accepted
|
||||
- Different timezones are accepted
|
||||
- User preferences are properly stored
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_login.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/activate",
|
||||
method="POST",
|
||||
json={
|
||||
"workspace_id": "workspace-123",
|
||||
"email": "invitee@example.com",
|
||||
"token": "valid_token",
|
||||
"name": "Test User",
|
||||
"interface_language": language,
|
||||
"timezone": timezone,
|
||||
},
|
||||
):
|
||||
api = ActivateApi()
|
||||
response = api.post()
|
||||
|
||||
# Assert
|
||||
assert response["result"] == "success"
|
||||
assert mock_account.interface_language == language
|
||||
assert mock_account.timezone == timezone
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
@patch("controllers.console.auth.activate.AccountService.login")
|
||||
def test_activation_returns_token_data(
|
||||
self,
|
||||
mock_login,
|
||||
mock_db,
|
||||
mock_revoke_token,
|
||||
mock_get_invitation,
|
||||
app,
|
||||
mock_invitation,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""
|
||||
Test that activation returns authentication tokens.
|
||||
|
||||
Verifies that:
|
||||
- Token pair is returned in response
|
||||
- All token types are included (access, refresh, csrf)
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_login.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/activate",
|
||||
method="POST",
|
||||
json={
|
||||
"workspace_id": "workspace-123",
|
||||
"email": "invitee@example.com",
|
||||
"token": "valid_token",
|
||||
"name": "John Doe",
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
):
|
||||
api = ActivateApi()
|
||||
response = api.post()
|
||||
|
||||
# Assert
|
||||
assert "data" in response
|
||||
assert response["data"]["access_token"] == "access_token"
|
||||
assert response["data"]["refresh_token"] == "refresh_token"
|
||||
assert response["data"]["csrf_token"] == "csrf_token"
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
@patch("controllers.console.auth.activate.AccountService.login")
|
||||
def test_activation_without_workspace_id(
|
||||
self,
|
||||
mock_login,
|
||||
mock_db,
|
||||
mock_revoke_token,
|
||||
mock_get_invitation,
|
||||
app,
|
||||
mock_invitation,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""
|
||||
Test account activation without workspace_id.
|
||||
|
||||
Verifies that:
|
||||
- Activation can proceed without workspace_id
|
||||
- Token revocation handles None workspace_id
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_login.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/activate",
|
||||
method="POST",
|
||||
json={
|
||||
"email": "invitee@example.com",
|
||||
"token": "valid_token",
|
||||
"name": "John Doe",
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
):
|
||||
api = ActivateApi()
|
||||
response = api.post()
|
||||
|
||||
# Assert
|
||||
assert response["result"] == "success"
|
||||
mock_revoke_token.assert_called_once_with(None, "invitee@example.com", "valid_token")
|
||||
@@ -0,0 +1,138 @@
|
||||
"""Test authentication security to prevent user enumeration."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
import services.errors.account
|
||||
from controllers.console.auth.error import AuthenticationFailedError
|
||||
from controllers.console.auth.login import LoginApi
|
||||
|
||||
|
||||
class TestAuthenticationSecurity:
|
||||
"""Test authentication endpoints for security against user enumeration."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.app = Flask(__name__)
|
||||
self.api = Api(self.app)
|
||||
self.api.add_resource(LoginApi, "/login")
|
||||
self.client = self.app.test_client()
|
||||
self.app.config["TESTING"] = True
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_invalid_email_with_registration_allowed(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
"""Test that invalid email raises AuthenticationFailedError when account not found."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = True
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
|
||||
# Assert
|
||||
with pytest.raises(AuthenticationFailedError) as exc_info:
|
||||
login_api.post()
|
||||
|
||||
assert exc_info.value.error_code == "authentication_failed"
|
||||
assert exc_info.value.description == "Invalid email or password."
|
||||
mock_add_rate_limit.assert_called_once_with("nonexistent@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_wrong_password_returns_error(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db
|
||||
):
|
||||
"""Test that wrong password returns AuthenticationFailedError."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
"/login", method="POST", json={"email": "existing@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
|
||||
# Assert
|
||||
with pytest.raises(AuthenticationFailedError) as exc_info:
|
||||
login_api.post()
|
||||
|
||||
assert exc_info.value.error_code == "authentication_failed"
|
||||
assert exc_info.value.description == "Invalid email or password."
|
||||
mock_add_rate_limit.assert_called_once_with("existing@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_invalid_email_with_registration_disabled(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
"""Test that invalid email raises AuthenticationFailedError when account not found."""
|
||||
# Arrange
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.")
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
mock_features.return_value.is_allow_register = False
|
||||
|
||||
# Act
|
||||
with self.app.test_request_context(
|
||||
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
|
||||
# Assert
|
||||
with pytest.raises(AuthenticationFailedError) as exc_info:
|
||||
login_api.post()
|
||||
|
||||
assert exc_info.value.error_code == "authentication_failed"
|
||||
assert exc_info.value.description == "Invalid email or password."
|
||||
mock_add_rate_limit.assert_called_once_with("nonexistent@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
|
||||
@patch("controllers.console.auth.login.AccountService.send_reset_password_email")
|
||||
def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db):
|
||||
"""Test that reset password returns success with token for existing accounts."""
|
||||
# Mock the setup check
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists
|
||||
|
||||
# Test with existing account
|
||||
mock_get_user.return_value = MagicMock(email="existing@example.com")
|
||||
mock_send_email.return_value = "token123"
|
||||
|
||||
with self.app.test_request_context("/reset-password", method="POST", json={"email": "existing@example.com"}):
|
||||
from controllers.console.auth.login import ResetPasswordSendEmailApi
|
||||
|
||||
api = ResetPasswordSendEmailApi()
|
||||
result = api.post()
|
||||
|
||||
assert result == {"result": "success", "data": "token123"}
|
||||
@@ -0,0 +1,546 @@
|
||||
"""
|
||||
Test suite for email verification authentication flows.
|
||||
|
||||
This module tests the email code login mechanism including:
|
||||
- Email code sending with rate limiting
|
||||
- Code verification and validation
|
||||
- Account creation via email verification
|
||||
- Workspace creation for new users
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError
|
||||
from controllers.console.auth.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
|
||||
from controllers.console.error import (
|
||||
AccountInFreezeError,
|
||||
AccountNotFound,
|
||||
EmailSendIpLimitError,
|
||||
NotAllowedCreateWorkspace,
|
||||
WorkspacesLimitExceeded,
|
||||
)
|
||||
from services.errors.account import AccountRegisterError
|
||||
|
||||
|
||||
class TestEmailCodeLoginSendEmailApi:
|
||||
"""Test cases for sending email verification codes."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create mock account object."""
|
||||
account = MagicMock()
|
||||
account.email = "test@example.com"
|
||||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
|
||||
@patch("controllers.console.auth.login.AccountService.send_email_code_login_email")
|
||||
def test_send_email_code_existing_user(
|
||||
self, mock_send_email, mock_get_user, mock_is_ip_limit, mock_db, app, mock_account
|
||||
):
|
||||
"""
|
||||
Test sending email code to existing user.
|
||||
|
||||
Verifies that:
|
||||
- Email code is sent to existing account
|
||||
- Token is generated and returned
|
||||
- IP rate limiting is checked
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_send_email.return_value = "email_token_123"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/email-code-login", method="POST", json={"email": "test@example.com", "language": "en-US"}
|
||||
):
|
||||
api = EmailCodeLoginSendEmailApi()
|
||||
response = api.post()
|
||||
|
||||
# Assert
|
||||
assert response["result"] == "success"
|
||||
assert response["data"] == "email_token_123"
|
||||
mock_send_email.assert_called_once_with(account=mock_account, language="en-US")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
@patch("controllers.console.auth.login.AccountService.send_email_code_login_email")
|
||||
def test_send_email_code_new_user_registration_allowed(
|
||||
self, mock_send_email, mock_get_features, mock_get_user, mock_is_ip_limit, mock_db, app
|
||||
):
|
||||
"""
|
||||
Test sending email code to new user when registration is allowed.
|
||||
|
||||
Verifies that:
|
||||
- Email code is sent even for non-existent accounts
|
||||
- Registration is allowed by system features
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = None
|
||||
mock_get_features.return_value.is_allow_register = True
|
||||
mock_send_email.return_value = "email_token_123"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/email-code-login", method="POST", json={"email": "newuser@example.com", "language": "en-US"}
|
||||
):
|
||||
api = EmailCodeLoginSendEmailApi()
|
||||
response = api.post()
|
||||
|
||||
# Assert
|
||||
assert response["result"] == "success"
|
||||
mock_send_email.assert_called_once_with(email="newuser@example.com", language="en-US")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
def test_send_email_code_new_user_registration_disabled(
|
||||
self, mock_get_features, mock_get_user, mock_is_ip_limit, mock_db, app
|
||||
):
|
||||
"""
|
||||
Test sending email code to new user when registration is disabled.
|
||||
|
||||
Verifies that:
|
||||
- AccountNotFound is raised for non-existent accounts
|
||||
- Registration is blocked by system features
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = None
|
||||
mock_get_features.return_value.is_allow_register = False
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/email-code-login", method="POST", json={"email": "newuser@example.com"}):
|
||||
api = EmailCodeLoginSendEmailApi()
|
||||
with pytest.raises(AccountNotFound):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit")
|
||||
def test_send_email_code_ip_rate_limited(self, mock_is_ip_limit, mock_db, app):
|
||||
"""
|
||||
Test email code sending blocked by IP rate limit.
|
||||
|
||||
Verifies that:
|
||||
- EmailSendIpLimitError is raised when IP limit exceeded
|
||||
- Prevents spam and abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/email-code-login", method="POST", json={"email": "test@example.com"}):
|
||||
api = EmailCodeLoginSendEmailApi()
|
||||
with pytest.raises(EmailSendIpLimitError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
|
||||
def test_send_email_code_frozen_account(self, mock_get_user, mock_is_ip_limit, mock_db, app):
|
||||
"""
|
||||
Test email code sending to frozen account.
|
||||
|
||||
Verifies that:
|
||||
- AccountInFreezeError is raised for frozen accounts
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.side_effect = AccountRegisterError("Account frozen")
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/email-code-login", method="POST", json={"email": "frozen@example.com"}):
|
||||
api = EmailCodeLoginSendEmailApi()
|
||||
with pytest.raises(AccountInFreezeError):
|
||||
api.post()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("language_input", "expected_language"),
|
||||
[
|
||||
("zh-Hans", "zh-Hans"),
|
||||
("en-US", "en-US"),
|
||||
(None, "en-US"),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
|
||||
@patch("controllers.console.auth.login.AccountService.send_email_code_login_email")
|
||||
def test_send_email_code_language_handling(
|
||||
self,
|
||||
mock_send_email,
|
||||
mock_get_user,
|
||||
mock_is_ip_limit,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
language_input,
|
||||
expected_language,
|
||||
):
|
||||
"""
|
||||
Test email code sending with different language preferences.
|
||||
|
||||
Verifies that:
|
||||
- Language parameter is correctly processed
|
||||
- Defaults to en-US when not specified
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_send_email.return_value = "token"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/email-code-login", method="POST", json={"email": "test@example.com", "language": language_input}
|
||||
):
|
||||
api = EmailCodeLoginSendEmailApi()
|
||||
api.post()
|
||||
|
||||
# Assert
|
||||
call_args = mock_send_email.call_args
|
||||
assert call_args.kwargs["language"] == expected_language
|
||||
|
||||
|
||||
class TestEmailCodeLoginApi:
|
||||
"""Test cases for email code verification and login."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create mock account object."""
|
||||
account = MagicMock()
|
||||
account.email = "test@example.com"
|
||||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token_pair(self):
|
||||
"""Create mock token pair object."""
|
||||
token_pair = MagicMock()
|
||||
token_pair.access_token = "access_token"
|
||||
token_pair.refresh_token = "refresh_token"
|
||||
token_pair.csrf_token = "csrf_token"
|
||||
return token_pair
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
|
||||
@patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token")
|
||||
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.AccountService.login")
|
||||
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
|
||||
def test_email_code_login_existing_user(
|
||||
self,
|
||||
mock_reset_rate_limit,
|
||||
mock_login,
|
||||
mock_get_tenants,
|
||||
mock_get_user,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""
|
||||
Test successful email code login for existing user.
|
||||
|
||||
Verifies that:
|
||||
- Email and code are validated
|
||||
- Token is revoked after use
|
||||
- User is logged in with token pair
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
mock_login.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "test@example.com", "code": "123456", "token": "valid_token"},
|
||||
):
|
||||
api = EmailCodeLoginApi()
|
||||
response = api.post()
|
||||
|
||||
# Assert
|
||||
assert response.json["result"] == "success"
|
||||
mock_revoke_token.assert_called_once_with("valid_token")
|
||||
mock_login.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
|
||||
@patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token")
|
||||
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
|
||||
@patch("controllers.console.auth.login.AccountService.create_account_and_tenant")
|
||||
@patch("controllers.console.auth.login.AccountService.login")
|
||||
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
|
||||
def test_email_code_login_new_user_creates_account(
|
||||
self,
|
||||
mock_reset_rate_limit,
|
||||
mock_login,
|
||||
mock_create_account,
|
||||
mock_get_user,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""
|
||||
Test email code login creates new account for new user.
|
||||
|
||||
Verifies that:
|
||||
- New account is created when user doesn't exist
|
||||
- Workspace is created for new user
|
||||
- User is logged in after account creation
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = None
|
||||
mock_create_account.return_value = mock_account
|
||||
mock_login.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "newuser@example.com", "code": "123456", "token": "valid_token", "language": "en-US"},
|
||||
):
|
||||
api = EmailCodeLoginApi()
|
||||
response = api.post()
|
||||
|
||||
# Assert
|
||||
assert response.json["result"] == "success"
|
||||
mock_create_account.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
|
||||
def test_email_code_login_invalid_token(self, mock_get_data, mock_db, app):
|
||||
"""
|
||||
Test email code login with invalid token.
|
||||
|
||||
Verifies that:
|
||||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "test@example.com", "code": "123456", "token": "invalid_token"},
|
||||
):
|
||||
api = EmailCodeLoginApi()
|
||||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
|
||||
def test_email_code_login_email_mismatch(self, mock_get_data, mock_db, app):
|
||||
"""
|
||||
Test email code login with mismatched email.
|
||||
|
||||
Verifies that:
|
||||
- InvalidEmailError is raised when email doesn't match token
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "different@example.com", "code": "123456", "token": "token"},
|
||||
):
|
||||
api = EmailCodeLoginApi()
|
||||
with pytest.raises(InvalidEmailError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
|
||||
def test_email_code_login_wrong_code(self, mock_get_data, mock_db, app):
|
||||
"""
|
||||
Test email code login with incorrect code.
|
||||
|
||||
Verifies that:
|
||||
- EmailCodeError is raised for wrong verification code
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "test@example.com", "code": "wrong_code", "token": "token"},
|
||||
):
|
||||
api = EmailCodeLoginApi()
|
||||
with pytest.raises(EmailCodeError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
|
||||
@patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token")
|
||||
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
def test_email_code_login_creates_workspace_for_user_without_tenant(
|
||||
self,
|
||||
mock_get_features,
|
||||
mock_get_tenants,
|
||||
mock_get_user,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
"""
|
||||
Test email code login creates workspace for user without tenant.
|
||||
|
||||
Verifies that:
|
||||
- Workspace is created when user has no tenants
|
||||
- User is added as owner of new workspace
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
mock_features = MagicMock()
|
||||
mock_features.is_allow_create_workspace = True
|
||||
mock_features.license.workspaces.is_available.return_value = True
|
||||
mock_get_features.return_value = mock_features
|
||||
|
||||
# Act & Assert - Should not raise WorkspacesLimitExceeded
|
||||
with app.test_request_context(
|
||||
"/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "test@example.com", "code": "123456", "token": "token"},
|
||||
):
|
||||
api = EmailCodeLoginApi()
|
||||
# This would complete the flow, but we're testing workspace creation logic
|
||||
# In real implementation, TenantService.create_tenant would be called
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
|
||||
@patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token")
|
||||
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
def test_email_code_login_workspace_limit_exceeded(
|
||||
self,
|
||||
mock_get_features,
|
||||
mock_get_tenants,
|
||||
mock_get_user,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
"""
|
||||
Test email code login fails when workspace limit exceeded.
|
||||
|
||||
Verifies that:
|
||||
- WorkspacesLimitExceeded is raised when limit reached
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
mock_features = MagicMock()
|
||||
mock_features.license.workspaces.is_available.return_value = False
|
||||
mock_get_features.return_value = mock_features
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "test@example.com", "code": "123456", "token": "token"},
|
||||
):
|
||||
api = EmailCodeLoginApi()
|
||||
with pytest.raises(WorkspacesLimitExceeded):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
|
||||
@patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token")
|
||||
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
def test_email_code_login_workspace_creation_not_allowed(
|
||||
self,
|
||||
mock_get_features,
|
||||
mock_get_tenants,
|
||||
mock_get_user,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
"""
|
||||
Test email code login fails when workspace creation not allowed.
|
||||
|
||||
Verifies that:
|
||||
- NotAllowedCreateWorkspace is raised when creation disabled
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_get_tenants.return_value = []
|
||||
mock_features = MagicMock()
|
||||
mock_features.is_allow_create_workspace = False
|
||||
mock_get_features.return_value = mock_features
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "test@example.com", "code": "123456", "token": "token"},
|
||||
):
|
||||
api = EmailCodeLoginApi()
|
||||
with pytest.raises(NotAllowedCreateWorkspace):
|
||||
api.post()
|
||||
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Test suite for login and logout authentication flows.
|
||||
|
||||
This module tests the core authentication endpoints including:
|
||||
- Email/password login with rate limiting
|
||||
- Session management and logout
|
||||
- Cookie-based token handling
|
||||
- Account status validation
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailPasswordLoginLimitError,
|
||||
InvalidEmailError,
|
||||
)
|
||||
from controllers.console.auth.login import LoginApi, LogoutApi
|
||||
from controllers.console.error import (
|
||||
AccountBannedError,
|
||||
AccountInFreezeError,
|
||||
WorkspacesLimitExceeded,
|
||||
)
|
||||
from services.errors.account import AccountLoginError, AccountPasswordError
|
||||
|
||||
|
||||
class TestLoginApi:
|
||||
"""Test cases for the LoginApi endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def api(self, app):
|
||||
"""Create Flask-RESTX API instance."""
|
||||
return Api(app)
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, app, api):
|
||||
"""Create test client."""
|
||||
api.add_resource(LoginApi, "/login")
|
||||
return app.test_client()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create mock account object."""
|
||||
account = MagicMock()
|
||||
account.id = "test-account-id"
|
||||
account.email = "test@example.com"
|
||||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token_pair(self):
|
||||
"""Create mock token pair object."""
|
||||
token_pair = MagicMock()
|
||||
token_pair.access_token = "mock_access_token"
|
||||
token_pair.refresh_token = "mock_refresh_token"
|
||||
token_pair.csrf_token = "mock_csrf_token"
|
||||
return token_pair
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.AccountService.login")
|
||||
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
|
||||
def test_successful_login_without_invitation(
|
||||
self,
|
||||
mock_reset_rate_limit,
|
||||
mock_login,
|
||||
mock_get_tenants,
|
||||
mock_authenticate,
|
||||
mock_get_invitation,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""
|
||||
Test successful login flow without invitation token.
|
||||
|
||||
Verifies that:
|
||||
- Valid credentials authenticate successfully
|
||||
- Tokens are generated and set in cookies
|
||||
- Rate limit is reset after successful login
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()] # Has at least one tenant
|
||||
mock_login.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
response = login_api.post()
|
||||
|
||||
# Assert
|
||||
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!")
|
||||
mock_login.assert_called_once()
|
||||
mock_reset_rate_limit.assert_called_once_with("test@example.com")
|
||||
assert response.json["result"] == "success"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.AccountService.login")
|
||||
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
|
||||
def test_successful_login_with_valid_invitation(
|
||||
self,
|
||||
mock_reset_rate_limit,
|
||||
mock_login,
|
||||
mock_get_tenants,
|
||||
mock_authenticate,
|
||||
mock_get_invitation,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""
|
||||
Test successful login with valid invitation token.
|
||||
|
||||
Verifies that:
|
||||
- Invitation token is validated
|
||||
- Email matches invitation email
|
||||
- Authentication proceeds with invitation token
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = {"data": {"email": "test@example.com"}}
|
||||
mock_authenticate.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
mock_login.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/login",
|
||||
method="POST",
|
||||
json={"email": "test@example.com", "password": "ValidPass123!", "invite_token": "valid_token"},
|
||||
):
|
||||
login_api = LoginApi()
|
||||
response = login_api.post()
|
||||
|
||||
# Assert
|
||||
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", "valid_token")
|
||||
assert response.json["result"] == "success"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
|
||||
"""
|
||||
Test login rejection when rate limit is exceeded.
|
||||
|
||||
Verifies that:
|
||||
- Rate limit check is performed before authentication
|
||||
- EmailPasswordLoginLimitError is raised when limit exceeded
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = True
|
||||
mock_get_invitation.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "test@example.com", "password": "password"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(EmailPasswordLoginLimitError):
|
||||
login_api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True)
|
||||
@patch("controllers.console.auth.login.BillingService.is_email_in_freeze")
|
||||
def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app):
|
||||
"""
|
||||
Test login rejection for frozen accounts.
|
||||
|
||||
Verifies that:
|
||||
- Billing freeze status is checked when billing enabled
|
||||
- AccountInFreezeError is raised for frozen accounts
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_frozen.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "frozen@example.com", "password": "password"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AccountInFreezeError):
|
||||
login_api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
def test_login_fails_with_invalid_credentials(
|
||||
self,
|
||||
mock_add_rate_limit,
|
||||
mock_authenticate,
|
||||
mock_get_invitation,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
"""
|
||||
Test login failure with invalid credentials.
|
||||
|
||||
Verifies that:
|
||||
- AuthenticationFailedError is raised for wrong password
|
||||
- Login error rate limit counter is incremented
|
||||
- Generic error message prevents user enumeration
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = AccountPasswordError("Invalid password")
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "test@example.com", "password": "WrongPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AuthenticationFailedError):
|
||||
login_api.post()
|
||||
|
||||
mock_add_rate_limit.assert_called_once_with("test@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
def test_login_fails_for_banned_account(
|
||||
self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app
|
||||
):
|
||||
"""
|
||||
Test login rejection for banned accounts.
|
||||
|
||||
Verifies that:
|
||||
- AccountBannedError is raised for banned accounts
|
||||
- Login is prevented even with valid credentials
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = AccountLoginError("Account is banned")
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "banned@example.com", "password": "ValidPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(AccountBannedError):
|
||||
login_api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
def test_login_fails_when_no_workspace_and_limit_exceeded(
|
||||
self,
|
||||
mock_get_features,
|
||||
mock_get_tenants,
|
||||
mock_authenticate,
|
||||
mock_get_invitation,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
"""
|
||||
Test login failure when user has no workspace and workspace limit exceeded.
|
||||
|
||||
Verifies that:
|
||||
- WorkspacesLimitExceeded is raised when limit reached
|
||||
- User cannot login without an assigned workspace
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.return_value = mock_account
|
||||
mock_get_tenants.return_value = [] # No tenants
|
||||
|
||||
mock_features = MagicMock()
|
||||
mock_features.is_allow_create_workspace = True
|
||||
mock_features.license.workspaces.is_available.return_value = False
|
||||
mock_get_features.return_value = mock_features
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"}
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(WorkspacesLimitExceeded):
|
||||
login_api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
|
||||
"""
|
||||
Test login failure when invitation email doesn't match login email.
|
||||
|
||||
Verifies that:
|
||||
- InvalidEmailError is raised for email mismatch
|
||||
- Security check prevents invitation token abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}}
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/login",
|
||||
method="POST",
|
||||
json={"email": "different@example.com", "password": "ValidPass123!", "invite_token": "token"},
|
||||
):
|
||||
login_api = LoginApi()
|
||||
with pytest.raises(InvalidEmailError):
|
||||
login_api.post()
|
||||
|
||||
|
||||
class TestLogoutApi:
|
||||
"""Test cases for the LogoutApi endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create mock account object."""
|
||||
account = MagicMock()
|
||||
account.id = "test-account-id"
|
||||
account.email = "test@example.com"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.current_account_with_tenant")
|
||||
@patch("controllers.console.auth.login.AccountService.logout")
|
||||
@patch("controllers.console.auth.login.flask_login.logout_user")
|
||||
def test_successful_logout(
|
||||
self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app, mock_account
|
||||
):
|
||||
"""
|
||||
Test successful logout flow.
|
||||
|
||||
Verifies that:
|
||||
- User session is terminated
|
||||
- AccountService.logout is called
|
||||
- All authentication cookies are cleared
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_current_account.return_value = (mock_account, MagicMock())
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/logout", method="POST"):
|
||||
logout_api = LogoutApi()
|
||||
response = logout_api.post()
|
||||
|
||||
# Assert
|
||||
mock_service_logout.assert_called_once_with(account=mock_account)
|
||||
mock_logout_user.assert_called_once()
|
||||
assert response.json["result"] == "success"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.current_account_with_tenant")
|
||||
@patch("controllers.console.auth.login.flask_login")
|
||||
def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app):
|
||||
"""
|
||||
Test logout for anonymous (not logged in) user.
|
||||
|
||||
Verifies that:
|
||||
- Anonymous users can call logout endpoint
|
||||
- No errors are raised
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
# Create a mock anonymous user that will pass isinstance check
|
||||
anonymous_user = MagicMock()
|
||||
mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {})
|
||||
anonymous_user.__class__ = mock_flask_login.AnonymousUserMixin
|
||||
mock_current_account.return_value = (anonymous_user, None)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/logout", method="POST"):
|
||||
logout_api = LogoutApi()
|
||||
response = logout_api.post()
|
||||
|
||||
# Assert
|
||||
assert response.json["result"] == "success"
|
||||
500
dify/api/tests/unit_tests/controllers/console/auth/test_oauth.py
Normal file
500
dify/api/tests/unit_tests/controllers/console/auth/test_oauth.py
Normal file
@@ -0,0 +1,500 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.oauth import (
|
||||
OAuthCallback,
|
||||
OAuthLogin,
|
||||
_generate_account,
|
||||
_get_account_by_openid_or_email,
|
||||
get_oauth_providers,
|
||||
)
|
||||
from libs.oauth import OAuthUserInfo
|
||||
from models.account import AccountStatus
|
||||
from services.errors.account import AccountRegisterError
|
||||
|
||||
|
||||
class TestGetOAuthProviders:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("github_config", "google_config", "expected_github", "expected_google"),
|
||||
[
|
||||
# Both providers configured
|
||||
(
|
||||
{"id": "github_id", "secret": "github_secret"},
|
||||
{"id": "google_id", "secret": "google_secret"},
|
||||
True,
|
||||
True,
|
||||
),
|
||||
# Only GitHub configured
|
||||
({"id": "github_id", "secret": "github_secret"}, {"id": None, "secret": None}, True, False),
|
||||
# Only Google configured
|
||||
({"id": None, "secret": None}, {"id": "google_id", "secret": "google_secret"}, False, True),
|
||||
# No providers configured
|
||||
({"id": None, "secret": None}, {"id": None, "secret": None}, False, False),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
def test_should_configure_oauth_providers_correctly(
|
||||
self, mock_config, app, github_config, google_config, expected_github, expected_google
|
||||
):
|
||||
mock_config.GITHUB_CLIENT_ID = github_config["id"]
|
||||
mock_config.GITHUB_CLIENT_SECRET = github_config["secret"]
|
||||
mock_config.GOOGLE_CLIENT_ID = google_config["id"]
|
||||
mock_config.GOOGLE_CLIENT_SECRET = google_config["secret"]
|
||||
mock_config.CONSOLE_API_URL = "http://localhost"
|
||||
|
||||
with app.app_context():
|
||||
providers = get_oauth_providers()
|
||||
|
||||
assert (providers["github"] is not None) == expected_github
|
||||
assert (providers["google"] is not None) == expected_google
|
||||
|
||||
|
||||
class TestOAuthLogin:
|
||||
@pytest.fixture
|
||||
def resource(self):
|
||||
return OAuthLogin()
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider(self):
|
||||
provider = MagicMock()
|
||||
provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?..."
|
||||
return provider
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invite_token", "expected_token"),
|
||||
[
|
||||
(None, None),
|
||||
("test_invite_token", "test_invite_token"),
|
||||
("", None),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_should_handle_oauth_login_with_various_tokens(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_get_providers,
|
||||
resource,
|
||||
app,
|
||||
mock_oauth_provider,
|
||||
invite_token,
|
||||
expected_token,
|
||||
):
|
||||
mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
|
||||
|
||||
query_string = f"invite_token={invite_token}" if invite_token else ""
|
||||
with app.test_request_context(f"/auth/oauth/github?{query_string}"):
|
||||
resource.get("github")
|
||||
|
||||
mock_oauth_provider.get_authorization_url.assert_called_once_with(invite_token=expected_token)
|
||||
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "expected_error"),
|
||||
[
|
||||
("invalid_provider", "Invalid provider"),
|
||||
("github", "Invalid provider"), # When GitHub is not configured
|
||||
("google", "Invalid provider"), # When Google is not configured
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
def test_should_return_error_for_invalid_providers(
|
||||
self, mock_get_providers, resource, app, provider, expected_error
|
||||
):
|
||||
mock_get_providers.return_value = {"github": None, "google": None}
|
||||
|
||||
with app.test_request_context(f"/auth/oauth/{provider}"):
|
||||
response, status_code = resource.get(provider)
|
||||
|
||||
assert status_code == 400
|
||||
assert response["error"] == expected_error
|
||||
|
||||
|
||||
class TestOAuthCallback:
|
||||
@pytest.fixture
|
||||
def resource(self):
|
||||
return OAuthCallback()
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def oauth_setup(self):
|
||||
"""Common OAuth setup for callback tests"""
|
||||
oauth_provider = MagicMock()
|
||||
oauth_provider.get_access_token.return_value = "access_token"
|
||||
oauth_provider.get_user_info.return_value = OAuthUserInfo(id="123", name="Test User", email="test@example.com")
|
||||
|
||||
account = MagicMock()
|
||||
account.status = AccountStatus.ACTIVE
|
||||
|
||||
token_pair = MagicMock()
|
||||
token_pair.access_token = "jwt_access_token"
|
||||
token_pair.refresh_token = "jwt_refresh_token"
|
||||
|
||||
return {"provider": oauth_provider, "account": account, "token_pair": token_pair}
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_should_handle_successful_oauth_callback(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
resource,
|
||||
app,
|
||||
oauth_setup,
|
||||
):
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
mock_generate_account.return_value = oauth_setup["account"]
|
||||
mock_account_service.login.return_value = oauth_setup["token_pair"]
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
resource.get("github")
|
||||
|
||||
oauth_setup["provider"].get_access_token.assert_called_once_with("test_code")
|
||||
oauth_setup["provider"].get_user_info.assert_called_once_with("access_token")
|
||||
mock_redirect.assert_called_once_with("http://localhost:3000")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception", "expected_error"),
|
||||
[
|
||||
(Exception("OAuth error"), "OAuth process failed"),
|
||||
(ValueError("Invalid token"), "OAuth process failed"),
|
||||
(KeyError("Missing key"), "OAuth process failed"),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
def test_should_handle_oauth_exceptions(
|
||||
self, mock_get_providers, mock_db, resource, app, exception, expected_error
|
||||
):
|
||||
# Mock database session
|
||||
mock_db.session = MagicMock()
|
||||
mock_db.session.rollback = MagicMock()
|
||||
|
||||
# Import the real requests module to create a proper exception
|
||||
import httpx
|
||||
|
||||
request_exception = httpx.RequestError("OAuth error")
|
||||
request_exception.response = MagicMock()
|
||||
request_exception.response.text = str(exception)
|
||||
|
||||
mock_oauth_provider = MagicMock()
|
||||
mock_oauth_provider.get_access_token.side_effect = request_exception
|
||||
mock_get_providers.return_value = {"github": mock_oauth_provider}
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
response, status_code = resource.get("github")
|
||||
|
||||
assert status_code == 400
|
||||
assert response["error"] == expected_error
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("account_status", "expected_redirect"),
|
||||
[
|
||||
(AccountStatus.BANNED, "http://localhost:3000/signin?message=Account is banned."),
|
||||
# CLOSED status: Currently NOT handled, will proceed to login (security issue)
|
||||
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
|
||||
(
|
||||
AccountStatus.CLOSED.value,
|
||||
"http://localhost:3000",
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_should_redirect_based_on_account_status(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
resource,
|
||||
app,
|
||||
oauth_setup,
|
||||
account_status,
|
||||
expected_redirect,
|
||||
):
|
||||
# Mock database session
|
||||
mock_db.session = MagicMock()
|
||||
mock_db.session.rollback = MagicMock()
|
||||
mock_db.session.commit = MagicMock()
|
||||
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
|
||||
account = MagicMock()
|
||||
account.status = account_status
|
||||
account.id = "123"
|
||||
mock_generate_account.return_value = account
|
||||
|
||||
# Mock login for CLOSED status
|
||||
mock_token_pair = MagicMock()
|
||||
mock_token_pair.access_token = "jwt_access_token"
|
||||
mock_token_pair.refresh_token = "jwt_refresh_token"
|
||||
mock_token_pair.csrf_token = "csrf_token"
|
||||
mock_account_service.login.return_value = mock_token_pair
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
resource.get("github")
|
||||
|
||||
mock_redirect.assert_called_once_with(expected_redirect)
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
def test_should_activate_pending_account(
|
||||
self,
|
||||
mock_account_service,
|
||||
mock_tenant_service,
|
||||
mock_db,
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
resource,
|
||||
app,
|
||||
oauth_setup,
|
||||
):
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = AccountStatus.PENDING
|
||||
mock_generate_account.return_value = mock_account
|
||||
|
||||
mock_token_pair = MagicMock()
|
||||
mock_token_pair.access_token = "jwt_access_token"
|
||||
mock_token_pair.refresh_token = "jwt_refresh_token"
|
||||
mock_token_pair.csrf_token = "csrf_token"
|
||||
mock_account_service.login.return_value = mock_token_pair
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
resource.get("github")
|
||||
|
||||
assert mock_account.status == AccountStatus.ACTIVE
|
||||
assert mock_account.initialized_at is not None
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_defensive_check_for_closed_account_status(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_account_service,
|
||||
mock_tenant_service,
|
||||
mock_db,
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
resource,
|
||||
app,
|
||||
oauth_setup,
|
||||
):
|
||||
"""Defensive test for CLOSED account status handling in OAuth callback.
|
||||
|
||||
This is a defensive test documenting expected security behavior for CLOSED accounts.
|
||||
|
||||
Current behavior: CLOSED status is NOT checked, allowing closed accounts to login.
|
||||
Expected behavior: CLOSED accounts should be rejected like BANNED accounts.
|
||||
|
||||
Context:
|
||||
- AccountStatus.CLOSED is defined in the enum but never used in production
|
||||
- The close_account() method exists but is never called
|
||||
- Account deletion uses external service instead of status change
|
||||
- All authentication services (OAuth, password, email) don't check CLOSED status
|
||||
|
||||
TODO: If CLOSED status is implemented in the future:
|
||||
1. Update OAuth callback to check for CLOSED status
|
||||
2. Add similar checks to all authentication services for consistency
|
||||
3. Update this test to verify the rejection behavior
|
||||
|
||||
Security consideration: Until properly implemented, CLOSED status provides no protection.
|
||||
"""
|
||||
# Setup
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
|
||||
# Create account with CLOSED status
|
||||
closed_account = MagicMock()
|
||||
closed_account.status = AccountStatus.CLOSED
|
||||
closed_account.id = "123"
|
||||
closed_account.name = "Closed Account"
|
||||
mock_generate_account.return_value = closed_account
|
||||
|
||||
# Mock successful login (current behavior)
|
||||
mock_token_pair = MagicMock()
|
||||
mock_token_pair.access_token = "jwt_access_token"
|
||||
mock_token_pair.refresh_token = "jwt_refresh_token"
|
||||
mock_token_pair.csrf_token = "csrf_token"
|
||||
mock_account_service.login.return_value = mock_token_pair
|
||||
|
||||
# Execute OAuth callback
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
resource.get("github")
|
||||
|
||||
# Verify current behavior: login succeeds (this is NOT ideal)
|
||||
mock_redirect.assert_called_once_with("http://localhost:3000")
|
||||
mock_account_service.login.assert_called_once()
|
||||
|
||||
# Document expected behavior in comments:
|
||||
# Expected: mock_redirect.assert_called_once_with(
|
||||
# "http://localhost:3000/signin?message=Account is closed."
|
||||
# )
|
||||
# Expected: mock_account_service.login.assert_not_called()
|
||||
|
||||
|
||||
class TestAccountGeneration:
|
||||
@pytest.fixture
|
||||
def user_info(self):
|
||||
return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
account = MagicMock()
|
||||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.Account")
|
||||
@patch("controllers.console.auth.oauth.Session")
|
||||
@patch("controllers.console.auth.oauth.select")
|
||||
def test_should_get_account_by_openid_or_email(
|
||||
self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account
|
||||
):
|
||||
# Mock db.engine for Session creation
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Test OpenID found
|
||||
mock_account_model.get_by_openid.return_value = mock_account
|
||||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
assert result == mock_account
|
||||
mock_account_model.get_by_openid.assert_called_once_with("github", "123")
|
||||
|
||||
# Test fallback to email
|
||||
mock_account_model.get_by_openid.return_value = None
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
|
||||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
assert result == mock_account
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("allow_register", "existing_account", "should_create"),
|
||||
[
|
||||
(True, None, True), # New account creation allowed
|
||||
(True, "existing", False), # Existing account
|
||||
(False, None, False), # Registration not allowed
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_handle_account_generation_scenarios(
|
||||
self,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
app,
|
||||
user_info,
|
||||
mock_account,
|
||||
allow_register,
|
||||
existing_account,
|
||||
should_create,
|
||||
):
|
||||
mock_get_account.return_value = mock_account if existing_account else None
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = allow_register
|
||||
mock_register_service.register.return_value = mock_account
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||
if not allow_register and not existing_account:
|
||||
with pytest.raises(AccountRegisterError):
|
||||
_generate_account("github", user_info)
|
||||
else:
|
||||
result = _generate_account("github", user_info)
|
||||
assert result == mock_account
|
||||
|
||||
if should_create:
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.tenant_was_created")
|
||||
def test_should_create_workspace_for_account_without_tenant(
|
||||
self,
|
||||
mock_event,
|
||||
mock_account_service,
|
||||
mock_feature_service,
|
||||
mock_tenant_service,
|
||||
mock_get_account,
|
||||
app,
|
||||
user_info,
|
||||
mock_account,
|
||||
):
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_tenant_service.get_join_tenants.return_value = []
|
||||
mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True
|
||||
|
||||
mock_new_tenant = MagicMock()
|
||||
mock_tenant_service.create_tenant.return_value = mock_new_tenant
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||
result = _generate_account("github", user_info)
|
||||
|
||||
assert result == mock_account
|
||||
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
|
||||
mock_tenant_service.create_tenant_member.assert_called_once_with(
|
||||
mock_new_tenant, mock_account, role="owner"
|
||||
)
|
||||
mock_event.send.assert_called_once_with(mock_new_tenant)
|
||||
@@ -0,0 +1,508 @@
|
||||
"""
|
||||
Test suite for password reset authentication flows.
|
||||
|
||||
This module tests the password reset mechanism including:
|
||||
- Password reset email sending
|
||||
- Verification code validation
|
||||
- Password reset with token
|
||||
- Rate limiting and security checks
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
EmailPasswordResetLimitError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.auth.forgot_password import (
|
||||
ForgotPasswordCheckApi,
|
||||
ForgotPasswordResetApi,
|
||||
ForgotPasswordSendEmailApi,
|
||||
)
|
||||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
"""Test cases for sending password reset emails."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create mock account object."""
|
||||
account = MagicMock()
|
||||
account.email = "test@example.com"
|
||||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.forgot_password.FeatureService.get_system_features")
|
||||
def test_send_reset_email_success(
|
||||
self,
|
||||
mock_get_features,
|
||||
mock_send_email,
|
||||
mock_select,
|
||||
mock_session,
|
||||
mock_is_ip_limit,
|
||||
mock_forgot_db,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
"""
|
||||
Test successful password reset email sending.
|
||||
|
||||
Verifies that:
|
||||
- Email is sent to valid account
|
||||
- Reset token is generated and returned
|
||||
- IP rate limiting is checked
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_send_email.return_value = "reset_token_123"
|
||||
mock_get_features.return_value.is_allow_register = True
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/forgot-password", method="POST", json={"email": "test@example.com", "language": "en-US"}
|
||||
):
|
||||
api = ForgotPasswordSendEmailApi()
|
||||
response = api.post()
|
||||
|
||||
# Assert
|
||||
assert response["result"] == "success"
|
||||
assert response["data"] == "reset_token_123"
|
||||
mock_send_email.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app):
|
||||
"""
|
||||
Test password reset email blocked by IP rate limit.
|
||||
|
||||
Verifies that:
|
||||
- EmailSendIpLimitError is raised when IP limit exceeded
|
||||
- No email is sent when rate limited
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_ip_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context("/forgot-password", method="POST", json={"email": "test@example.com"}):
|
||||
api = ForgotPasswordSendEmailApi()
|
||||
with pytest.raises(EmailSendIpLimitError):
|
||||
api.post()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("language_input", "expected_language"),
|
||||
[
|
||||
("zh-Hans", "zh-Hans"),
|
||||
("en-US", "en-US"),
|
||||
("fr-FR", "en-US"), # Defaults to en-US for unsupported
|
||||
(None, "en-US"), # Defaults to en-US when not provided
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.forgot_password.FeatureService.get_system_features")
|
||||
def test_send_reset_email_language_handling(
|
||||
self,
|
||||
mock_get_features,
|
||||
mock_send_email,
|
||||
mock_select,
|
||||
mock_session,
|
||||
mock_is_ip_limit,
|
||||
mock_forgot_db,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
language_input,
|
||||
expected_language,
|
||||
):
|
||||
"""
|
||||
Test password reset email with different language preferences.
|
||||
|
||||
Verifies that:
|
||||
- Language parameter is correctly processed
|
||||
- Unsupported languages default to en-US
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_send_email.return_value = "token"
|
||||
mock_get_features.return_value.is_allow_register = True
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/forgot-password", method="POST", json={"email": "test@example.com", "language": language_input}
|
||||
):
|
||||
api = ForgotPasswordSendEmailApi()
|
||||
api.post()
|
||||
|
||||
# Assert
|
||||
call_args = mock_send_email.call_args
|
||||
assert call_args.kwargs["language"] == expected_language
|
||||
|
||||
|
||||
class TestForgotPasswordCheckApi:
|
||||
"""Test cases for verifying password reset codes."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
def test_verify_code_success(
|
||||
self,
|
||||
mock_reset_rate_limit,
|
||||
mock_generate_token,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
"""
|
||||
Test successful verification code validation.
|
||||
|
||||
Verifies that:
|
||||
- Valid code is accepted
|
||||
- Old token is revoked
|
||||
- New token is generated for reset phase
|
||||
- Rate limit is reset on success
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
mock_generate_token.return_value = (None, "new_token")
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "test@example.com", "code": "123456", "token": "old_token"},
|
||||
):
|
||||
api = ForgotPasswordCheckApi()
|
||||
response = api.post()
|
||||
|
||||
# Assert
|
||||
assert response["is_valid"] is True
|
||||
assert response["email"] == "test@example.com"
|
||||
assert response["token"] == "new_token"
|
||||
mock_revoke_token.assert_called_once_with("old_token")
|
||||
mock_reset_rate_limit.assert_called_once_with("test@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
|
||||
"""
|
||||
Test code verification blocked by rate limit.
|
||||
|
||||
Verifies that:
|
||||
- EmailPasswordResetLimitError is raised when limit exceeded
|
||||
- Prevents brute force attacks on verification codes
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "test@example.com", "code": "123456", "token": "token"},
|
||||
):
|
||||
api = ForgotPasswordCheckApi()
|
||||
with pytest.raises(EmailPasswordResetLimitError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
"""
|
||||
Test code verification with invalid token.
|
||||
|
||||
Verifies that:
|
||||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "test@example.com", "code": "123456", "token": "invalid_token"},
|
||||
):
|
||||
api = ForgotPasswordCheckApi()
|
||||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
"""
|
||||
Test code verification with mismatched email.
|
||||
|
||||
Verifies that:
|
||||
- InvalidEmailError is raised when email doesn't match token
|
||||
- Prevents token abuse
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "original@example.com", "code": "123456"}
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "different@example.com", "code": "123456", "token": "token"},
|
||||
):
|
||||
api = ForgotPasswordCheckApi()
|
||||
with pytest.raises(InvalidEmailError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
|
||||
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app):
|
||||
"""
|
||||
Test code verification with incorrect code.
|
||||
|
||||
Verifies that:
|
||||
- EmailCodeError is raised for wrong code
|
||||
- Rate limit counter is incremented
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "test@example.com", "code": "123456"}
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "test@example.com", "code": "wrong_code", "token": "token"},
|
||||
):
|
||||
api = ForgotPasswordCheckApi()
|
||||
with pytest.raises(EmailCodeError):
|
||||
api.post()
|
||||
|
||||
mock_add_rate_limit.assert_called_once_with("test@example.com")
|
||||
|
||||
|
||||
class TestForgotPasswordResetApi:
|
||||
"""Test cases for resetting password with verified token."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create mock account object."""
|
||||
account = MagicMock()
|
||||
account.email = "test@example.com"
|
||||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants")
|
||||
def test_reset_password_success(
|
||||
self,
|
||||
mock_get_tenants,
|
||||
mock_select,
|
||||
mock_session,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_forgot_db,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
):
|
||||
"""
|
||||
Test successful password reset.
|
||||
|
||||
Verifies that:
|
||||
- Password is updated with new hashed value
|
||||
- Token is revoked after use
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
"/forgot-password/resets",
|
||||
method="POST",
|
||||
json={"token": "valid_token", "new_password": "NewPass123!", "password_confirm": "NewPass123!"},
|
||||
):
|
||||
api = ForgotPasswordResetApi()
|
||||
response = api.post()
|
||||
|
||||
# Assert
|
||||
assert response["result"] == "success"
|
||||
mock_revoke_token.assert_called_once_with("valid_token")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_mismatch(self, mock_get_data, mock_db, app):
|
||||
"""
|
||||
Test password reset with mismatched passwords.
|
||||
|
||||
Verifies that:
|
||||
- PasswordMismatchError is raised when passwords don't match
|
||||
- No password update occurs
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/forgot-password/resets",
|
||||
method="POST",
|
||||
json={"token": "token", "new_password": "NewPass123!", "password_confirm": "DifferentPass123!"},
|
||||
):
|
||||
api = ForgotPasswordResetApi()
|
||||
with pytest.raises(PasswordMismatchError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_invalid_token(self, mock_get_data, mock_db, app):
|
||||
"""
|
||||
Test password reset with invalid token.
|
||||
|
||||
Verifies that:
|
||||
- InvalidTokenError is raised for invalid/expired tokens
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/forgot-password/resets",
|
||||
method="POST",
|
||||
json={"token": "invalid_token", "new_password": "NewPass123!", "password_confirm": "NewPass123!"},
|
||||
):
|
||||
api = ForgotPasswordResetApi()
|
||||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app):
|
||||
"""
|
||||
Test password reset with token not in reset phase.
|
||||
|
||||
Verifies that:
|
||||
- InvalidTokenError is raised when token is not in reset phase
|
||||
- Prevents use of verification-phase tokens for reset
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"}
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/forgot-password/resets",
|
||||
method="POST",
|
||||
json={"token": "token", "new_password": "NewPass123!", "password_confirm": "NewPass123!"},
|
||||
):
|
||||
api = ForgotPasswordResetApi()
|
||||
with pytest.raises(InvalidTokenError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
def test_reset_password_account_not_found(
|
||||
self, mock_select, mock_session, mock_revoke_token, mock_get_data, mock_forgot_db, mock_wraps_db, app
|
||||
):
|
||||
"""
|
||||
Test password reset for non-existent account.
|
||||
|
||||
Verifies that:
|
||||
- AccountNotFound is raised when account doesn't exist
|
||||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
"/forgot-password/resets",
|
||||
method="POST",
|
||||
json={"token": "token", "new_password": "NewPass123!", "password_confirm": "NewPass123!"},
|
||||
):
|
||||
api = ForgotPasswordResetApi()
|
||||
with pytest.raises(AccountNotFound):
|
||||
api.post()
|
||||
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Test suite for token refresh authentication flows.
|
||||
|
||||
This module tests the token refresh mechanism including:
|
||||
- Access token refresh using refresh token
|
||||
- Cookie-based token extraction and renewal
|
||||
- Token expiration and validation
|
||||
- Error handling for invalid tokens
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
from controllers.console.auth.login import RefreshTokenApi
|
||||
|
||||
|
||||
class TestRefreshTokenApi:
|
||||
"""Test cases for the RefreshTokenApi endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask test application."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def api(self, app):
|
||||
"""Create Flask-RESTX API instance."""
|
||||
return Api(app)
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, app, api):
|
||||
"""Create test client."""
|
||||
api.add_resource(RefreshTokenApi, "/refresh-token")
|
||||
return app.test_client()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token_pair(self):
|
||||
"""Create mock token pair object."""
|
||||
token_pair = MagicMock()
|
||||
token_pair.access_token = "new_access_token"
|
||||
token_pair.refresh_token = "new_refresh_token"
|
||||
token_pair.csrf_token = "new_csrf_token"
|
||||
return token_pair
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token")
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token")
|
||||
def test_successful_token_refresh(self, mock_refresh_token, mock_extract_token, app, mock_token_pair):
|
||||
"""
|
||||
Test successful token refresh flow.
|
||||
|
||||
Verifies that:
|
||||
- Refresh token is extracted from cookies
|
||||
- New token pair is generated
|
||||
- New tokens are set in response cookies
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_extract_token.return_value = "valid_refresh_token"
|
||||
mock_refresh_token.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/refresh-token", method="POST"):
|
||||
refresh_api = RefreshTokenApi()
|
||||
response = refresh_api.post()
|
||||
|
||||
# Assert
|
||||
mock_extract_token.assert_called_once()
|
||||
mock_refresh_token.assert_called_once_with("valid_refresh_token")
|
||||
assert response.json["result"] == "success"
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token")
|
||||
def test_refresh_fails_without_token(self, mock_extract_token, app):
|
||||
"""
|
||||
Test token refresh failure when no refresh token provided.
|
||||
|
||||
Verifies that:
|
||||
- Error is returned when refresh token is missing
|
||||
- 401 status code is returned
|
||||
- Appropriate error message is provided
|
||||
"""
|
||||
# Arrange
|
||||
mock_extract_token.return_value = None
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/refresh-token", method="POST"):
|
||||
refresh_api = RefreshTokenApi()
|
||||
response, status_code = refresh_api.post()
|
||||
|
||||
# Assert
|
||||
assert status_code == 401
|
||||
assert response["result"] == "fail"
|
||||
assert "No refresh token provided" in response["message"]
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token")
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token")
|
||||
def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app):
|
||||
"""
|
||||
Test token refresh failure with invalid refresh token.
|
||||
|
||||
Verifies that:
|
||||
- Exception is caught when token is invalid
|
||||
- 401 status code is returned
|
||||
- Error message is included in response
|
||||
"""
|
||||
# Arrange
|
||||
mock_extract_token.return_value = "invalid_refresh_token"
|
||||
mock_refresh_token.side_effect = Exception("Invalid refresh token")
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/refresh-token", method="POST"):
|
||||
refresh_api = RefreshTokenApi()
|
||||
response, status_code = refresh_api.post()
|
||||
|
||||
# Assert
|
||||
assert status_code == 401
|
||||
assert response["result"] == "fail"
|
||||
assert "Invalid refresh token" in response["message"]
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token")
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token")
|
||||
def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app):
|
||||
"""
|
||||
Test token refresh failure with expired refresh token.
|
||||
|
||||
Verifies that:
|
||||
- Expired tokens are rejected
|
||||
- 401 status code is returned
|
||||
- Appropriate error handling
|
||||
"""
|
||||
# Arrange
|
||||
mock_extract_token.return_value = "expired_refresh_token"
|
||||
mock_refresh_token.side_effect = Exception("Refresh token expired")
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/refresh-token", method="POST"):
|
||||
refresh_api = RefreshTokenApi()
|
||||
response, status_code = refresh_api.post()
|
||||
|
||||
# Assert
|
||||
assert status_code == 401
|
||||
assert response["result"] == "fail"
|
||||
assert "expired" in response["message"].lower()
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token")
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token")
|
||||
def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app):
|
||||
"""
|
||||
Test token refresh with empty string token.
|
||||
|
||||
Verifies that:
|
||||
- Empty string is treated as no token
|
||||
- 401 status code is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_extract_token.return_value = ""
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/refresh-token", method="POST"):
|
||||
refresh_api = RefreshTokenApi()
|
||||
response, status_code = refresh_api.post()
|
||||
|
||||
# Assert
|
||||
assert status_code == 401
|
||||
assert response["result"] == "fail"
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token")
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token")
|
||||
def test_refresh_updates_all_tokens(self, mock_refresh_token, mock_extract_token, app, mock_token_pair):
|
||||
"""
|
||||
Test that token refresh updates all three tokens.
|
||||
|
||||
Verifies that:
|
||||
- Access token is updated
|
||||
- Refresh token is rotated
|
||||
- CSRF token is regenerated
|
||||
"""
|
||||
# Arrange
|
||||
mock_extract_token.return_value = "valid_refresh_token"
|
||||
mock_refresh_token.return_value = mock_token_pair
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/refresh-token", method="POST"):
|
||||
refresh_api = RefreshTokenApi()
|
||||
response = refresh_api.post()
|
||||
|
||||
# Assert
|
||||
assert response.json["result"] == "success"
|
||||
# Verify new token pair was generated
|
||||
mock_refresh_token.assert_called_once_with("valid_refresh_token")
|
||||
# In real implementation, cookies would be set with new values
|
||||
assert mock_token_pair.access_token == "new_access_token"
|
||||
assert mock_token_pair.refresh_token == "new_refresh_token"
|
||||
assert mock_token_pair.csrf_token == "new_csrf_token"
|
||||
Reference in New Issue
Block a user