dify
This commit is contained in:
@@ -0,0 +1,271 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.storage.supabase_storage import SupabaseStorage
|
||||
|
||||
|
||||
class TestSupabaseStorage:
|
||||
"""Test suite for SupabaseStorage class."""
|
||||
|
||||
def test_init_success_with_all_config(self):
|
||||
"""Test successful initialization when all required config is provided."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock bucket_exists to return True so create_bucket is not called
|
||||
with patch.object(SupabaseStorage, "bucket_exists", return_value=True):
|
||||
storage = SupabaseStorage()
|
||||
|
||||
assert storage.bucket_name == "test-bucket"
|
||||
mock_client_class.assert_called_once_with(
|
||||
supabase_url="https://test.supabase.co", supabase_key="test-api-key"
|
||||
)
|
||||
|
||||
def test_init_raises_error_when_url_missing(self):
|
||||
"""Test initialization raises ValueError when SUPABASE_URL is None."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = None
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with pytest.raises(ValueError, match="SUPABASE_URL is not set"):
|
||||
SupabaseStorage()
|
||||
|
||||
def test_init_raises_error_when_api_key_missing(self):
|
||||
"""Test initialization raises ValueError when SUPABASE_API_KEY is None."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = None
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with pytest.raises(ValueError, match="SUPABASE_API_KEY is not set"):
|
||||
SupabaseStorage()
|
||||
|
||||
def test_init_raises_error_when_bucket_name_missing(self):
|
||||
"""Test initialization raises ValueError when SUPABASE_BUCKET_NAME is None."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = None
|
||||
|
||||
with pytest.raises(ValueError, match="SUPABASE_BUCKET_NAME is not set"):
|
||||
SupabaseStorage()
|
||||
|
||||
def test_create_bucket_when_not_exists(self):
|
||||
"""Test create_bucket creates bucket when it doesn't exist."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
with patch.object(SupabaseStorage, "bucket_exists", return_value=False):
|
||||
storage = SupabaseStorage()
|
||||
|
||||
mock_client.storage.create_bucket.assert_called_once_with(id="test-bucket", name="test-bucket")
|
||||
|
||||
def test_create_bucket_when_exists(self):
|
||||
"""Test create_bucket does not create bucket when it already exists."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
with patch.object(SupabaseStorage, "bucket_exists", return_value=True):
|
||||
storage = SupabaseStorage()
|
||||
|
||||
mock_client.storage.create_bucket.assert_not_called()
|
||||
|
||||
@pytest.fixture
|
||||
def storage_with_mock_client(self):
|
||||
"""Fixture providing SupabaseStorage with mocked client."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
with patch.object(SupabaseStorage, "bucket_exists", return_value=True):
|
||||
storage = SupabaseStorage()
|
||||
# Create fresh mock for each test
|
||||
mock_client.reset_mock()
|
||||
yield storage, mock_client
|
||||
|
||||
def test_save(self, storage_with_mock_client):
|
||||
"""Test save calls client.storage.from_(bucket).upload(path, data)."""
|
||||
storage, mock_client = storage_with_mock_client
|
||||
|
||||
filename = "test.txt"
|
||||
data = b"test data"
|
||||
|
||||
storage.save(filename, data)
|
||||
|
||||
mock_client.storage.from_.assert_called_once_with("test-bucket")
|
||||
mock_client.storage.from_().upload.assert_called_once_with(filename, data)
|
||||
|
||||
def test_load_once_returns_bytes(self, storage_with_mock_client):
|
||||
"""Test load_once returns bytes."""
|
||||
storage, mock_client = storage_with_mock_client
|
||||
|
||||
expected_data = b"test content"
|
||||
mock_client.storage.from_().download.return_value = expected_data
|
||||
|
||||
result = storage.load_once("test.txt")
|
||||
|
||||
assert result == expected_data
|
||||
# Verify the correct calls were made
|
||||
assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]]
|
||||
mock_client.storage.from_().download.assert_called_with("test.txt")
|
||||
|
||||
def test_load_stream_yields_chunks(self, storage_with_mock_client):
|
||||
"""Test load_stream yields chunks."""
|
||||
storage, mock_client = storage_with_mock_client
|
||||
|
||||
test_data = b"test content for streaming"
|
||||
mock_client.storage.from_().download.return_value = test_data
|
||||
|
||||
result = storage.load_stream("test.txt")
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
# Collect all chunks
|
||||
chunks = list(result)
|
||||
|
||||
# Verify chunks contain the expected data
|
||||
assert b"".join(chunks) == test_data
|
||||
# Verify the correct calls were made
|
||||
assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]]
|
||||
mock_client.storage.from_().download.assert_called_with("test.txt")
|
||||
|
||||
def test_download_writes_bytes_to_disk(self, storage_with_mock_client, tmp_path):
|
||||
"""Test download writes expected bytes to disk."""
|
||||
storage, mock_client = storage_with_mock_client
|
||||
|
||||
test_data = b"test file content"
|
||||
mock_client.storage.from_().download.return_value = test_data
|
||||
|
||||
target_file = tmp_path / "downloaded_file.txt"
|
||||
|
||||
storage.download("test.txt", str(target_file))
|
||||
|
||||
# Verify file was written with correct content
|
||||
assert target_file.read_bytes() == test_data
|
||||
# Verify the correct calls were made
|
||||
assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]]
|
||||
mock_client.storage.from_().download.assert_called_with("test.txt")
|
||||
|
||||
def test_exists_returns_true_when_file_found(self, storage_with_mock_client):
|
||||
"""Test exists returns True when list() returns items."""
|
||||
storage, mock_client = storage_with_mock_client
|
||||
|
||||
mock_client.storage.from_().list.return_value = [{"name": "test.txt"}]
|
||||
|
||||
result = storage.exists("test.txt")
|
||||
|
||||
assert result is True
|
||||
assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]]
|
||||
mock_client.storage.from_().list.assert_called_with(path="test.txt")
|
||||
|
||||
def test_exists_returns_false_when_file_not_found(self, storage_with_mock_client):
|
||||
"""Test exists returns False when list() returns an empty list."""
|
||||
storage, mock_client = storage_with_mock_client
|
||||
|
||||
mock_client.storage.from_().list.return_value = []
|
||||
|
||||
result = storage.exists("test.txt")
|
||||
|
||||
assert result is False
|
||||
assert "test-bucket" in [call[0][0] for call in mock_client.storage.from_.call_args_list if call[0]]
|
||||
mock_client.storage.from_().list.assert_called_with(path="test.txt")
|
||||
|
||||
def test_delete_calls_remove_with_filename_in_list(self, storage_with_mock_client):
|
||||
"""Test delete calls remove([...]) (some client versions require a list)."""
|
||||
storage, mock_client = storage_with_mock_client
|
||||
|
||||
filename = "test.txt"
|
||||
|
||||
storage.delete(filename)
|
||||
|
||||
mock_client.storage.from_.assert_called_once_with("test-bucket")
|
||||
mock_client.storage.from_().remove.assert_called_once_with([filename])
|
||||
|
||||
def test_bucket_exists_returns_true_when_bucket_found(self):
|
||||
"""Test bucket_exists returns True when bucket is found in list."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_bucket = Mock()
|
||||
mock_bucket.name = "test-bucket"
|
||||
mock_client.storage.list_buckets.return_value = [mock_bucket]
|
||||
storage = SupabaseStorage()
|
||||
result = storage.bucket_exists()
|
||||
|
||||
assert result is True
|
||||
assert mock_client.storage.list_buckets.call_count >= 1
|
||||
|
||||
def test_bucket_exists_returns_false_when_bucket_not_found(self):
|
||||
"""Test bucket_exists returns False when bucket is not found in list."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock different bucket
|
||||
mock_bucket = Mock()
|
||||
mock_bucket.name = "different-bucket"
|
||||
mock_client.storage.list_buckets.return_value = [mock_bucket]
|
||||
mock_client.storage.create_bucket = Mock()
|
||||
|
||||
storage = SupabaseStorage()
|
||||
result = storage.bucket_exists()
|
||||
|
||||
assert result is False
|
||||
assert mock_client.storage.list_buckets.call_count >= 1
|
||||
|
||||
def test_bucket_exists_returns_false_when_no_buckets(self):
|
||||
"""Test bucket_exists returns False when no buckets exist."""
|
||||
with patch("extensions.storage.supabase_storage.dify_config", autospec=True) as mock_config:
|
||||
mock_config.SUPABASE_URL = "https://test.supabase.co"
|
||||
mock_config.SUPABASE_API_KEY = "test-api-key"
|
||||
mock_config.SUPABASE_BUCKET_NAME = "test-bucket"
|
||||
|
||||
with patch("extensions.storage.supabase_storage.Client", autospec=True) as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_client.storage.list_buckets.return_value = []
|
||||
mock_client.storage.create_bucket = Mock()
|
||||
|
||||
storage = SupabaseStorage()
|
||||
result = storage.bucket_exists()
|
||||
|
||||
assert result is False
|
||||
assert mock_client.storage.list_buckets.call_count >= 1
|
||||
155
dify/api/tests/unit_tests/extensions/test_celery_ssl.py
Normal file
155
dify/api/tests/unit_tests/extensions/test_celery_ssl.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Tests for Celery SSL configuration."""
|
||||
|
||||
import ssl
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestCelerySSLConfiguration:
|
||||
"""Test suite for Celery SSL configuration."""
|
||||
|
||||
def test_get_celery_ssl_options_when_ssl_disabled(self):
|
||||
"""Test SSL options when REDIS_USE_SSL is False."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.REDIS_USE_SSL = False
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from extensions.ext_celery import _get_celery_ssl_options
|
||||
|
||||
result = _get_celery_ssl_options()
|
||||
assert result is None
|
||||
|
||||
def test_get_celery_ssl_options_when_broker_not_redis(self):
|
||||
"""Test SSL options when broker is not Redis."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.REDIS_USE_SSL = True
|
||||
mock_config.CELERY_BROKER_URL = "amqp://localhost:5672"
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from extensions.ext_celery import _get_celery_ssl_options
|
||||
|
||||
result = _get_celery_ssl_options()
|
||||
assert result is None
|
||||
|
||||
def test_get_celery_ssl_options_with_cert_none(self):
|
||||
"""Test SSL options with CERT_NONE requirement."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.REDIS_USE_SSL = True
|
||||
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||
mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE"
|
||||
mock_config.REDIS_SSL_CA_CERTS = None
|
||||
mock_config.REDIS_SSL_CERTFILE = None
|
||||
mock_config.REDIS_SSL_KEYFILE = None
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from extensions.ext_celery import _get_celery_ssl_options
|
||||
|
||||
result = _get_celery_ssl_options()
|
||||
assert result is not None
|
||||
assert result["ssl_cert_reqs"] == ssl.CERT_NONE
|
||||
assert result["ssl_ca_certs"] is None
|
||||
assert result["ssl_certfile"] is None
|
||||
assert result["ssl_keyfile"] is None
|
||||
|
||||
def test_get_celery_ssl_options_with_cert_required(self):
|
||||
"""Test SSL options with CERT_REQUIRED and certificates."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.REDIS_USE_SSL = True
|
||||
mock_config.CELERY_BROKER_URL = "rediss://localhost:6380/0"
|
||||
mock_config.REDIS_SSL_CERT_REQS = "CERT_REQUIRED"
|
||||
mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt"
|
||||
mock_config.REDIS_SSL_CERTFILE = "/path/to/client.crt"
|
||||
mock_config.REDIS_SSL_KEYFILE = "/path/to/client.key"
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from extensions.ext_celery import _get_celery_ssl_options
|
||||
|
||||
result = _get_celery_ssl_options()
|
||||
assert result is not None
|
||||
assert result["ssl_cert_reqs"] == ssl.CERT_REQUIRED
|
||||
assert result["ssl_ca_certs"] == "/path/to/ca.crt"
|
||||
assert result["ssl_certfile"] == "/path/to/client.crt"
|
||||
assert result["ssl_keyfile"] == "/path/to/client.key"
|
||||
|
||||
def test_get_celery_ssl_options_with_cert_optional(self):
|
||||
"""Test SSL options with CERT_OPTIONAL requirement."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.REDIS_USE_SSL = True
|
||||
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||
mock_config.REDIS_SSL_CERT_REQS = "CERT_OPTIONAL"
|
||||
mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt"
|
||||
mock_config.REDIS_SSL_CERTFILE = None
|
||||
mock_config.REDIS_SSL_KEYFILE = None
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from extensions.ext_celery import _get_celery_ssl_options
|
||||
|
||||
result = _get_celery_ssl_options()
|
||||
assert result is not None
|
||||
assert result["ssl_cert_reqs"] == ssl.CERT_OPTIONAL
|
||||
assert result["ssl_ca_certs"] == "/path/to/ca.crt"
|
||||
|
||||
def test_get_celery_ssl_options_with_invalid_cert_reqs(self):
|
||||
"""Test SSL options with invalid cert requirement defaults to CERT_NONE."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.REDIS_USE_SSL = True
|
||||
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||
mock_config.REDIS_SSL_CERT_REQS = "INVALID_VALUE"
|
||||
mock_config.REDIS_SSL_CA_CERTS = None
|
||||
mock_config.REDIS_SSL_CERTFILE = None
|
||||
mock_config.REDIS_SSL_KEYFILE = None
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from extensions.ext_celery import _get_celery_ssl_options
|
||||
|
||||
result = _get_celery_ssl_options()
|
||||
assert result is not None
|
||||
assert result["ssl_cert_reqs"] == ssl.CERT_NONE # Should default to CERT_NONE
|
||||
|
||||
def test_celery_init_applies_ssl_to_broker_and_backend(self):
|
||||
"""Test that SSL options are applied to both broker and backend when using Redis."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.REDIS_USE_SSL = True
|
||||
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||
mock_config.CELERY_BACKEND = "redis"
|
||||
mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0"
|
||||
mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE"
|
||||
mock_config.REDIS_SSL_CA_CERTS = None
|
||||
mock_config.REDIS_SSL_CERTFILE = None
|
||||
mock_config.REDIS_SSL_KEYFILE = None
|
||||
mock_config.CELERY_USE_SENTINEL = False
|
||||
mock_config.LOG_FORMAT = "%(message)s"
|
||||
mock_config.LOG_TZ = "UTC"
|
||||
mock_config.LOG_FILE = None
|
||||
|
||||
# Mock all the scheduler configs
|
||||
mock_config.CELERY_BEAT_SCHEDULER_TIME = 1
|
||||
mock_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK = False
|
||||
mock_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK = False
|
||||
mock_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK = False
|
||||
mock_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK = False
|
||||
mock_config.ENABLE_CLEAN_MESSAGES = False
|
||||
mock_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK = False
|
||||
mock_config.ENABLE_DATASETS_QUEUE_MONITOR = False
|
||||
mock_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK = False
|
||||
mock_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK = False
|
||||
mock_config.WORKFLOW_SCHEDULE_POLLER_INTERVAL = 1
|
||||
mock_config.WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE = 100
|
||||
mock_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK = 0
|
||||
mock_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK = False
|
||||
mock_config.TRIGGER_PROVIDER_REFRESH_INTERVAL = 15
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_celery import init_app
|
||||
|
||||
app = DifyApp(__name__)
|
||||
celery_app = init_app(app)
|
||||
|
||||
# Check that SSL options were applied
|
||||
assert "broker_use_ssl" in celery_app.conf
|
||||
assert celery_app.conf["broker_use_ssl"] is not None
|
||||
assert celery_app.conf["broker_use_ssl"]["ssl_cert_reqs"] == ssl.CERT_NONE
|
||||
|
||||
# Check that SSL is also applied to Redis backend
|
||||
assert "redis_backend_use_ssl" in celery_app.conf
|
||||
assert celery_app.conf["redis_backend_use_ssl"] is not None
|
||||
265
dify/api/tests/unit_tests/extensions/test_ext_request_logging.py
Normal file
265
dify/api/tests/unit_tests/extensions/test_ext_request_logging.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import json
|
||||
import logging
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask, Response
|
||||
|
||||
from configs import dify_config
|
||||
from extensions import ext_request_logging
|
||||
from extensions.ext_request_logging import _is_content_type_json, _log_request_finished, init_app
|
||||
|
||||
|
||||
def test_is_content_type_json():
|
||||
"""
|
||||
Test the _is_content_type_json function.
|
||||
"""
|
||||
|
||||
assert _is_content_type_json("application/json") is True
|
||||
# content type header with charset option.
|
||||
assert _is_content_type_json("application/json; charset=utf-8") is True
|
||||
# content type header with charset option, in uppercase.
|
||||
assert _is_content_type_json("APPLICATION/JSON; CHARSET=UTF-8") is True
|
||||
assert _is_content_type_json("text/html") is False
|
||||
assert _is_content_type_json("") is False
|
||||
|
||||
|
||||
_KEY_NEEDLE = "needle"
|
||||
_VALUE_NEEDLE = _KEY_NEEDLE[::-1]
|
||||
_RESPONSE_NEEDLE = "response"
|
||||
|
||||
|
||||
def _get_test_app():
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/", methods=["GET", "POST"])
|
||||
def handler():
|
||||
return _RESPONSE_NEEDLE
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# NOTE(QuantumGhost): Due to the design of Flask, we need to use monkey patch to write tests.
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_receiver(monkeypatch: pytest.MonkeyPatch) -> mock.Mock:
|
||||
mock_log_request_started = mock.Mock()
|
||||
monkeypatch.setattr(ext_request_logging, "_log_request_started", mock_log_request_started)
|
||||
return mock_log_request_started
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response_receiver(monkeypatch: pytest.MonkeyPatch) -> mock.Mock:
|
||||
mock_log_request_finished = mock.Mock()
|
||||
monkeypatch.setattr(ext_request_logging, "_log_request_finished", mock_log_request_finished)
|
||||
return mock_log_request_finished
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(monkeypatch: pytest.MonkeyPatch) -> logging.Logger:
|
||||
_logger = mock.MagicMock(spec=logging.Logger)
|
||||
monkeypatch.setattr(ext_request_logging, "logger", _logger)
|
||||
return _logger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_request_logging(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(dify_config, "ENABLE_REQUEST_LOGGING", True)
|
||||
|
||||
|
||||
class TestRequestLoggingExtension:
|
||||
def test_receiver_should_not_be_invoked_if_configuration_is_disabled(
|
||||
self,
|
||||
monkeypatch,
|
||||
mock_request_receiver,
|
||||
mock_response_receiver,
|
||||
):
|
||||
monkeypatch.setattr(dify_config, "ENABLE_REQUEST_LOGGING", False)
|
||||
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
client.get("/")
|
||||
|
||||
mock_request_receiver.assert_not_called()
|
||||
mock_response_receiver.assert_not_called()
|
||||
|
||||
def test_receiver_should_be_called_if_enabled(
|
||||
self,
|
||||
enable_request_logging,
|
||||
mock_request_receiver,
|
||||
mock_response_receiver,
|
||||
):
|
||||
"""
|
||||
Test the request logging extension with JSON data.
|
||||
"""
|
||||
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE})
|
||||
|
||||
mock_request_receiver.assert_called_once()
|
||||
mock_response_receiver.assert_called_once()
|
||||
|
||||
|
||||
class TestLoggingLevel:
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_logging_should_be_skipped_if_level_is_above_debug(self, enable_request_logging, mock_logger):
|
||||
mock_logger.isEnabledFor.return_value = False
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE})
|
||||
mock_logger.debug.assert_not_called()
|
||||
|
||||
|
||||
class TestRequestReceiverLogging:
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_non_json_request(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
client.post("/", data="plain text")
|
||||
assert mock_logger.debug.call_count == 1
|
||||
call_args = mock_logger.debug.call_args[0]
|
||||
assert "Received Request" in call_args[0]
|
||||
assert call_args[1] == "POST"
|
||||
assert call_args[2] == "/"
|
||||
assert "Request Body" not in call_args[0]
|
||||
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_json_request(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE})
|
||||
assert mock_logger.debug.call_count == 1
|
||||
call_args = mock_logger.debug.call_args[0]
|
||||
assert "Received Request" in call_args[0]
|
||||
assert "Request Body" in call_args[0]
|
||||
assert call_args[1] == "POST"
|
||||
assert call_args[2] == "/"
|
||||
assert _KEY_NEEDLE in call_args[3]
|
||||
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_json_request_with_empty_body(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
client.post("/", headers={"Content-Type": "application/json"})
|
||||
|
||||
assert mock_logger.debug.call_count == 1
|
||||
call_args = mock_logger.debug.call_args[0]
|
||||
assert "Received Request" in call_args[0]
|
||||
assert "Request Body" not in call_args[0]
|
||||
assert call_args[1] == "POST"
|
||||
assert call_args[2] == "/"
|
||||
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_json_request_with_invalid_json_as_body(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
client.post(
|
||||
"/",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data="{",
|
||||
)
|
||||
assert mock_logger.debug.call_count == 0
|
||||
assert mock_logger.exception.call_count == 1
|
||||
|
||||
exception_call_args = mock_logger.exception.call_args[0]
|
||||
assert exception_call_args[0] == "Failed to parse JSON request"
|
||||
|
||||
|
||||
class TestResponseReceiverLogging:
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_non_json_response(self, enable_request_logging, mock_logger):
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
app = _get_test_app()
|
||||
response = Response(
|
||||
"OK",
|
||||
headers={"Content-Type": "text/plain"},
|
||||
)
|
||||
_log_request_finished(app, response)
|
||||
assert mock_logger.debug.call_count == 1
|
||||
call_args = mock_logger.debug.call_args[0]
|
||||
assert "Response" in call_args[0]
|
||||
assert "200" in call_args[1]
|
||||
assert call_args[2] == "text/plain"
|
||||
assert "Response Body" not in call_args[0]
|
||||
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_json_response(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
app = _get_test_app()
|
||||
response = Response(
|
||||
json.dumps({_KEY_NEEDLE: _VALUE_NEEDLE}),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
_log_request_finished(app, response)
|
||||
assert mock_logger.debug.call_count == 1
|
||||
call_args = mock_logger.debug.call_args[0]
|
||||
assert "Response" in call_args[0]
|
||||
assert "Response Body" in call_args[0]
|
||||
assert "200" in call_args[1]
|
||||
assert call_args[2] == "application/json"
|
||||
assert _KEY_NEEDLE in call_args[3]
|
||||
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_json_request_with_invalid_json_as_body(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
app = _get_test_app()
|
||||
|
||||
response = Response(
|
||||
"{",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
_log_request_finished(app, response)
|
||||
assert mock_logger.debug.call_count == 0
|
||||
assert mock_logger.exception.call_count == 1
|
||||
|
||||
exception_call_args = mock_logger.exception.call_args[0]
|
||||
assert exception_call_args[0] == "Failed to parse JSON response"
|
||||
|
||||
|
||||
class TestResponseUnmodified:
|
||||
def test_when_request_logging_disabled(self):
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
response = client.post(
|
||||
"/",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data="{",
|
||||
)
|
||||
assert response.text == _RESPONSE_NEEDLE
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_when_request_logging_enabled(self, enable_request_logging):
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
response = client.post(
|
||||
"/",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data="{",
|
||||
)
|
||||
assert response.text == _RESPONSE_NEEDLE
|
||||
assert response.status_code == 200
|
||||
53
dify/api/tests/unit_tests/extensions/test_redis.py
Normal file
53
dify/api/tests/unit_tests/extensions/test_redis.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from redis import RedisError
|
||||
|
||||
from extensions.ext_redis import redis_fallback
|
||||
|
||||
|
||||
def test_redis_fallback_success():
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
return "success"
|
||||
|
||||
assert test_func() == "success"
|
||||
|
||||
|
||||
def test_redis_fallback_error():
|
||||
@redis_fallback(default_return="fallback")
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func() == "fallback"
|
||||
|
||||
|
||||
def test_redis_fallback_none_default():
|
||||
@redis_fallback()
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func() is None
|
||||
|
||||
|
||||
def test_redis_fallback_with_args():
|
||||
@redis_fallback(default_return=0)
|
||||
def test_func(x, y):
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func(1, 2) == 0
|
||||
|
||||
|
||||
def test_redis_fallback_with_kwargs():
|
||||
@redis_fallback(default_return={})
|
||||
def test_func(x=None, y=None):
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func(x=1, y=2) == {}
|
||||
|
||||
|
||||
def test_redis_fallback_preserves_function_metadata():
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
"""Test function docstring"""
|
||||
pass
|
||||
|
||||
assert test_func.__name__ == "test_func"
|
||||
assert test_func.__doc__ == "Test function docstring"
|
||||
Reference in New Issue
Block a user