dify
This commit is contained in:
223
dify/api/tests/integration_tests/.env.example
Normal file
223
dify/api/tests/integration_tests/.env.example
Normal file
@@ -0,0 +1,223 @@
|
||||
FLASK_APP=app.py
|
||||
FLASK_DEBUG=0
|
||||
SECRET_KEY='uhySf6a3aZuvRNfAlcr47paOw9TRYBY6j8ZHXpVw1yx5RP27Yj3w2uvI'
|
||||
|
||||
CONSOLE_API_URL=http://127.0.0.1:5001
|
||||
CONSOLE_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# Service API base URL
|
||||
SERVICE_API_URL=http://127.0.0.1:5001
|
||||
|
||||
# Web APP base URL
|
||||
APP_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# Files URL
|
||||
FILES_URL=http://127.0.0.1:5001
|
||||
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
# Refresh token expiration time in days
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
|
||||
# redis configuration
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_USERNAME=
|
||||
REDIS_PASSWORD=difyai123456
|
||||
REDIS_USE_SSL=false
|
||||
REDIS_DB=0
|
||||
|
||||
# PostgreSQL database configuration
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_DATABASE=dify
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
|
||||
STORAGE_TYPE=opendal
|
||||
|
||||
# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal
|
||||
OPENDAL_SCHEME=fs
|
||||
OPENDAL_FS_ROOT=storage
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration
|
||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
|
||||
VECTOR_STORE=weaviate
|
||||
# Weaviate configuration
|
||||
WEAVIATE_ENDPOINT=http://localhost:8080
|
||||
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
WEAVIATE_TOKENIZATION=word
|
||||
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=4096
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE=
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@example.com>
|
||||
RESEND_API_KEY=
|
||||
RESEND_API_URL=https://api.resend.com
|
||||
# smtp configuration
|
||||
SMTP_SERVER=smtp.example.com
|
||||
SMTP_PORT=465
|
||||
SMTP_USERNAME=123
|
||||
SMTP_PASSWORD=abc
|
||||
SMTP_USE_TLS=true
|
||||
SMTP_OPPORTUNISTIC_TLS=false
|
||||
|
||||
# Sentry configuration
|
||||
SENTRY_DSN=
|
||||
|
||||
# DEBUG
|
||||
DEBUG=false
|
||||
SQLALCHEMY_ECHO=false
|
||||
|
||||
# Notion import configuration, support public and internal
|
||||
NOTION_INTEGRATION_TYPE=public
|
||||
NOTION_CLIENT_SECRET=you-client-secret
|
||||
NOTION_CLIENT_ID=you-client-id
|
||||
NOTION_INTERNAL_SECRET=you-internal-secret
|
||||
|
||||
ETL_TYPE=dify
|
||||
UNSTRUCTURED_API_URL=
|
||||
UNSTRUCTURED_API_KEY=
|
||||
SCARF_NO_ANALYTICS=false
|
||||
|
||||
#ssrf
|
||||
SSRF_PROXY_HTTP_URL=
|
||||
SSRF_PROXY_HTTPS_URL=
|
||||
SSRF_DEFAULT_MAX_RETRIES=3
|
||||
SSRF_DEFAULT_TIME_OUT=5
|
||||
SSRF_DEFAULT_CONNECT_TIME_OUT=5
|
||||
SSRF_DEFAULT_READ_TIME_OUT=5
|
||||
SSRF_DEFAULT_WRITE_TIME_OUT=5
|
||||
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
KEYWORD_DATA_SOURCE_TYPE=database
|
||||
|
||||
# Workflow file upload limit
|
||||
WORKFLOW_FILE_UPLOAD_LIMIT=10
|
||||
|
||||
# CODE EXECUTION CONFIGURATION
|
||||
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
|
||||
CODE_EXECUTION_API_KEY=dify-sandbox
|
||||
CODE_MAX_NUMBER=9223372036854775807
|
||||
CODE_MIN_NUMBER=-9223372036854775808
|
||||
CODE_MAX_STRING_LENGTH=80000
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH=80000
|
||||
CODE_MAX_STRING_ARRAY_LENGTH=30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH=30
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH=1000
|
||||
|
||||
# API Tool configuration
|
||||
API_TOOL_DEFAULT_CONNECT_TIMEOUT=10
|
||||
API_TOOL_DEFAULT_READ_TIMEOUT=60
|
||||
|
||||
# HTTP Node configuration
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT=600
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
|
||||
# Webhook configuration
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
|
||||
|
||||
# Respect X-* headers to redirect clients
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
# Log file path
|
||||
LOG_FILE=
|
||||
# Log file max size, the unit is MB
|
||||
LOG_FILE_MAX_SIZE=20
|
||||
# Log file max backup count
|
||||
LOG_FILE_BACKUP_COUNT=5
|
||||
# Log dateformat
|
||||
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
|
||||
# Log Timezone
|
||||
LOG_TZ=UTC
|
||||
# Log format
|
||||
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
|
||||
|
||||
# Indexing configuration
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
||||
|
||||
# Workflow runtime configuration
|
||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
# Celery beat configuration
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
POSITION_TOOL_INCLUDES=
|
||||
POSITION_TOOL_EXCLUDES=
|
||||
|
||||
POSITION_PROVIDER_PINS=
|
||||
POSITION_PROVIDER_INCLUDES=
|
||||
POSITION_PROVIDER_EXCLUDES=
|
||||
|
||||
# Plugin configuration
|
||||
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
|
||||
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
|
||||
PLUGIN_REMOTE_INSTALL_PORT=5003
|
||||
PLUGIN_REMOTE_INSTALL_HOST=localhost
|
||||
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
||||
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||
|
||||
# Marketplace configuration
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# Endpoint configuration
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
# Reset password token expiry minutes
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5
|
||||
CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5
|
||||
OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
||||
|
||||
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
|
||||
MAX_SUBMIT_COUNT=100
|
||||
# Lockout duration in seconds
|
||||
LOGIN_LOCKOUT_DURATION=86400
|
||||
|
||||
HTTP_PROXY='http://127.0.0.1:1092'
|
||||
HTTPS_PROXY='http://127.0.0.1:1092'
|
||||
NO_PROXY='localhost,127.0.0.1'
|
||||
LOG_LEVEL=INFO
|
||||
1
dify/api/tests/integration_tests/.gitignore
vendored
Normal file
1
dify/api/tests/integration_tests/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
.env.test
|
||||
0
dify/api/tests/integration_tests/__init__.py
Normal file
0
dify/api/tests/integration_tests/__init__.py
Normal file
93
dify/api/tests/integration_tests/conftest.py
Normal file
93
dify/api/tests/integration_tests/conftest.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import pathlib
|
||||
import random
|
||||
import secrets
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app_factory import create_app
|
||||
from extensions.ext_database import db
|
||||
from models import Account, DifySetup, Tenant, TenantAccountJoin
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
|
||||
# Loading the .env file if it exists
|
||||
def _load_env():
|
||||
current_file_path = pathlib.Path(__file__).absolute()
|
||||
# Items later in the list have higher precedence.
|
||||
files_to_load = [".env", "vdb.env"]
|
||||
|
||||
env_file_paths = [current_file_path.parent / i for i in files_to_load]
|
||||
for path in env_file_paths:
|
||||
if not path.exists():
|
||||
continue
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Set `override=True` to ensure values from `vdb.env` take priority over values from `.env`
|
||||
load_dotenv(str(path), override=True)
|
||||
|
||||
|
||||
_load_env()
|
||||
|
||||
_CACHED_APP = create_app()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return _CACHED_APP
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def setup_account(request) -> Generator[Account, None, None]:
|
||||
"""`dify_setup` completes the setup process for the Dify application.
|
||||
|
||||
It creates `Account` and `Tenant`, and inserts a `DifySetup` record into the database.
|
||||
|
||||
Most tests in the `controllers` package may require dify has been successfully setup.
|
||||
"""
|
||||
with _CACHED_APP.test_request_context():
|
||||
rand_suffix = random.randint(int(1e6), int(1e7)) # noqa
|
||||
name = f"test-user-{rand_suffix}"
|
||||
email = f"{name}@example.com"
|
||||
RegisterService.setup(
|
||||
email=email,
|
||||
name=name,
|
||||
password=secrets.token_hex(16),
|
||||
ip_address="localhost",
|
||||
language="en-US",
|
||||
)
|
||||
|
||||
with _CACHED_APP.test_request_context():
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
account = session.query(Account).filter_by(email=email).one()
|
||||
|
||||
yield account
|
||||
|
||||
with _CACHED_APP.test_request_context():
|
||||
db.session.query(DifySetup).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_req_ctx():
|
||||
with _CACHED_APP.test_request_context():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_header(setup_account) -> dict[str, str]:
|
||||
token = AccountService.get_account_jwt_token(setup_account)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client() -> Generator[FlaskClient, None, None]:
|
||||
with _CACHED_APP.test_client() as client:
|
||||
yield client
|
||||
@@ -0,0 +1,219 @@
|
||||
"""Integration tests for ChatMessageApi permission verification."""
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from controllers.console.app import completion as completion_api
|
||||
from controllers.console.app import message as message_api
|
||||
from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
|
||||
class TestChatMessageApiPermissions:
|
||||
"""Test permission verification for ChatMessageApi endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model for testing."""
|
||||
app = App()
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Create a mock Account for testing."""
|
||||
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.last_active_at = naive_utc_now()
|
||||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
account.id = str(uuid.uuid4())
|
||||
|
||||
# Create mock tenant
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid.uuid4())
|
||||
|
||||
mock_session_instance = mock.Mock()
|
||||
|
||||
mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER)
|
||||
monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join))
|
||||
|
||||
mock_scalars_result = mock.Mock()
|
||||
mock_scalars_result.one.return_value = tenant
|
||||
monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result))
|
||||
|
||||
mock_session_context = mock.Mock()
|
||||
mock_session_context.__enter__.return_value = mock_session_instance
|
||||
monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context)
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "status"),
|
||||
[
|
||||
(TenantAccountRole.OWNER, 200),
|
||||
(TenantAccountRole.ADMIN, 200),
|
||||
(TenantAccountRole.EDITOR, 200),
|
||||
(TenantAccountRole.NORMAL, 403),
|
||||
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||
],
|
||||
)
|
||||
def test_post_with_owner_role_succeeds(
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
status: int,
|
||||
):
|
||||
"""Test that OWNER role can access chat-messages endpoint."""
|
||||
|
||||
"""Setup common mocks for testing."""
|
||||
# Mock app loading
|
||||
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
# Mock current user
|
||||
monkeypatch.setattr(completion_api, "current_user", mock_account)
|
||||
|
||||
mock_generate = mock.Mock(return_value={"message": "Test response"})
|
||||
monkeypatch.setattr(AppGenerateService, "generate", mock_generate)
|
||||
|
||||
# Set user role to OWNER
|
||||
mock_account.role = role
|
||||
|
||||
response = test_client.post(
|
||||
f"/console/api/apps/{mock_app_model.id}/chat-messages",
|
||||
headers=auth_header,
|
||||
json={
|
||||
"inputs": {},
|
||||
"query": "Hello, world!",
|
||||
"model_config": {
|
||||
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}}
|
||||
},
|
||||
"response_mode": "blocking",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "status"),
|
||||
[
|
||||
(TenantAccountRole.OWNER, 200),
|
||||
(TenantAccountRole.ADMIN, 200),
|
||||
(TenantAccountRole.EDITOR, 200),
|
||||
(TenantAccountRole.NORMAL, 403),
|
||||
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||
],
|
||||
)
|
||||
def test_get_requires_edit_permission(
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
status: int,
|
||||
):
|
||||
"""Ensure GET chat-messages endpoint enforces edit permissions."""
|
||||
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
conversation_id = uuid.uuid4()
|
||||
created_at = naive_utc_now()
|
||||
|
||||
mock_conversation = SimpleNamespace(id=str(conversation_id), app_id=str(mock_app_model.id))
|
||||
mock_message = SimpleNamespace(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=str(conversation_id),
|
||||
inputs=[],
|
||||
query="hello",
|
||||
message=[{"text": "hello"}],
|
||||
message_tokens=0,
|
||||
re_sign_file_url_answer="",
|
||||
answer_tokens=0,
|
||||
provider_response_latency=0.0,
|
||||
from_source="console",
|
||||
from_end_user_id=None,
|
||||
from_account_id=mock_account.id,
|
||||
feedbacks=[],
|
||||
workflow_run_id=None,
|
||||
annotation=None,
|
||||
annotation_hit_history=None,
|
||||
created_at=created_at,
|
||||
agent_thoughts=[],
|
||||
message_files=[],
|
||||
message_metadata_dict={},
|
||||
status="success",
|
||||
error="",
|
||||
parent_message_id=None,
|
||||
)
|
||||
|
||||
class MockQuery:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
if getattr(self.model, "__name__", "") == "Conversation":
|
||||
return mock_conversation
|
||||
return None
|
||||
|
||||
def order_by(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def limit(self, *_):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
if getattr(self.model, "__name__", "") == "Message":
|
||||
return [mock_message]
|
||||
return []
|
||||
|
||||
mock_session = mock.Mock()
|
||||
mock_session.query.side_effect = MockQuery
|
||||
mock_session.scalar.return_value = False
|
||||
|
||||
monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session))
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
class DummyPagination:
|
||||
def __init__(self, data, limit, has_more):
|
||||
self.data = data
|
||||
self.limit = limit
|
||||
self.has_more = has_more
|
||||
|
||||
monkeypatch.setattr(message_api, "InfiniteScrollPagination", DummyPagination)
|
||||
|
||||
mock_account.role = role
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/chat-messages",
|
||||
headers=auth_header,
|
||||
query_string={"conversation_id": str(conversation_id)},
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Unit tests for App description validation functions.
|
||||
|
||||
This test module validates the 400-character limit enforcement
|
||||
for App descriptions across all creation and editing endpoints.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the API root to Python path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
|
||||
|
||||
class TestAppDescriptionValidationUnit:
|
||||
"""Unit tests for description validation function"""
|
||||
|
||||
def test_validate_description_length_function(self):
|
||||
"""Test the validate_description_length function directly"""
|
||||
from libs.validators import validate_description_length
|
||||
|
||||
# Test valid descriptions
|
||||
assert validate_description_length("") == ""
|
||||
assert validate_description_length("x" * 400) == "x" * 400
|
||||
assert validate_description_length(None) is None
|
||||
|
||||
# Test invalid descriptions
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
validate_description_length("x" * 401)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
validate_description_length("x" * 500)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
validate_description_length("x" * 1000)
|
||||
assert "Description cannot exceed 400 characters." in str(exc_info.value)
|
||||
|
||||
def test_boundary_values(self):
|
||||
"""Test boundary values for description validation"""
|
||||
from libs.validators import validate_description_length
|
||||
|
||||
# Test exact boundary
|
||||
exactly_400 = "x" * 400
|
||||
assert validate_description_length(exactly_400) == exactly_400
|
||||
|
||||
# Test just over boundary
|
||||
just_over_400 = "x" * 401
|
||||
with pytest.raises(ValueError):
|
||||
validate_description_length(just_over_400)
|
||||
|
||||
# Test just under boundary
|
||||
just_under_400 = "x" * 399
|
||||
assert validate_description_length(just_under_400) == just_under_400
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test edge cases for description validation"""
|
||||
from libs.validators import validate_description_length
|
||||
|
||||
# Test None input
|
||||
assert validate_description_length(None) is None
|
||||
|
||||
# Test empty string
|
||||
assert validate_description_length("") == ""
|
||||
|
||||
# Test single character
|
||||
assert validate_description_length("a") == "a"
|
||||
|
||||
# Test unicode characters
|
||||
unicode_desc = "测试" * 200 # 400 characters in Chinese
|
||||
assert validate_description_length(unicode_desc) == unicode_desc
|
||||
|
||||
# Test unicode over limit
|
||||
unicode_over = "测试" * 201 # 402 characters
|
||||
with pytest.raises(ValueError):
|
||||
validate_description_length(unicode_over)
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
"""Test how validation handles whitespace"""
|
||||
from libs.validators import validate_description_length
|
||||
|
||||
# Test description with spaces
|
||||
spaces_400 = " " * 400
|
||||
assert validate_description_length(spaces_400) == spaces_400
|
||||
|
||||
# Test description with spaces over limit
|
||||
spaces_401 = " " * 401
|
||||
with pytest.raises(ValueError):
|
||||
validate_description_length(spaces_401)
|
||||
|
||||
# Test mixed content
|
||||
mixed_400 = "a" * 200 + " " * 200
|
||||
assert validate_description_length(mixed_400) == mixed_400
|
||||
|
||||
# Test mixed over limit
|
||||
mixed_401 = "a" * 200 + " " * 201
|
||||
with pytest.raises(ValueError):
|
||||
validate_description_length(mixed_401)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests directly
|
||||
import traceback
|
||||
|
||||
test_instance = TestAppDescriptionValidationUnit()
|
||||
test_methods = [method for method in dir(test_instance) if method.startswith("test_")]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test_method in test_methods:
|
||||
try:
|
||||
print(f"Running {test_method}...")
|
||||
getattr(test_instance, test_method)()
|
||||
print(f"✅ {test_method} PASSED")
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"❌ {test_method} FAILED: {str(e)}")
|
||||
traceback.print_exc()
|
||||
failed += 1
|
||||
|
||||
print(f"\n📊 Test Results: {passed} passed, {failed} failed")
|
||||
|
||||
if failed == 0:
|
||||
print("🎉 All tests passed!")
|
||||
else:
|
||||
print("💥 Some tests failed!")
|
||||
sys.exit(1)
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Basic integration tests for Feedback API endpoints."""
|
||||
|
||||
import uuid
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
|
||||
class TestFeedbackApiBasic:
|
||||
"""Basic tests for feedback API endpoints."""
|
||||
|
||||
def test_feedback_export_endpoint_exists(self, test_client: FlaskClient, auth_header):
|
||||
"""Test that feedback export endpoint exists and handles basic requests."""
|
||||
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
# Test endpoint exists (even if it fails, it should return 500 or 403, not 404)
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{app_id}/feedbacks/export", headers=auth_header, query_string={"format": "csv"}
|
||||
)
|
||||
|
||||
# Should not return 404 (endpoint exists)
|
||||
assert response.status_code != 404
|
||||
|
||||
# Should return authentication or permission error
|
||||
assert response.status_code in [401, 403, 500] # 500 if app doesn't exist, 403 if no permission
|
||||
|
||||
def test_feedback_summary_endpoint_exists(self, test_client: FlaskClient, auth_header):
|
||||
"""Test that feedback summary endpoint exists and handles basic requests."""
|
||||
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
# Test endpoint exists
|
||||
response = test_client.get(f"/console/api/apps/{app_id}/feedbacks/summary", headers=auth_header)
|
||||
|
||||
# Should not return 404 (endpoint exists)
|
||||
assert response.status_code != 404
|
||||
|
||||
# Should return authentication or permission error
|
||||
assert response.status_code in [401, 403, 500]
|
||||
|
||||
def test_feedback_export_invalid_format(self, test_client: FlaskClient, auth_header):
|
||||
"""Test feedback export endpoint with invalid format parameter."""
|
||||
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
# Test with invalid format
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{app_id}/feedbacks/export",
|
||||
headers=auth_header,
|
||||
query_string={"format": "invalid_format"},
|
||||
)
|
||||
|
||||
# Should not return 404
|
||||
assert response.status_code != 404
|
||||
|
||||
def test_feedback_export_with_filters(self, test_client: FlaskClient, auth_header):
|
||||
"""Test feedback export endpoint with various filter parameters."""
|
||||
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
# Test with various filter combinations
|
||||
filter_params = [
|
||||
{"from_source": "user"},
|
||||
{"rating": "like"},
|
||||
{"has_comment": True},
|
||||
{"start_date": "2024-01-01"},
|
||||
{"end_date": "2024-12-31"},
|
||||
{"format": "json"},
|
||||
{
|
||||
"from_source": "admin",
|
||||
"rating": "dislike",
|
||||
"has_comment": True,
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-12-31",
|
||||
"format": "csv",
|
||||
},
|
||||
]
|
||||
|
||||
for params in filter_params:
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{app_id}/feedbacks/export", headers=auth_header, query_string=params
|
||||
)
|
||||
|
||||
# Should not return 404
|
||||
assert response.status_code != 404
|
||||
|
||||
def test_feedback_export_invalid_dates(self, test_client: FlaskClient, auth_header):
|
||||
"""Test feedback export endpoint with invalid date formats."""
|
||||
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
# Test with invalid date formats
|
||||
invalid_dates = [
|
||||
{"start_date": "invalid-date"},
|
||||
{"end_date": "not-a-date"},
|
||||
{"start_date": "2024-13-01"}, # Invalid month
|
||||
{"end_date": "2024-12-32"}, # Invalid day
|
||||
]
|
||||
|
||||
for params in invalid_dates:
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{app_id}/feedbacks/export", headers=auth_header, query_string=params
|
||||
)
|
||||
|
||||
# Should not return 404
|
||||
assert response.status_code != 404
|
||||
@@ -0,0 +1,334 @@
|
||||
"""Integration tests for Feedback Export API endpoints."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from controllers.console.app import message as message_api
|
||||
from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import AppMode, MessageFeedback
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
|
||||
class TestFeedbackExportApi:
|
||||
"""Test feedback export API endpoints."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model for testing."""
|
||||
app = App()
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.name = "Test App"
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Create a mock Account for testing."""
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
account.last_active_at = naive_utc_now()
|
||||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
account.id = str(uuid.uuid4())
|
||||
|
||||
# Create mock tenant
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid.uuid4())
|
||||
|
||||
mock_session_instance = mock.Mock()
|
||||
|
||||
mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER)
|
||||
monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join))
|
||||
|
||||
mock_scalars_result = mock.Mock()
|
||||
mock_scalars_result.one.return_value = tenant
|
||||
monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result))
|
||||
|
||||
mock_session_context = mock.Mock()
|
||||
mock_session_context.__enter__.return_value = mock_session_instance
|
||||
monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context)
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account
|
||||
|
||||
@pytest.fixture
|
||||
def sample_feedback_data(self):
|
||||
"""Create sample feedback data for testing."""
|
||||
app_id = str(uuid.uuid4())
|
||||
conversation_id = str(uuid.uuid4())
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# Mock feedback data
|
||||
user_feedback = MessageFeedback(
|
||||
id=str(uuid.uuid4()),
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
message_id=message_id,
|
||||
rating="like",
|
||||
from_source="user",
|
||||
content=None,
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
from_account_id=None,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
admin_feedback = MessageFeedback(
|
||||
id=str(uuid.uuid4()),
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
message_id=message_id,
|
||||
rating="dislike",
|
||||
from_source="admin",
|
||||
content="The response was not helpful",
|
||||
from_end_user_id=None,
|
||||
from_account_id=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Mock message and conversation
|
||||
mock_message = SimpleNamespace(
|
||||
id=message_id,
|
||||
conversation_id=conversation_id,
|
||||
query="What is the weather today?",
|
||||
answer="It's sunny and 25 degrees outside.",
|
||||
inputs={"query": "What is the weather today?"},
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
mock_conversation = SimpleNamespace(id=conversation_id, name="Weather Conversation", app_id=app_id)
|
||||
|
||||
mock_app = SimpleNamespace(id=app_id, name="Weather App")
|
||||
|
||||
return {
|
||||
"user_feedback": user_feedback,
|
||||
"admin_feedback": admin_feedback,
|
||||
"message": mock_message,
|
||||
"conversation": mock_conversation,
|
||||
"app": mock_app,
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "status"),
|
||||
[
|
||||
(TenantAccountRole.OWNER, 200),
|
||||
(TenantAccountRole.ADMIN, 200),
|
||||
(TenantAccountRole.EDITOR, 200),
|
||||
(TenantAccountRole.NORMAL, 403),
|
||||
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||
],
|
||||
)
|
||||
def test_feedback_export_permissions(
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
status: int,
|
||||
):
|
||||
"""Test feedback export endpoint permissions."""
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
mock_export_feedbacks = mock.Mock(return_value="mock csv response")
|
||||
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
|
||||
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
# Set user role
|
||||
mock_account.role = role
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
|
||||
headers=auth_header,
|
||||
query_string={"format": "csv"},
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
|
||||
if status == 200:
|
||||
mock_export_feedbacks.assert_called_once()
|
||||
|
||||
def test_feedback_export_csv_format(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account, sample_feedback_data
|
||||
):
|
||||
"""Test feedback export in CSV format."""
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
# Create mock CSV response
|
||||
mock_csv_content = (
|
||||
"feedback_id,app_name,conversation_id,user_query,ai_response,feedback_rating,feedback_comment\n"
|
||||
)
|
||||
mock_csv_content += f"{sample_feedback_data['user_feedback'].id},{sample_feedback_data['app'].name},"
|
||||
mock_csv_content += f"{sample_feedback_data['conversation'].id},{sample_feedback_data['message'].query},"
|
||||
mock_csv_content += f"{sample_feedback_data['message'].answer},👍,\n"
|
||||
|
||||
mock_response = mock.Mock()
|
||||
mock_response.headers = {"Content-Type": "text/csv; charset=utf-8-sig"}
|
||||
mock_response.data = mock_csv_content.encode("utf-8")
|
||||
|
||||
mock_export_feedbacks = mock.Mock(return_value=mock_response)
|
||||
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
|
||||
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
|
||||
headers=auth_header,
|
||||
query_string={"format": "csv", "from_source": "user"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "text/csv" in response.content_type
|
||||
|
||||
def test_feedback_export_json_format(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account, sample_feedback_data
|
||||
):
|
||||
"""Test feedback export in JSON format."""
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
mock_json_response = {
|
||||
"export_info": {
|
||||
"app_id": mock_app_model.id,
|
||||
"export_date": datetime.now().isoformat(),
|
||||
"total_records": 2,
|
||||
"data_source": "dify_feedback_export",
|
||||
},
|
||||
"feedback_data": [
|
||||
{
|
||||
"feedback_id": sample_feedback_data["user_feedback"].id,
|
||||
"feedback_rating": "👍",
|
||||
"feedback_rating_raw": "like",
|
||||
"feedback_comment": "",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
mock_response = mock.Mock()
|
||||
mock_response.headers = {"Content-Type": "application/json; charset=utf-8"}
|
||||
mock_response.data = json.dumps(mock_json_response).encode("utf-8")
|
||||
|
||||
mock_export_feedbacks = mock.Mock(return_value=mock_response)
|
||||
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
|
||||
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
|
||||
headers=auth_header,
|
||||
query_string={"format": "json"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "application/json" in response.content_type
|
||||
|
||||
def test_feedback_export_with_filters(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
|
||||
):
|
||||
"""Test feedback export with various filters."""
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
mock_export_feedbacks = mock.Mock(return_value="mock filtered response")
|
||||
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
|
||||
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
# Test with multiple filters
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
|
||||
headers=auth_header,
|
||||
query_string={
|
||||
"from_source": "user",
|
||||
"rating": "dislike",
|
||||
"has_comment": True,
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-12-31",
|
||||
"format": "csv",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify service was called with correct parameters
|
||||
mock_export_feedbacks.assert_called_once_with(
|
||||
app_id=mock_app_model.id,
|
||||
from_source="user",
|
||||
rating="dislike",
|
||||
has_comment=True,
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
format_type="csv",
|
||||
)
|
||||
|
||||
def test_feedback_export_invalid_date_format(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
|
||||
):
|
||||
"""Test feedback export with invalid date format."""
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
# Mock the service to raise ValueError for invalid date
|
||||
mock_export_feedbacks = mock.Mock(side_effect=ValueError("Invalid date format"))
|
||||
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
|
||||
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
|
||||
headers=auth_header,
|
||||
query_string={"start_date": "invalid-date", "format": "csv"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
response_json = response.get_json()
|
||||
assert "Parameter validation error" in response_json["error"]
|
||||
|
||||
def test_feedback_export_server_error(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
|
||||
):
|
||||
"""Test feedback export with server error."""
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
# Mock the service to raise an exception
|
||||
mock_export_feedbacks = mock.Mock(side_effect=Exception("Database connection failed"))
|
||||
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
|
||||
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/feedbacks/export",
|
||||
headers=auth_header,
|
||||
query_string={"format": "csv"},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
@@ -0,0 +1,139 @@
|
||||
"""Integration tests for ModelConfigResource permission verification."""
|
||||
|
||||
import uuid
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from controllers.console.app import model_config as model_config_api
|
||||
from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import AppMode
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
class TestModelConfigResourcePermissions:
|
||||
"""Test permission verification for ModelConfigResource endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model for testing."""
|
||||
app = App()
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.app_model_config_id = str(uuid.uuid4())
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Create a mock Account for testing."""
|
||||
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account.id = str(uuid.uuid4())
|
||||
account.last_active_at = naive_utc_now()
|
||||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
|
||||
# Create mock tenant
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = str(uuid.uuid4())
|
||||
|
||||
mock_session_instance = mock.Mock()
|
||||
|
||||
mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER)
|
||||
monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join))
|
||||
|
||||
mock_scalars_result = mock.Mock()
|
||||
mock_scalars_result.one.return_value = tenant
|
||||
monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result))
|
||||
|
||||
mock_session_context = mock.Mock()
|
||||
mock_session_context.__enter__.return_value = mock_session_instance
|
||||
monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context)
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "status"),
|
||||
[
|
||||
(TenantAccountRole.OWNER, 200),
|
||||
(TenantAccountRole.ADMIN, 200),
|
||||
(TenantAccountRole.EDITOR, 200),
|
||||
(TenantAccountRole.NORMAL, 403),
|
||||
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||
],
|
||||
)
|
||||
def test_post_with_owner_role_succeeds(
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
status: int,
|
||||
):
|
||||
"""Test that OWNER role can access model-config endpoint."""
|
||||
# Set user role to OWNER
|
||||
mock_account.role = role
|
||||
|
||||
# Mock app loading
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
# Mock current user
|
||||
monkeypatch.setattr(model_config_api, "current_user", mock_account)
|
||||
|
||||
# Mock AccountService.load_user to prevent authentication issues
|
||||
from services.account_service import AccountService
|
||||
|
||||
mock_load_user = mock.Mock(return_value=mock_account)
|
||||
monkeypatch.setattr(AccountService, "load_user", mock_load_user)
|
||||
|
||||
mock_validate_config = mock.Mock(
|
||||
return_value={
|
||||
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}},
|
||||
"pre_prompt": "You are a helpful assistant.",
|
||||
"user_input_form": [],
|
||||
"dataset_query_variable": "",
|
||||
"agent_mode": {"enabled": False, "tools": []},
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(AppModelConfigService, "validate_configuration", mock_validate_config)
|
||||
|
||||
# Mock database operations
|
||||
mock_db_session = mock.Mock()
|
||||
mock_db_session.add = mock.Mock()
|
||||
mock_db_session.flush = mock.Mock()
|
||||
mock_db_session.commit = mock.Mock()
|
||||
monkeypatch.setattr(model_config_api.db, "session", mock_db_session)
|
||||
|
||||
# Mock app_model_config_was_updated event
|
||||
mock_event = mock.Mock()
|
||||
mock_event.send = mock.Mock()
|
||||
monkeypatch.setattr(model_config_api, "app_model_config_was_updated", mock_event)
|
||||
|
||||
response = test_client.post(
|
||||
f"/console/api/apps/{mock_app_model.id}/model-config",
|
||||
headers=auth_header,
|
||||
json={
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4",
|
||||
"mode": "chat",
|
||||
"completion_params": {"temperature": 0.7, "max_tokens": 1000},
|
||||
},
|
||||
"user_input_form": [],
|
||||
"dataset_query_variable": "",
|
||||
"pre_prompt": "You are a helpful assistant.",
|
||||
"agent_mode": {"enabled": False, "tools": []},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
@@ -0,0 +1,47 @@
|
||||
import uuid
|
||||
from unittest import mock
|
||||
|
||||
from controllers.console.app import workflow_draft_variable as draft_variable_api
|
||||
from controllers.console.app import wraps
|
||||
from factories.variable_factory import build_segment
|
||||
from models import App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
|
||||
def _get_mock_srv_class() -> type[WorkflowDraftVariableService]:
|
||||
return mock.create_autospec(WorkflowDraftVariableService)
|
||||
|
||||
|
||||
class TestWorkflowDraftNodeVariableListApi:
|
||||
def test_get(self, test_client, auth_header, monkeypatch):
|
||||
srv_class = _get_mock_srv_class()
|
||||
mock_app_model: App = App()
|
||||
mock_app_model.id = str(uuid.uuid4())
|
||||
test_node_id = "test_node_id"
|
||||
mock_app_model.mode = AppMode.ADVANCED_CHAT
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
|
||||
monkeypatch.setattr(draft_variable_api, "WorkflowDraftVariableService", srv_class)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id="test_app_1",
|
||||
node_id="test_node_1",
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
srv_instance = mock.create_autospec(WorkflowDraftVariableService, instance=True)
|
||||
srv_class.return_value = srv_instance
|
||||
srv_instance.list_node_variables.return_value = WorkflowDraftVariableList(variables=[var1])
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/workflows/draft/nodes/{test_node_id}/variables",
|
||||
headers=auth_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_dict = response.json
|
||||
assert isinstance(response_dict, dict)
|
||||
assert "items" in response_dict
|
||||
assert len(response_dict["items"]) == 1
|
||||
@@ -0,0 +1,368 @@
|
||||
import unittest
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import StorageKeyLoader
|
||||
from models import ToolFile, UploadFile
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
class TestStorageKeyLoader(unittest.TestCase):
|
||||
"""
|
||||
Integration tests for StorageKeyLoader class.
|
||||
|
||||
Tests the batched loading of storage keys from the database for files
|
||||
with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data before each test method."""
|
||||
self.session = db.session()
|
||||
self.tenant_id = str(uuid4())
|
||||
self.user_id = str(uuid4())
|
||||
self.conversation_id = str(uuid4())
|
||||
|
||||
# Create test data that will be cleaned up after each test
|
||||
self.test_upload_files = []
|
||||
self.test_tool_files = []
|
||||
|
||||
# Create StorageKeyLoader instance
|
||||
self.loader = StorageKeyLoader(self.session, self.tenant_id)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test data after each test method."""
|
||||
self.session.rollback()
|
||||
|
||||
def _create_upload_file(
|
||||
self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None
|
||||
) -> UploadFile:
|
||||
"""Helper method to create an UploadFile record for testing."""
|
||||
if file_id is None:
|
||||
file_id = str(uuid4())
|
||||
if storage_key is None:
|
||||
storage_key = f"test_storage_key_{uuid4()}"
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
key=storage_key,
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=self.user_id,
|
||||
created_at=datetime.now(UTC),
|
||||
used=False,
|
||||
)
|
||||
upload_file.id = file_id
|
||||
|
||||
self.session.add(upload_file)
|
||||
self.session.flush()
|
||||
self.test_upload_files.append(upload_file)
|
||||
|
||||
return upload_file
|
||||
|
||||
def _create_tool_file(
|
||||
self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None
|
||||
) -> ToolFile:
|
||||
"""Helper method to create a ToolFile record for testing."""
|
||||
if file_id is None:
|
||||
file_id = str(uuid4())
|
||||
if file_key is None:
|
||||
file_key = f"test_file_key_{uuid4()}"
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
tool_file = ToolFile(
|
||||
user_id=self.user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=self.conversation_id,
|
||||
file_key=file_key,
|
||||
mimetype="text/plain",
|
||||
original_url="http://example.com/file.txt",
|
||||
name="test_tool_file.txt",
|
||||
size=2048,
|
||||
)
|
||||
tool_file.id = file_id
|
||||
self.session.add(tool_file)
|
||||
self.session.flush()
|
||||
self.test_tool_files.append(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File:
|
||||
"""Helper method to create a File object for testing."""
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
# Set related_id for LOCAL_FILE and TOOL_FILE transfer methods
|
||||
file_related_id = None
|
||||
remote_url = None
|
||||
|
||||
if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE):
|
||||
file_related_id = related_id
|
||||
elif transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
remote_url = "https://example.com/test_file.txt"
|
||||
file_related_id = related_id
|
||||
|
||||
return File(
|
||||
id=str(uuid4()), # Generate new UUID for File.id
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=transfer_method,
|
||||
related_id=file_related_id,
|
||||
remote_url=remote_url,
|
||||
filename="test_file.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=1024,
|
||||
storage_key="initial_key",
|
||||
)
|
||||
|
||||
def test_load_storage_keys_local_file(self):
|
||||
"""Test loading storage keys for LOCAL_FILE transfer method."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_remote_url(self):
|
||||
"""Test loading storage keys for REMOTE_URL transfer method."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_tool_file(self):
|
||||
"""Test loading storage keys for TOOL_FILE transfer method."""
|
||||
# Create test data
|
||||
tool_file = self._create_tool_file()
|
||||
file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == tool_file.file_key
|
||||
|
||||
def test_load_storage_keys_mixed_methods(self):
|
||||
"""Test batch loading with mixed transfer methods."""
|
||||
# Create test data for different transfer methods
|
||||
upload_file1 = self._create_upload_file()
|
||||
upload_file2 = self._create_upload_file()
|
||||
tool_file = self._create_tool_file()
|
||||
|
||||
file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
||||
file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
files = [file1, file2, file3]
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys(files)
|
||||
|
||||
# Verify all storage keys were loaded correctly
|
||||
assert file1._storage_key == upload_file1.key
|
||||
assert file2._storage_key == upload_file2.key
|
||||
assert file3._storage_key == tool_file.file_key
|
||||
|
||||
def test_load_storage_keys_empty_list(self):
|
||||
"""Test with empty file list."""
|
||||
# Should not raise any exceptions
|
||||
self.loader.load_storage_keys([])
|
||||
|
||||
def test_load_storage_keys_tenant_mismatch(self):
|
||||
"""Test tenant_id validation."""
|
||||
# Create file with different tenant_id
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(
|
||||
related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())
|
||||
)
|
||||
|
||||
# Should raise ValueError for tenant mismatch
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
assert "invalid file, expected tenant_id" in str(context.value)
|
||||
|
||||
def test_load_storage_keys_missing_file_id(self):
|
||||
"""Test with None file.related_id."""
|
||||
# Create a file with valid parameters first, then manually set related_id to None
|
||||
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file.related_id = None
|
||||
|
||||
# Should raise ValueError for None file related_id
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
assert str(context.value) == "file id should not be None."
|
||||
|
||||
def test_load_storage_keys_nonexistent_upload_file_records(self):
|
||||
"""Test with missing UploadFile database records."""
|
||||
# Create file with non-existent upload file id
|
||||
non_existent_id = str(uuid4())
|
||||
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Should raise ValueError for missing record
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_nonexistent_tool_file_records(self):
|
||||
"""Test with missing ToolFile database records."""
|
||||
# Create file with non-existent tool file id
|
||||
non_existent_id = str(uuid4())
|
||||
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# Should raise ValueError for missing record
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_invalid_uuid(self):
|
||||
"""Test with invalid UUID format."""
|
||||
# Create a file with valid parameters first, then manually set invalid related_id
|
||||
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file.related_id = "invalid-uuid-format"
|
||||
|
||||
# Should raise ValueError for invalid UUID
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_batch_efficiency(self):
|
||||
"""Test batched operations use efficient queries."""
|
||||
# Create multiple files of different types
|
||||
upload_files = [self._create_upload_file() for _ in range(3)]
|
||||
tool_files = [self._create_tool_file() for _ in range(2)]
|
||||
|
||||
files = []
|
||||
files.extend(
|
||||
[self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files]
|
||||
)
|
||||
files.extend(
|
||||
[self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files]
|
||||
)
|
||||
|
||||
# Mock the session to count queries
|
||||
with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars:
|
||||
self.loader.load_storage_keys(files)
|
||||
|
||||
# Should make exactly 2 queries (one for upload_files, one for tool_files)
|
||||
assert mock_scalars.call_count == 2
|
||||
|
||||
# Verify all storage keys were loaded correctly
|
||||
for i, file in enumerate(files[:3]):
|
||||
assert file._storage_key == upload_files[i].key
|
||||
for i, file in enumerate(files[3:]):
|
||||
assert file._storage_key == tool_files[i].file_key
|
||||
|
||||
def test_load_storage_keys_tenant_isolation(self):
|
||||
"""Test that tenant isolation works correctly."""
|
||||
# Create files for different tenants
|
||||
other_tenant_id = str(uuid4())
|
||||
|
||||
# Create upload file for current tenant
|
||||
upload_file_current = self._create_upload_file()
|
||||
file_current = self._create_file(
|
||||
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
||||
)
|
||||
|
||||
# Create upload file for other tenant (but don't add to cleanup list)
|
||||
upload_file_other = UploadFile(
|
||||
tenant_id=other_tenant_id,
|
||||
storage_type="local",
|
||||
key="other_tenant_key",
|
||||
name="other_file.txt",
|
||||
size=1024,
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=self.user_id,
|
||||
created_at=datetime.now(UTC),
|
||||
used=False,
|
||||
)
|
||||
upload_file_other.id = str(uuid4())
|
||||
self.session.add(upload_file_other)
|
||||
self.session.flush()
|
||||
|
||||
# Create file for other tenant but try to load with current tenant's loader
|
||||
file_other = self._create_file(
|
||||
related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
|
||||
)
|
||||
|
||||
# Should raise ValueError due to tenant mismatch
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file_other])
|
||||
|
||||
assert "invalid file, expected tenant_id" in str(context.value)
|
||||
|
||||
# Current tenant's file should still work
|
||||
self.loader.load_storage_keys([file_current])
|
||||
assert file_current._storage_key == upload_file_current.key
|
||||
|
||||
def test_load_storage_keys_mixed_tenant_batch(self):
|
||||
"""Test batch with mixed tenant files (should fail on first mismatch)."""
|
||||
# Create files for current tenant
|
||||
upload_file_current = self._create_upload_file()
|
||||
file_current = self._create_file(
|
||||
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
||||
)
|
||||
|
||||
# Create file for different tenant
|
||||
other_tenant_id = str(uuid4())
|
||||
file_other = self._create_file(
|
||||
related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
|
||||
)
|
||||
|
||||
# Should raise ValueError on tenant mismatch
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file_current, file_other])
|
||||
|
||||
assert "invalid file, expected tenant_id" in str(context.value)
|
||||
|
||||
def test_load_storage_keys_duplicate_file_ids(self):
|
||||
"""Test handling of duplicate file IDs in the batch."""
|
||||
# Create upload file
|
||||
upload_file = self._create_upload_file()
|
||||
|
||||
# Create two File objects with same related_id
|
||||
file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Should handle duplicates gracefully
|
||||
self.loader.load_storage_keys([file1, file2])
|
||||
|
||||
# Both files should have the same storage key
|
||||
assert file1._storage_key == upload_file.key
|
||||
assert file2._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_session_isolation(self):
|
||||
"""Test that the loader uses the provided session correctly."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Create loader with different session (same underlying connection)
|
||||
|
||||
with Session(bind=db.engine) as other_session:
|
||||
other_loader = StorageKeyLoader(other_session, self.tenant_id)
|
||||
with pytest.raises(ValueError):
|
||||
other_loader.load_storage_keys([file])
|
||||
@@ -0,0 +1,41 @@
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass
|
||||
|
||||
|
||||
def mock_plugin_daemon(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> Callable[[], None]:
|
||||
"""
|
||||
mock openai module
|
||||
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
"""
|
||||
|
||||
def unpatch():
|
||||
monkeypatch.undo()
|
||||
|
||||
monkeypatch.setattr(PluginModelClient, "invoke_llm", MockModelClass.invoke_llm)
|
||||
monkeypatch.setattr(PluginModelClient, "fetch_model_providers", MockModelClass.fetch_model_providers)
|
||||
monkeypatch.setattr(PluginModelClient, "get_model_schema", MockModelClass.get_model_schema)
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_model_mock(monkeypatch: pytest.MonkeyPatch):
|
||||
if MOCK:
|
||||
unpatch = mock_plugin_daemon(monkeypatch)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
unpatch()
|
||||
@@ -0,0 +1,247 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from collections.abc import Generator, Sequence
|
||||
from decimal import Decimal
|
||||
from json import dumps
|
||||
|
||||
# import monkeypatch
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
)
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class MockModelClass(PluginModelClient):
|
||||
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
|
||||
"""
|
||||
Fetch model providers for the given tenant.
|
||||
"""
|
||||
return [
|
||||
PluginModelProviderEntity(
|
||||
id=uuid.uuid4().hex,
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
provider="openai",
|
||||
tenant_id=tenant_id,
|
||||
plugin_unique_identifier="langgenius/openai/openai",
|
||||
plugin_id="langgenius/openai",
|
||||
declaration=ProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(
|
||||
en_US="OpenAI",
|
||||
zh_Hans="OpenAI",
|
||||
),
|
||||
description=I18nObject(
|
||||
en_US="OpenAI",
|
||||
zh_Hans="OpenAI",
|
||||
),
|
||||
icon_small=I18nObject(
|
||||
en_US="https://example.com/icon_small.png",
|
||||
zh_Hans="https://example.com/icon_small.png",
|
||||
),
|
||||
icon_large=I18nObject(
|
||||
en_US="https://example.com/icon_large.png",
|
||||
zh_Hans="https://example.com/icon_large.png",
|
||||
),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
models=[
|
||||
AIModelEntity(
|
||||
model="gpt-3.5-turbo",
|
||||
label=I18nObject(
|
||||
en_US="gpt-3.5-turbo",
|
||||
zh_Hans="gpt-3.5-turbo",
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL],
|
||||
),
|
||||
AIModelEntity(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
label=I18nObject(
|
||||
en_US="gpt-3.5-turbo-instruct",
|
||||
zh_Hans="gpt-3.5-turbo-instruct",
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.COMPLETION,
|
||||
},
|
||||
features=[],
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
def get_model_schema(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US="OpenAI",
|
||||
zh_Hans="OpenAI",
|
||||
),
|
||||
model_type=ModelType(model_type),
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties={},
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL] if model == "gpt-3.5-turbo" else [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_function_call(
|
||||
tools: list[PromptMessageTool] | None,
|
||||
) -> AssistantPromptMessage.ToolCall | None:
|
||||
if not tools or len(tools) == 0:
|
||||
return None
|
||||
function: PromptMessageTool = tools[0]
|
||||
function_name = function.name
|
||||
function_parameters = function.parameters
|
||||
function_parameters_type = function_parameters["type"]
|
||||
if function_parameters_type != "object":
|
||||
return None
|
||||
function_parameters_properties = function_parameters["properties"]
|
||||
function_parameters_required = function_parameters["required"]
|
||||
parameters = {}
|
||||
for parameter_name, parameter in function_parameters_properties.items():
|
||||
if parameter_name not in function_parameters_required:
|
||||
continue
|
||||
parameter_type = parameter["type"]
|
||||
if parameter_type == "string":
|
||||
if "enum" in parameter:
|
||||
if len(parameter["enum"]) == 0:
|
||||
continue
|
||||
parameters[parameter_name] = parameter["enum"][0]
|
||||
else:
|
||||
parameters[parameter_name] = "kawaii"
|
||||
elif parameter_type == "integer":
|
||||
parameters[parameter_name] = 114514
|
||||
elif parameter_type == "number":
|
||||
parameters[parameter_name] = 1919810.0
|
||||
elif parameter_type == "boolean":
|
||||
parameters[parameter_name] = True
|
||||
|
||||
return AssistantPromptMessage.ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=function_name,
|
||||
arguments=dumps(parameters),
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_chat_create_sync(
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> LLMResult:
|
||||
tool_call = MockModelClass.generate_function_call(tools=tools)
|
||||
|
||||
return LLMResult(
|
||||
id=str(uuid.uuid4()),
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content="elaina", tool_calls=[tool_call] if tool_call else []),
|
||||
usage=LLMUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=1,
|
||||
total_tokens=3,
|
||||
prompt_unit_price=Decimal(0.0001),
|
||||
completion_unit_price=Decimal(0.0002),
|
||||
prompt_price_unit=Decimal(1),
|
||||
prompt_price=Decimal(0.0001),
|
||||
completion_price_unit=Decimal(1),
|
||||
completion_price=Decimal(0.0002),
|
||||
total_price=Decimal(0.0003),
|
||||
currency="USD",
|
||||
latency=0.001,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_chat_create_stream(
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
tool_call = MockModelClass.generate_function_call(tools=tools)
|
||||
|
||||
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
|
||||
for i in range(0, len(full_text) + 1):
|
||||
if i == len(full_text):
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content="",
|
||||
tool_calls=[tool_call] if tool_call else [],
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=full_text[i],
|
||||
tool_calls=[tool_call] if tool_call else [],
|
||||
),
|
||||
usage=LLMUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=17,
|
||||
total_tokens=19,
|
||||
prompt_unit_price=Decimal(0.0001),
|
||||
completion_unit_price=Decimal(0.0002),
|
||||
prompt_price_unit=Decimal(1),
|
||||
prompt_price=Decimal(0.0001),
|
||||
completion_price_unit=Decimal(1),
|
||||
completion_price=Decimal(0.0002),
|
||||
total_price=Decimal(0.0003),
|
||||
currency="USD",
|
||||
latency=0.001,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_llm(
|
||||
self: PluginModelClient,
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
):
|
||||
return MockModelClass.mocked_chat_create_stream(model=model, prompt_messages=prompt_messages, tools=tools)
|
||||
62
dify/api/tests/integration_tests/plugin/__mock/http.py
Normal file
62
dify/api/tests/integration_tests/plugin/__mock/http.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity, ToolProviderIdentity
|
||||
|
||||
|
||||
class MockedHttp:
|
||||
@classmethod
|
||||
def list_tools(cls) -> list[ToolProviderEntity]:
|
||||
return [
|
||||
ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
author="Yeuoly",
|
||||
name="Yeuoly",
|
||||
description=I18nObject(en_US="Yeuoly"),
|
||||
icon="ssss.svg",
|
||||
label=I18nObject(en_US="Yeuoly"),
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def requests_request(
|
||||
cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Mocked httpx.request
|
||||
"""
|
||||
request = httpx.Request(method, url)
|
||||
if url.endswith("/tools"):
|
||||
content = PluginDaemonBasicResponse[list[ToolProviderEntity]](
|
||||
code=0, message="success", data=cls.list_tools()
|
||||
).model_dump_json()
|
||||
else:
|
||||
raise ValueError("")
|
||||
|
||||
response = httpx.Response(status_code=200)
|
||||
response.request = request
|
||||
response._content = content.encode("utf-8")
|
||||
return response
|
||||
|
||||
|
||||
MOCK_SWITCH = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||
if MOCK_SWITCH:
|
||||
monkeypatch.setattr(httpx, "request", MockedHttp.requests_request)
|
||||
|
||||
def unpatch():
|
||||
monkeypatch.undo()
|
||||
|
||||
yield
|
||||
|
||||
if MOCK_SWITCH:
|
||||
unpatch()
|
||||
@@ -0,0 +1,8 @@
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from tests.integration_tests.plugin.__mock.http import setup_http_mock
|
||||
|
||||
|
||||
def test_fetch_all_plugin_tools(setup_http_mock):
|
||||
manager = PluginToolManager()
|
||||
tools = manager.fetch_tool_providers(tenant_id="test-tenant")
|
||||
assert len(tools) >= 1
|
||||
@@ -0,0 +1,779 @@
|
||||
import json
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import StringVariable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes import NodeType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from factories.variable_factory import build_segment
|
||||
from libs import datetime_utils
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, WorkflowNodeExecutionModel
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
DraftVarLoader,
|
||||
VariableResetError,
|
||||
WorkflowDraftVariableService,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
_test_app_id: str
|
||||
_session: Session
|
||||
_node1_id = "test_node_1"
|
||||
_node2_id = "test_node_2"
|
||||
_node_exec_id = str(uuid.uuid4())
|
||||
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
self._session: Session = db.session()
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="sys_var",
|
||||
value=build_segment("sys_value"),
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="conv_var",
|
||||
value=build_segment("conv_value"),
|
||||
)
|
||||
node2_vars = [
|
||||
WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node2_id,
|
||||
name="int_var",
|
||||
value=build_segment(1),
|
||||
visible=False,
|
||||
node_execution_id=self._node_exec_id,
|
||||
),
|
||||
WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node2_id,
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
visible=True,
|
||||
node_execution_id=self._node_exec_id,
|
||||
),
|
||||
]
|
||||
node1_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node1_id,
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
visible=True,
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
_variables = list(node2_vars)
|
||||
_variables.extend(
|
||||
[
|
||||
node1_var,
|
||||
sys_var,
|
||||
conv_var,
|
||||
]
|
||||
)
|
||||
|
||||
db.session.add_all(_variables)
|
||||
db.session.flush()
|
||||
self._variable_ids = [v.id for v in _variables]
|
||||
self._node1_str_var_id = node1_var.id
|
||||
self._sys_var_id = sys_var.id
|
||||
self._conv_var_id = conv_var.id
|
||||
self._node2_var_ids = [v.id for v in node2_vars]
|
||||
|
||||
def _get_test_srv(self) -> WorkflowDraftVariableService:
|
||||
return WorkflowDraftVariableService(session=self._session)
|
||||
|
||||
def tearDown(self):
|
||||
self._session.rollback()
|
||||
|
||||
def test_list_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2)
|
||||
assert var_list.total == 5
|
||||
assert len(var_list.variables) == 2
|
||||
page1_var_ids = {v.id for v in var_list.variables}
|
||||
assert page1_var_ids.issubset(self._variable_ids)
|
||||
|
||||
var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2)
|
||||
assert var_list_2.total is None
|
||||
assert len(var_list_2.variables) == 2
|
||||
page2_var_ids = {v.id for v in var_list_2.variables}
|
||||
assert page2_var_ids.isdisjoint(page1_var_ids)
|
||||
assert page2_var_ids.issubset(self._variable_ids)
|
||||
|
||||
def test_get_node_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var")
|
||||
assert node_var is not None
|
||||
assert node_var.id == self._node1_str_var_id
|
||||
assert node_var.name == "str_var"
|
||||
assert node_var.get_value() == build_segment("str_value")
|
||||
|
||||
def test_get_system_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
sys_var = srv.get_system_variable(self._test_app_id, "sys_var")
|
||||
assert sys_var is not None
|
||||
assert sys_var.id == self._sys_var_id
|
||||
assert sys_var.name == "sys_var"
|
||||
assert sys_var.get_value() == build_segment("sys_value")
|
||||
|
||||
def test_get_conversation_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var")
|
||||
assert conv_var is not None
|
||||
assert conv_var.id == self._conv_var_id
|
||||
assert conv_var.name == "conv_var"
|
||||
assert conv_var.get_value() == build_segment("conv_value")
|
||||
|
||||
def test_delete_node_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
srv.delete_node_variables(self._test_app_id, self._node2_id)
|
||||
node2_var_count = (
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == self._test_app_id,
|
||||
WorkflowDraftVariable.node_id == self._node2_id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
assert node2_var_count == 0
|
||||
|
||||
def test_delete_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
node_1_var = (
|
||||
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one()
|
||||
)
|
||||
srv.delete_variable(node_1_var)
|
||||
exists = bool(
|
||||
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first()
|
||||
)
|
||||
assert exists is False
|
||||
|
||||
def test__list_node_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
|
||||
assert len(node_vars.variables) == 2
|
||||
assert {v.id for v in node_vars.variables} == set(self._node2_var_ids)
|
||||
|
||||
def test_get_draft_variables_by_selectors(self):
|
||||
srv = self._get_test_srv()
|
||||
selectors = [
|
||||
[self._node1_id, "str_var"],
|
||||
[self._node2_id, "str_var"],
|
||||
[self._node2_id, "int_var"],
|
||||
]
|
||||
variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors)
|
||||
assert len(variables) == 3
|
||||
assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
class TestDraftVariableLoader(unittest.TestCase):
|
||||
_test_app_id: str
|
||||
_test_tenant_id: str
|
||||
|
||||
_node1_id = "test_loader_node_1"
|
||||
_node_exec_id = str(uuid.uuid4())
|
||||
|
||||
# @pytest.fixture
|
||||
# def test_app_id(self):
|
||||
# return str(uuid.uuid4())
|
||||
|
||||
# @pytest.fixture
|
||||
# def test_tenant_id(self):
|
||||
# return str(uuid.uuid4())
|
||||
|
||||
# @pytest.fixture
|
||||
# def session(self):
|
||||
# with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# yield session
|
||||
|
||||
# @pytest.fixture
|
||||
# def node_var(self, session):
|
||||
# pass
|
||||
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
self._test_tenant_id = str(uuid.uuid4())
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="sys_var",
|
||||
value=build_segment("sys_value"),
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="conv_var",
|
||||
value=build_segment("conv_value"),
|
||||
)
|
||||
node_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node1_id,
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
visible=True,
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
_variables = [
|
||||
node_var,
|
||||
sys_var,
|
||||
conv_var,
|
||||
]
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
session.add_all(_variables)
|
||||
session.flush()
|
||||
session.commit()
|
||||
self._variable_ids = [v.id for v in _variables]
|
||||
self._node_var_id = node_var.id
|
||||
self._sys_var_id = sys_var.id
|
||||
self._conv_var_id = conv_var.id
|
||||
|
||||
def tearDown(self):
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
|
||||
def test_variable_loader_with_empty_selector(self):
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
variables = var_loader.load_variables([])
|
||||
assert len(variables) == 0
|
||||
|
||||
def test_variable_loader_with_non_empty_selector(self):
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
variables = var_loader.load_variables(
|
||||
[
|
||||
[SYSTEM_VARIABLE_NODE_ID, "sys_var"],
|
||||
[CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
|
||||
[self._node1_id, "str_var"],
|
||||
]
|
||||
)
|
||||
assert len(variables) == 3
|
||||
conv_var = next(v for v in variables if v.selector[0] == CONVERSATION_VARIABLE_NODE_ID)
|
||||
assert conv_var.id == self._conv_var_id
|
||||
sys_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID)
|
||||
assert sys_var.id == self._sys_var_id
|
||||
node1_var = next(v for v in variables if v.selector[0] == self._node1_id)
|
||||
assert node1_var.id == self._node_var_id
|
||||
|
||||
@pytest.mark.usefixtures("setup_account")
|
||||
def test_load_offloaded_variable_string_type_integration(self, setup_account):
|
||||
"""Test _load_offloaded_variable with string type using DraftVariableSaver for data creation."""
|
||||
|
||||
# Create a large string that will be offloaded
|
||||
test_content = "x" * 15000 # Create a string larger than LARGE_VARIABLE_THRESHOLD (10KB)
|
||||
large_string_segment = StringSegment(value=test_content)
|
||||
|
||||
node_execution_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Use DraftVariableSaver to create offloaded variable (this mimics production)
|
||||
saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=self._test_app_id,
|
||||
node_id="test_offload_node",
|
||||
node_type=NodeType.LLM, # Use a real node type
|
||||
node_execution_id=node_execution_id,
|
||||
user=setup_account,
|
||||
)
|
||||
|
||||
# Save the variable - this will trigger offloading due to large size
|
||||
saver.save(outputs={"offloaded_string_var": large_string_segment})
|
||||
session.commit()
|
||||
|
||||
# Now test loading using DraftVarLoader
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
|
||||
# Load the variable using the standard workflow
|
||||
variables = var_loader.load_variables([["test_offload_node", "offloaded_string_var"]])
|
||||
|
||||
# Verify results
|
||||
assert len(variables) == 1
|
||||
loaded_variable = variables[0]
|
||||
assert loaded_variable.name == "offloaded_string_var"
|
||||
assert loaded_variable.selector == ["test_offload_node", "offloaded_string_var"]
|
||||
assert isinstance(loaded_variable.value, StringSegment)
|
||||
assert loaded_variable.value.value == test_content
|
||||
|
||||
finally:
|
||||
# Clean up - delete all draft variables for this app
|
||||
with Session(bind=db.engine) as session:
|
||||
service = WorkflowDraftVariableService(session)
|
||||
service.delete_workflow_variables(self._test_app_id)
|
||||
session.commit()
|
||||
|
||||
def test_load_offloaded_variable_object_type_integration(self):
|
||||
"""Test _load_offloaded_variable with object type using real storage and service."""
|
||||
|
||||
# Create a test object
|
||||
test_object = {"key1": "value1", "key2": 42, "nested": {"inner": "data"}}
|
||||
test_json = json.dumps(test_object, ensure_ascii=False, separators=(",", ":"))
|
||||
content_bytes = test_json.encode()
|
||||
|
||||
# Create an upload file record
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._test_tenant_id,
|
||||
storage_type="local",
|
||||
key=f"test_offload_{uuid.uuid4()}.json",
|
||||
name="test_offload.json",
|
||||
size=len(content_bytes),
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
used=True,
|
||||
used_by=str(uuid.uuid4()),
|
||||
used_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store the content in storage
|
||||
storage.save(upload_file.key, content_bytes)
|
||||
|
||||
# Create a variable file record
|
||||
variable_file = WorkflowDraftVariableFile(
|
||||
upload_file_id=upload_file.id,
|
||||
value_type=SegmentType.OBJECT,
|
||||
tenant_id=self._test_tenant_id,
|
||||
app_id=self._test_app_id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
size=len(content_bytes),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
try:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Add upload file and variable file first to get their IDs
|
||||
session.add_all([upload_file, variable_file])
|
||||
session.flush() # This generates the IDs
|
||||
|
||||
# Now create the offloaded draft variable with the correct file_id
|
||||
offloaded_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id="test_offload_node",
|
||||
name="offloaded_object_var",
|
||||
value=build_segment({"truncated": True}),
|
||||
visible=True,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
offloaded_var.file_id = variable_file.id
|
||||
|
||||
session.add(offloaded_var)
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
# Use the service method that properly preloads relationships
|
||||
service = WorkflowDraftVariableService(session)
|
||||
draft_vars = service.get_draft_variables_by_selectors(
|
||||
self._test_app_id, [["test_offload_node", "offloaded_object_var"]]
|
||||
)
|
||||
|
||||
assert len(draft_vars) == 1
|
||||
loaded_var = draft_vars[0]
|
||||
assert loaded_var.is_truncated()
|
||||
|
||||
# Create DraftVarLoader and test loading
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
|
||||
# Test the _load_offloaded_variable method
|
||||
selector_tuple, variable = var_loader._load_offloaded_variable(loaded_var)
|
||||
|
||||
# Verify the results
|
||||
assert selector_tuple == ("test_offload_node", "offloaded_object_var")
|
||||
assert variable.id == loaded_var.id
|
||||
assert variable.name == "offloaded_object_var"
|
||||
assert variable.value.value == test_object
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
with Session(bind=db.engine) as session:
|
||||
# Query and delete by ID to ensure they're tracked in this session
|
||||
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
|
||||
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
|
||||
session.query(UploadFile).filter_by(id=upload_file.id).delete()
|
||||
session.commit()
|
||||
# Clean up storage
|
||||
try:
|
||||
storage.delete(upload_file.key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup failures
|
||||
|
||||
def test_load_variables_with_offloaded_variables_integration(self):
|
||||
"""Test load_variables method with mix of regular and offloaded variables using real storage."""
|
||||
# Create a regular variable (already exists from setUp)
|
||||
# Create offloaded variable content
|
||||
test_content = "This is offloaded content for integration test"
|
||||
content_bytes = test_content.encode()
|
||||
|
||||
# Create upload file record
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._test_tenant_id,
|
||||
storage_type="local",
|
||||
key=f"test_integration_{uuid.uuid4()}.txt",
|
||||
name="test_integration.txt",
|
||||
size=len(content_bytes),
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
used=True,
|
||||
used_by=str(uuid.uuid4()),
|
||||
used_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store the content
|
||||
storage.save(upload_file.key, content_bytes)
|
||||
|
||||
# Create variable file
|
||||
variable_file = WorkflowDraftVariableFile(
|
||||
upload_file_id=upload_file.id,
|
||||
value_type=SegmentType.STRING,
|
||||
tenant_id=self._test_tenant_id,
|
||||
app_id=self._test_app_id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
size=len(content_bytes),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
try:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Add upload file and variable file first to get their IDs
|
||||
session.add_all([upload_file, variable_file])
|
||||
session.flush() # This generates the IDs
|
||||
|
||||
# Now create the offloaded draft variable with the correct file_id
|
||||
offloaded_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id="test_integration_node",
|
||||
name="offloaded_integration_var",
|
||||
value=build_segment("truncated"),
|
||||
visible=True,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
offloaded_var.file_id = variable_file.id
|
||||
|
||||
session.add(offloaded_var)
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
# Test load_variables with both regular and offloaded variables
|
||||
# This method should handle the relationship preloading internally
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
|
||||
variables = var_loader.load_variables(
|
||||
[
|
||||
[SYSTEM_VARIABLE_NODE_ID, "sys_var"], # Regular variable from setUp
|
||||
["test_integration_node", "offloaded_integration_var"], # Offloaded variable
|
||||
]
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert len(variables) == 2
|
||||
|
||||
# Find regular variable
|
||||
regular_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID)
|
||||
assert regular_var.id == self._sys_var_id
|
||||
assert regular_var.value == "sys_value"
|
||||
|
||||
# Find offloaded variable
|
||||
offloaded_loaded_var = next(v for v in variables if v.selector[0] == "test_integration_node")
|
||||
assert offloaded_loaded_var.id == offloaded_var.id
|
||||
assert offloaded_loaded_var.value == test_content
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
with Session(bind=db.engine) as session:
|
||||
# Query and delete by ID to ensure they're tracked in this session
|
||||
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
|
||||
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
|
||||
session.query(UploadFile).filter_by(id=upload_file.id).delete()
|
||||
session.commit()
|
||||
# Clean up storage
|
||||
try:
|
||||
storage.delete(upload_file.key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup failures
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
"""Integration tests for reset_variable functionality using real database"""
|
||||
|
||||
_test_app_id: str
|
||||
_test_tenant_id: str
|
||||
_test_workflow_id: str
|
||||
_session: Session
|
||||
_node_id = "test_reset_node"
|
||||
_node_exec_id: str
|
||||
_workflow_node_exec_id: str
|
||||
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
self._test_tenant_id = str(uuid.uuid4())
|
||||
self._test_workflow_id = str(uuid.uuid4())
|
||||
self._node_exec_id = str(uuid.uuid4())
|
||||
self._workflow_node_exec_id = str(uuid.uuid4())
|
||||
self._session: Session = db.session()
|
||||
|
||||
# Create a workflow node execution record with outputs
|
||||
# Note: The WorkflowNodeExecutionModel.id should match the node_execution_id in WorkflowDraftVariable
|
||||
self._workflow_node_execution = WorkflowNodeExecutionModel(
|
||||
id=self._node_exec_id, # This should match the node_execution_id in the variable
|
||||
tenant_id=self._test_tenant_id,
|
||||
app_id=self._test_app_id,
|
||||
workflow_id=self._test_workflow_id,
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
index=1,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
node_id=self._node_id,
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
inputs='{"input": "test input"}',
|
||||
process_data='{"test_var": "process_value", "other_var": "other_process"}',
|
||||
outputs='{"test_var": "output_value", "other_var": "other_output"}',
|
||||
status="succeeded",
|
||||
elapsed_time=1.5,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
# Create conversation variables for the workflow
|
||||
self._conv_variables = [
|
||||
StringVariable(
|
||||
id=str(uuid.uuid4()),
|
||||
name="conv_var_1",
|
||||
description="Test conversation variable 1",
|
||||
value="default_value_1",
|
||||
),
|
||||
StringVariable(
|
||||
id=str(uuid.uuid4()),
|
||||
name="conv_var_2",
|
||||
description="Test conversation variable 2",
|
||||
value="default_value_2",
|
||||
),
|
||||
]
|
||||
|
||||
# Create test variables
|
||||
self._node_var_with_exec = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node_id,
|
||||
name="test_var",
|
||||
value=build_segment("old_value"),
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
self._node_var_with_exec.last_edited_at = datetime_utils.naive_utc_now()
|
||||
|
||||
self._node_var_without_exec = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node_id,
|
||||
name="no_exec_var",
|
||||
value=build_segment("some_value"),
|
||||
node_execution_id="temp_exec_id",
|
||||
)
|
||||
# Manually set node_execution_id to None after creation
|
||||
self._node_var_without_exec.node_execution_id = None
|
||||
|
||||
self._node_var_missing_exec = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id=self._node_id,
|
||||
name="missing_exec_var",
|
||||
value=build_segment("some_value"),
|
||||
node_execution_id=str(uuid.uuid4()), # Use a valid UUID that doesn't exist in database
|
||||
)
|
||||
|
||||
self._conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="conv_var_1",
|
||||
value=build_segment("old_conv_value"),
|
||||
)
|
||||
self._conv_var.last_edited_at = datetime_utils.naive_utc_now()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as persistent_session, persistent_session.begin():
|
||||
persistent_session.add(
|
||||
self._workflow_node_execution,
|
||||
)
|
||||
|
||||
# Add all to database
|
||||
db.session.add_all(
|
||||
[
|
||||
self._node_var_with_exec,
|
||||
self._node_var_without_exec,
|
||||
self._node_var_missing_exec,
|
||||
self._conv_var,
|
||||
]
|
||||
)
|
||||
db.session.flush()
|
||||
|
||||
# Store IDs for assertions
|
||||
self._node_var_with_exec_id = self._node_var_with_exec.id
|
||||
self._node_var_without_exec_id = self._node_var_without_exec.id
|
||||
self._node_var_missing_exec_id = self._node_var_missing_exec.id
|
||||
self._conv_var_id = self._conv_var.id
|
||||
|
||||
def tearDown(self):
|
||||
self._session.rollback()
|
||||
with Session(db.engine) as session, session.begin():
|
||||
stmt = delete(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.id == self._workflow_node_execution.id
|
||||
)
|
||||
session.execute(stmt)
|
||||
|
||||
def _get_test_srv(self) -> WorkflowDraftVariableService:
|
||||
return WorkflowDraftVariableService(session=self._session)
|
||||
|
||||
def _create_mock_workflow(self) -> Workflow:
|
||||
"""Create a real workflow with conversation variables and graph"""
|
||||
conversation_vars = self._conv_variables
|
||||
|
||||
# Create a simple graph with the test node
|
||||
graph = {
|
||||
"nodes": [{"id": "test_reset_node", "type": "llm", "title": "Test Node", "data": {"type": "llm"}}],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
workflow = Workflow.new(
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
app_id=self._test_app_id,
|
||||
type="workflow",
|
||||
version="1.0",
|
||||
graph=json.dumps(graph),
|
||||
features="{}",
|
||||
created_by=str(uuid.uuid4()),
|
||||
environment_variables=[],
|
||||
conversation_variables=conversation_vars,
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
return workflow
|
||||
|
||||
def test_reset_node_variable_with_valid_execution_record(self):
|
||||
"""Test resetting a node variable with valid execution record - should restore from execution"""
|
||||
srv = self._get_test_srv()
|
||||
mock_workflow = self._create_mock_workflow()
|
||||
|
||||
# Get the variable before reset
|
||||
variable = srv.get_variable(self._node_var_with_exec_id)
|
||||
assert variable is not None
|
||||
assert variable.get_value().value == "old_value"
|
||||
assert variable.last_edited_at is not None
|
||||
|
||||
# Reset the variable
|
||||
result = srv.reset_variable(mock_workflow, variable)
|
||||
|
||||
# Should return the updated variable
|
||||
assert result is not None
|
||||
assert result.id == self._node_var_with_exec_id
|
||||
assert result.node_execution_id == self._workflow_node_execution.id
|
||||
assert result.last_edited_at is None # Should be reset to None
|
||||
|
||||
# The returned variable should have the updated value from execution record
|
||||
assert result.get_value().value == "output_value"
|
||||
|
||||
# Verify the variable was updated in database
|
||||
updated_variable = srv.get_variable(self._node_var_with_exec_id)
|
||||
assert updated_variable is not None
|
||||
# The value should be updated from the execution record's outputs
|
||||
assert updated_variable.get_value().value == "output_value"
|
||||
assert updated_variable.last_edited_at is None
|
||||
assert updated_variable.node_execution_id == self._workflow_node_execution.id
|
||||
|
||||
def test_reset_node_variable_with_no_execution_id(self):
|
||||
"""Test resetting a node variable with no execution ID - should delete variable"""
|
||||
srv = self._get_test_srv()
|
||||
mock_workflow = self._create_mock_workflow()
|
||||
|
||||
# Get the variable before reset
|
||||
variable = srv.get_variable(self._node_var_without_exec_id)
|
||||
assert variable is not None
|
||||
|
||||
# Reset the variable
|
||||
result = srv.reset_variable(mock_workflow, variable)
|
||||
|
||||
# Should return None (variable deleted)
|
||||
assert result is None
|
||||
|
||||
# Verify the variable was deleted
|
||||
deleted_variable = srv.get_variable(self._node_var_without_exec_id)
|
||||
assert deleted_variable is None
|
||||
|
||||
def test_reset_node_variable_with_missing_execution_record(self):
|
||||
"""Test resetting a node variable when execution record doesn't exist"""
|
||||
srv = self._get_test_srv()
|
||||
mock_workflow = self._create_mock_workflow()
|
||||
|
||||
# Get the variable before reset
|
||||
variable = srv.get_variable(self._node_var_missing_exec_id)
|
||||
assert variable is not None
|
||||
|
||||
# Reset the variable
|
||||
result = srv.reset_variable(mock_workflow, variable)
|
||||
|
||||
# Should return None (variable deleted)
|
||||
assert result is None
|
||||
|
||||
# Verify the variable was deleted
|
||||
deleted_variable = srv.get_variable(self._node_var_missing_exec_id)
|
||||
assert deleted_variable is None
|
||||
|
||||
def test_reset_conversation_variable(self):
|
||||
"""Test resetting a conversation variable"""
|
||||
srv = self._get_test_srv()
|
||||
mock_workflow = self._create_mock_workflow()
|
||||
|
||||
# Get the variable before reset
|
||||
variable = srv.get_variable(self._conv_var_id)
|
||||
assert variable is not None
|
||||
assert variable.get_value().value == "old_conv_value"
|
||||
assert variable.last_edited_at is not None
|
||||
|
||||
# Reset the variable
|
||||
result = srv.reset_variable(mock_workflow, variable)
|
||||
|
||||
# Should return the updated variable
|
||||
assert result is not None
|
||||
assert result.id == self._conv_var_id
|
||||
assert result.last_edited_at is None # Should be reset to None
|
||||
|
||||
# Verify the variable was updated with default value from workflow
|
||||
updated_variable = srv.get_variable(self._conv_var_id)
|
||||
assert updated_variable is not None
|
||||
# The value should be updated from the workflow's conversation variable default
|
||||
assert updated_variable.get_value().value == "default_value_1"
|
||||
assert updated_variable.last_edited_at is None
|
||||
|
||||
def test_reset_system_variable_raises_error(self):
|
||||
"""Test that resetting a system variable raises an error"""
|
||||
srv = self._get_test_srv()
|
||||
mock_workflow = self._create_mock_workflow()
|
||||
|
||||
# Create a system variable
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=self._test_app_id,
|
||||
name="sys_var",
|
||||
value=build_segment("sys_value"),
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
db.session.add(sys_var)
|
||||
db.session.flush()
|
||||
|
||||
# Attempt to reset the system variable
|
||||
with pytest.raises(VariableResetError) as exc_info:
|
||||
srv.reset_variable(mock_workflow, sys_var)
|
||||
|
||||
assert "cannot reset system variable" in str(exc_info.value)
|
||||
assert sys_var.id in str(exc_info.value)
|
||||
@@ -0,0 +1,168 @@
|
||||
"""Integration tests for ClickZetta Volume Storage."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
|
||||
ClickZettaVolumeConfig,
|
||||
ClickZettaVolumeStorage,
|
||||
)
|
||||
|
||||
|
||||
class TestClickZettaVolumeStorage(unittest.TestCase):
|
||||
"""Test cases for ClickZetta Volume Storage."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.config = ClickZettaVolumeConfig(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", "test_pass"),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "uat-api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema_name=os.getenv("CLICKZETTA_SCHEMA", "dify"),
|
||||
volume_type="table",
|
||||
table_prefix="test_dataset_",
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
|
||||
def test_user_volume_operations(self):
|
||||
"""Test basic operations with User Volume."""
|
||||
config = self.config
|
||||
config.volume_type = "user"
|
||||
|
||||
storage = ClickZettaVolumeStorage(config)
|
||||
|
||||
# Test file operations
|
||||
test_filename = "test_file.txt"
|
||||
test_content = b"Hello, ClickZetta Volume!"
|
||||
|
||||
# Save file
|
||||
storage.save(test_filename, test_content)
|
||||
|
||||
# Check if file exists
|
||||
assert storage.exists(test_filename)
|
||||
|
||||
# Load file
|
||||
loaded_content = storage.load_once(test_filename)
|
||||
assert loaded_content == test_content
|
||||
|
||||
# Test streaming
|
||||
stream_content = b""
|
||||
for chunk in storage.load_stream(test_filename):
|
||||
stream_content += chunk
|
||||
assert stream_content == test_content
|
||||
|
||||
# Test download
|
||||
with tempfile.NamedTemporaryFile() as temp_file:
|
||||
storage.download(test_filename, temp_file.name)
|
||||
downloaded_content = Path(temp_file.name).read_bytes()
|
||||
assert downloaded_content == test_content
|
||||
|
||||
# Test scan
|
||||
files = storage.scan("", files=True, directories=False)
|
||||
assert test_filename in files
|
||||
|
||||
# Delete file
|
||||
storage.delete(test_filename)
|
||||
assert not storage.exists(test_filename)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
|
||||
def test_table_volume_operations(self):
|
||||
"""Test basic operations with Table Volume."""
|
||||
config = self.config
|
||||
config.volume_type = "table"
|
||||
|
||||
storage = ClickZettaVolumeStorage(config)
|
||||
|
||||
# Test file operations with dataset_id
|
||||
dataset_id = "12345"
|
||||
test_filename = f"{dataset_id}/test_file.txt"
|
||||
test_content = b"Hello, Table Volume!"
|
||||
|
||||
# Save file
|
||||
storage.save(test_filename, test_content)
|
||||
|
||||
# Check if file exists
|
||||
assert storage.exists(test_filename)
|
||||
|
||||
# Load file
|
||||
loaded_content = storage.load_once(test_filename)
|
||||
assert loaded_content == test_content
|
||||
|
||||
# Test scan for dataset
|
||||
files = storage.scan(dataset_id, files=True, directories=False)
|
||||
assert "test_file.txt" in files
|
||||
|
||||
# Delete file
|
||||
storage.delete(test_filename)
|
||||
assert not storage.exists(test_filename)
|
||||
|
||||
def test_config_validation(self):
|
||||
"""Test configuration validation."""
|
||||
# Test missing required fields
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(
|
||||
username="", # Empty username should fail
|
||||
password="pass",
|
||||
instance="instance",
|
||||
)
|
||||
|
||||
# Test invalid volume type
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(username="user", password="pass", instance="instance", volume_type="invalid_type")
|
||||
|
||||
# Test external volume without volume_name
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(
|
||||
username="user",
|
||||
password="pass",
|
||||
instance="instance",
|
||||
volume_type="external",
|
||||
# Missing volume_name
|
||||
)
|
||||
|
||||
def test_volume_path_generation(self):
|
||||
"""Test volume path generation for different types."""
|
||||
storage = ClickZettaVolumeStorage(self.config)
|
||||
|
||||
# Test table volume path
|
||||
path = storage._get_volume_path("test.txt", "12345")
|
||||
assert path == "test_dataset_12345/test.txt"
|
||||
|
||||
# Test path with existing dataset_id prefix
|
||||
path = storage._get_volume_path("12345/test.txt")
|
||||
assert path == "12345/test.txt"
|
||||
|
||||
# Test user volume
|
||||
storage._config.volume_type = "user"
|
||||
path = storage._get_volume_path("test.txt")
|
||||
assert path == "test.txt"
|
||||
|
||||
def test_sql_prefix_generation(self):
|
||||
"""Test SQL prefix generation for different volume types."""
|
||||
storage = ClickZettaVolumeStorage(self.config)
|
||||
|
||||
# Test table volume SQL prefix
|
||||
prefix = storage._get_volume_sql_prefix("12345")
|
||||
assert prefix == "TABLE VOLUME test_dataset_12345"
|
||||
|
||||
# Test user volume SQL prefix
|
||||
storage._config.volume_type = "user"
|
||||
prefix = storage._get_volume_sql_prefix()
|
||||
assert prefix == "USER VOLUME"
|
||||
|
||||
# Test external volume SQL prefix
|
||||
storage._config.volume_type = "external"
|
||||
storage._config.volume_name = "my_external_volume"
|
||||
prefix = storage._get_volume_sql_prefix()
|
||||
assert prefix == "VOLUME my_external_volume"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
0
dify/api/tests/integration_tests/tasks/__init__.py
Normal file
0
dify/api/tests/integration_tests/tasks/__init__.py
Normal file
@@ -0,0 +1,470 @@
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, UploadFile
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_and_tenant(flask_req_ctx):
|
||||
tenant_id = uuid.uuid4()
|
||||
tenant = Tenant(
|
||||
id=tenant_id,
|
||||
name="test_tenant",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant_id, # Now tenant.id will have a value
|
||||
name=f"Test App for tenant {tenant.id}",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app)
|
||||
db.session.flush()
|
||||
yield (tenant, app)
|
||||
|
||||
# Cleanup with proper error handling
|
||||
db.session.delete(app)
|
||||
db.session.delete(tenant)
|
||||
|
||||
|
||||
class TestDeleteDraftVariablesIntegration:
|
||||
@pytest.fixture
|
||||
def setup_test_data(self, app_and_tenant):
|
||||
"""Create test data with apps and draft variables."""
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
# Create a second app for testing
|
||||
app2 = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App 2",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app2)
|
||||
db.session.commit()
|
||||
|
||||
# Create draft variables for both apps
|
||||
variables_app1 = []
|
||||
variables_app2 = []
|
||||
|
||||
for i in range(5):
|
||||
var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(var1)
|
||||
variables_app1.append(var1)
|
||||
|
||||
var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app2.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(var2)
|
||||
variables_app2.append(var2)
|
||||
|
||||
# Commit all the variables to the database
|
||||
db.session.commit()
|
||||
|
||||
yield {
|
||||
"app1": app,
|
||||
"app2": app2,
|
||||
"tenant": tenant,
|
||||
"variables_app1": variables_app1,
|
||||
"variables_app2": variables_app2,
|
||||
}
|
||||
|
||||
# Cleanup - refresh session and check if objects still exist
|
||||
db.session.rollback() # Clear any pending changes
|
||||
|
||||
# Clean up remaining variables
|
||||
cleanup_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id.in_([app.id, app2.id]),
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(cleanup_query)
|
||||
|
||||
# Clean up app2
|
||||
app2_obj = db.session.get(App, app2.id)
|
||||
if app2_obj:
|
||||
db.session.delete(app2_obj)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def test_delete_draft_variables_batch_removes_correct_variables(self, setup_test_data):
|
||||
"""Test that batch deletion only removes variables for the specified app."""
|
||||
data = setup_test_data
|
||||
app1_id = data["app1"].id
|
||||
app2_id = data["app2"].id
|
||||
|
||||
# Verify initial state
|
||||
app1_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
app2_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
|
||||
assert app1_vars_before == 5
|
||||
assert app2_vars_before == 5
|
||||
|
||||
# Delete app1 variables
|
||||
deleted_count = delete_draft_variables_batch(app1_id, batch_size=10)
|
||||
|
||||
# Verify results
|
||||
assert deleted_count == 5
|
||||
|
||||
app1_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
app2_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
|
||||
|
||||
assert app1_vars_after == 0 # All app1 variables deleted
|
||||
assert app2_vars_after == 5 # App2 variables unchanged
|
||||
|
||||
def test_delete_draft_variables_batch_with_small_batch_size(self, setup_test_data):
|
||||
"""Test batch deletion with small batch size processes all records."""
|
||||
data = setup_test_data
|
||||
app1_id = data["app1"].id
|
||||
|
||||
# Use small batch size to force multiple batches
|
||||
deleted_count = delete_draft_variables_batch(app1_id, batch_size=2)
|
||||
|
||||
assert deleted_count == 5
|
||||
|
||||
# Verify all variables are deleted
|
||||
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
assert remaining_vars == 0
|
||||
|
||||
def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data):
|
||||
"""Test that deleting variables for nonexistent app returns 0."""
|
||||
nonexistent_app_id = str(uuid.uuid4()) # Use a valid UUID format
|
||||
|
||||
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=100)
|
||||
|
||||
assert deleted_count == 0
|
||||
|
||||
def test_delete_draft_variables_wrapper_function(self, setup_test_data):
|
||||
"""Test that _delete_draft_variables wrapper function works correctly."""
|
||||
data = setup_test_data
|
||||
app1_id = data["app1"].id
|
||||
|
||||
# Verify initial state
|
||||
vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
assert vars_before == 5
|
||||
|
||||
# Call wrapper function
|
||||
deleted_count = _delete_draft_variables(app1_id)
|
||||
|
||||
# Verify results
|
||||
assert deleted_count == 5
|
||||
|
||||
vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
assert vars_after == 0
|
||||
|
||||
def test_batch_deletion_handles_large_dataset(self, app_and_tenant):
|
||||
"""Test batch deletion with larger dataset to verify batching logic."""
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
# Create many draft variables
|
||||
variables = []
|
||||
for i in range(25):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(var)
|
||||
variables.append(var)
|
||||
variable_ids = [i.id for i in variables]
|
||||
|
||||
# Commit the variables to the database
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Use small batch size to force multiple batches
|
||||
deleted_count = delete_draft_variables_batch(app.id, batch_size=8)
|
||||
|
||||
assert deleted_count == 25
|
||||
|
||||
# Verify all variables are deleted
|
||||
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
|
||||
assert remaining_vars == 0
|
||||
|
||||
finally:
|
||||
query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.id.in_(variable_ids),
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(query)
|
||||
|
||||
|
||||
class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
"""Integration tests for draft variable deletion with Offload data."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_offload_test_data(self, app_and_tenant):
|
||||
"""Create test data with draft variables that have associated Offload files."""
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
# Create UploadFile records
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
db.session.add(upload_file1)
|
||||
db.session.add(upload_file2)
|
||||
db.session.flush()
|
||||
|
||||
# Create WorkflowDraftVariableFile records
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
var_file1 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file1.id,
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.STRING,
|
||||
)
|
||||
var_file2 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file2.id,
|
||||
size=2048,
|
||||
length=20,
|
||||
value_type=SegmentType.OBJECT,
|
||||
)
|
||||
db.session.add(var_file1)
|
||||
db.session.add(var_file2)
|
||||
db.session.flush()
|
||||
|
||||
# Create WorkflowDraftVariable records with file associations
|
||||
draft_var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_1",
|
||||
name="large_var_1",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file1.id,
|
||||
)
|
||||
draft_var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_2",
|
||||
name="large_var_2",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file2.id,
|
||||
)
|
||||
# Create a regular variable without Offload data
|
||||
draft_var3 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_3",
|
||||
name="regular_var",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
db.session.add(draft_var1)
|
||||
db.session.add(draft_var2)
|
||||
db.session.add(draft_var3)
|
||||
db.session.commit()
|
||||
|
||||
yield {
|
||||
"app": app,
|
||||
"tenant": tenant,
|
||||
"upload_files": [upload_file1, upload_file2],
|
||||
"variable_files": [var_file1, var_file2],
|
||||
"draft_variables": [draft_var1, draft_var2, draft_var3],
|
||||
}
|
||||
|
||||
# Cleanup
|
||||
db.session.rollback()
|
||||
|
||||
# Clean up any remaining records
|
||||
for table, ids in [
|
||||
(WorkflowDraftVariable, [v.id for v in [draft_var1, draft_var2, draft_var3]]),
|
||||
(WorkflowDraftVariableFile, [vf.id for vf in [var_file1, var_file2]]),
|
||||
(UploadFile, [uf.id for uf in [upload_file1, upload_file2]]),
|
||||
]:
|
||||
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
|
||||
db.session.execute(cleanup_query)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that deleting draft variables also cleans up associated Offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Mock storage deletion to succeed
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Verify initial state
|
||||
draft_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_before = db.session.query(UploadFile).count()
|
||||
|
||||
assert draft_vars_before == 3 # 2 with files + 1 regular
|
||||
assert var_files_before == 2
|
||||
assert upload_files_before == 2
|
||||
|
||||
# Delete draft variables
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
|
||||
# Verify results
|
||||
assert deleted_count == 3
|
||||
|
||||
# Check that all draft variables are deleted
|
||||
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert draft_vars_after == 0
|
||||
|
||||
# Check that associated Offload data is cleaned up
|
||||
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = db.session.query(UploadFile).count()
|
||||
|
||||
assert var_files_after == 0 # All variable files should be deleted
|
||||
assert upload_files_after == 0 # All upload files should be deleted
|
||||
|
||||
# Verify storage deletion was called for both files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
storage_keys_deleted = [call.args[0] for call in mock_storage.delete.call_args_list]
|
||||
assert "test/file1.json" in storage_keys_deleted
|
||||
assert "test/file2.json" in storage_keys_deleted
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that database cleanup continues even when storage deletion fails."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Mock storage deletion to fail for first file, succeed for second
|
||||
mock_storage.delete.side_effect = [Exception("Storage error"), None]
|
||||
|
||||
# Delete draft variables
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
|
||||
# Verify that all draft variables are still deleted
|
||||
assert deleted_count == 3
|
||||
|
||||
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert draft_vars_after == 0
|
||||
|
||||
# Database cleanup should still succeed even with storage errors
|
||||
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = db.session.query(UploadFile).count()
|
||||
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
# Verify storage deletion was attempted for both files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_partial_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
"""Test deletion with mix of variables with and without Offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Create additional app with only regular variables (no offload data)
|
||||
tenant = data["tenant"]
|
||||
app2 = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App 2",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app2)
|
||||
db.session.flush()
|
||||
|
||||
# Add regular variables to app2
|
||||
regular_vars = []
|
||||
for i in range(3):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app2.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(var)
|
||||
regular_vars.append(var)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Mock storage deletion
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Delete variables for app2 (no offload data)
|
||||
deleted_count_app2 = delete_draft_variables_batch(app2.id, batch_size=10)
|
||||
assert deleted_count_app2 == 3
|
||||
|
||||
# Verify storage wasn't called for app2 (no offload files)
|
||||
mock_storage.delete.assert_not_called()
|
||||
|
||||
# Delete variables for original app (with offload data)
|
||||
deleted_count_app1 = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
assert deleted_count_app1 == 3
|
||||
|
||||
# Now storage should be called for the offload files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
finally:
|
||||
# Cleanup app2 and its variables
|
||||
cleanup_vars_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.app_id == app2.id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(cleanup_vars_query)
|
||||
|
||||
app2_obj = db.session.get(App, app2.id)
|
||||
if app2_obj:
|
||||
db.session.delete(app2_obj)
|
||||
db.session.commit()
|
||||
0
dify/api/tests/integration_tests/tools/__init__.py
Normal file
0
dify/api/tests/integration_tests/tools/__init__.py
Normal file
35
dify/api/tests/integration_tests/tools/__mock/http.py
Normal file
35
dify/api/tests/integration_tests/tools/__mock/http.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import json
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
|
||||
|
||||
class MockedHttp:
|
||||
@staticmethod
|
||||
def httpx_request(
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Mocked httpx.request
|
||||
"""
|
||||
request = httpx.Request(
|
||||
method, url, params=kwargs.get("params"), headers=kwargs.get("headers"), cookies=kwargs.get("cookies")
|
||||
)
|
||||
data = kwargs.get("data")
|
||||
resp = json.dumps(data).encode("utf-8") if data else b"OK"
|
||||
response = httpx.Response(
|
||||
status_code=200,
|
||||
request=request,
|
||||
content=resp,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(ssrf_proxy, "make_request", MockedHttp.httpx_request)
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,40 @@
|
||||
from flask import Flask, request
|
||||
from flask_restx import Api, Resource
|
||||
|
||||
app = Flask(__name__)
|
||||
api = Api(app)
|
||||
|
||||
# Mock data
|
||||
todos_data = {
|
||||
"global": ["Buy groceries", "Finish project"],
|
||||
"user1": ["Go for a run", "Read a book"],
|
||||
}
|
||||
|
||||
|
||||
class TodosResource(Resource):
|
||||
def get(self, username):
|
||||
todos = todos_data.get(username, [])
|
||||
return {"todos": todos}
|
||||
|
||||
def post(self, username):
|
||||
data = request.get_json()
|
||||
new_todo = data.get("todo")
|
||||
todos_data.setdefault(username, []).append(new_todo)
|
||||
return {"message": "Todo added successfully"}
|
||||
|
||||
def delete(self, username):
|
||||
data = request.get_json()
|
||||
todo_idx = data.get("todo_idx")
|
||||
todos = todos_data.get(username, [])
|
||||
|
||||
if 0 <= todo_idx < len(todos):
|
||||
del todos[todo_idx]
|
||||
return {"message": "Todo deleted successfully"}
|
||||
|
||||
return {"error": "Invalid todo index"}, 400
|
||||
|
||||
|
||||
api.add_resource(TodosResource, "/todos/<string:username>")
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(port=5003, debug=True)
|
||||
@@ -0,0 +1,52 @@
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.custom_tool.tool import ApiTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity
|
||||
from tests.integration_tests.tools.__mock.http import setup_http_mock
|
||||
|
||||
tool_bundle = {
|
||||
"server_url": "http://www.example.com/{path_param}",
|
||||
"method": "post",
|
||||
"author": "",
|
||||
"openapi": {
|
||||
"parameters": [
|
||||
{"in": "path", "name": "path_param"},
|
||||
{"in": "query", "name": "query_param"},
|
||||
{"in": "cookie", "name": "cookie_param"},
|
||||
{"in": "header", "name": "header_param"},
|
||||
],
|
||||
"requestBody": {
|
||||
"content": {"application/json": {"schema": {"properties": {"body_param": {"type": "string"}}}}}
|
||||
},
|
||||
},
|
||||
"parameters": [],
|
||||
}
|
||||
parameters = {
|
||||
"path_param": "p_param",
|
||||
"query_param": "q_param",
|
||||
"cookie_param": "c_param",
|
||||
"header_param": "h_param",
|
||||
"body_param": "b_param",
|
||||
}
|
||||
|
||||
|
||||
def test_api_tool(setup_http_mock):
|
||||
tool = ApiTool(
|
||||
entity=ToolEntity(
|
||||
identity=ToolIdentity(provider="", author="", name="", label=I18nObject(en_US="test tool")),
|
||||
),
|
||||
api_bundle=ApiToolBundle.model_validate(tool_bundle),
|
||||
runtime=ToolRuntime(tenant_id="", credentials={"auth_type": "none"}),
|
||||
provider_id="test_tool",
|
||||
)
|
||||
headers = tool.assembling_request(parameters)
|
||||
response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.request.url.path == "/p_param"
|
||||
assert response.request.url.query == b"query_param=q_param"
|
||||
assert response.request.headers.get("header_param") == "h_param"
|
||||
assert response.request.headers.get("content-type") == "application/json"
|
||||
assert response.request.headers.get("cookie") == "cookie_param=c_param"
|
||||
assert "b_param" in response.content.decode()
|
||||
7
dify/api/tests/integration_tests/utils/child_class.py
Normal file
7
dify/api/tests/integration_tests/utils/child_class.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from tests.integration_tests.utils.parent_class import ParentClass
|
||||
|
||||
|
||||
class ChildClass(ParentClass):
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self.name = name
|
||||
@@ -0,0 +1,7 @@
|
||||
from tests.integration_tests.utils.parent_class import ParentClass
|
||||
|
||||
|
||||
class LazyLoadChildClass(ParentClass):
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self.name = name
|
||||
6
dify/api/tests/integration_tests/utils/parent_class.py
Normal file
6
dify/api/tests/integration_tests/utils/parent_class.py
Normal file
@@ -0,0 +1,6 @@
|
||||
class ParentClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
|
||||
from core.helper.module_import_helper import import_module_from_source, load_single_subclass_from_source
|
||||
from tests.integration_tests.utils.parent_class import ParentClass
|
||||
|
||||
|
||||
def test_loading_subclass_from_source():
|
||||
current_path = os.getcwd()
|
||||
module = load_single_subclass_from_source(
|
||||
module_name="ChildClass", script_path=os.path.join(current_path, "child_class.py"), parent_type=ParentClass
|
||||
)
|
||||
assert module
|
||||
assert module.__name__ == "ChildClass"
|
||||
|
||||
|
||||
def test_load_import_module_from_source():
|
||||
current_path = os.getcwd()
|
||||
module = import_module_from_source(
|
||||
module_name="ChildClass", py_file_path=os.path.join(current_path, "child_class.py")
|
||||
)
|
||||
assert module
|
||||
assert module.__name__ == "ChildClass"
|
||||
|
||||
|
||||
def test_lazy_loading_subclass_from_source():
|
||||
current_path = os.getcwd()
|
||||
clz = load_single_subclass_from_source(
|
||||
module_name="LazyLoadChildClass",
|
||||
script_path=os.path.join(current_path, "lazy_load_class.py"),
|
||||
parent_type=ParentClass,
|
||||
use_lazy_loader=True,
|
||||
)
|
||||
instance = clz("dify")
|
||||
assert instance.get_name() == "dify"
|
||||
0
dify/api/tests/integration_tests/vdb/__init__.py
Normal file
0
dify/api/tests/integration_tests/vdb/__init__.py
Normal file
166
dify/api/tests/integration_tests/vdb/__mock/baiduvectordb.py
Normal file
166
dify/api/tests/integration_tests/vdb/__mock/baiduvectordb.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import os
|
||||
from collections import UserDict
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from pymochow import MochowClient
|
||||
from pymochow.model.database import Database
|
||||
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
|
||||
from pymochow.model.schema import HNSWParams, VectorIndex
|
||||
from pymochow.model.table import Table
|
||||
|
||||
|
||||
class AttrDict(UserDict):
|
||||
def __getattr__(self, item):
|
||||
return self.get(item)
|
||||
|
||||
|
||||
class MockBaiduVectorDBClass:
|
||||
def mock_vector_db_client(
|
||||
self,
|
||||
config=None,
|
||||
adapter: Any | None = None,
|
||||
):
|
||||
self.conn = MagicMock()
|
||||
self._config = MagicMock()
|
||||
|
||||
def list_databases(self, config=None) -> list[Database]:
|
||||
return [
|
||||
Database(
|
||||
conn=self.conn,
|
||||
database_name="dify",
|
||||
config=self._config,
|
||||
)
|
||||
]
|
||||
|
||||
def create_database(self, database_name: str, config=None) -> Database:
|
||||
return Database(conn=self.conn, database_name=database_name, config=config)
|
||||
|
||||
def list_table(self, config=None) -> list[Table]:
|
||||
return []
|
||||
|
||||
def drop_table(self, table_name: str, config=None):
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def create_table(
|
||||
self,
|
||||
table_name: str,
|
||||
replication: int,
|
||||
partition: int,
|
||||
schema,
|
||||
enable_dynamic_field=False,
|
||||
description: str = "",
|
||||
config=None,
|
||||
) -> Table:
|
||||
return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config)
|
||||
|
||||
def describe_table(self, table_name: str, config=None) -> Table:
|
||||
return Table(
|
||||
self,
|
||||
table_name,
|
||||
3,
|
||||
1,
|
||||
None,
|
||||
enable_dynamic_field=False,
|
||||
description="table for dify",
|
||||
config=config,
|
||||
state=TableState.NORMAL,
|
||||
)
|
||||
|
||||
def upsert(self, rows, config=None):
|
||||
return {"code": 0, "msg": "operation success", "affectedCount": 1}
|
||||
|
||||
def rebuild_index(self, index_name: str, config=None):
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def describe_index(self, index_name: str, config=None):
|
||||
return VectorIndex(
|
||||
index_name=index_name,
|
||||
index_type=IndexType.HNSW,
|
||||
field="vector",
|
||||
metric_type=MetricType.L2,
|
||||
params=HNSWParams(m=16, efconstruction=200),
|
||||
auto_build=False,
|
||||
state=IndexState.NORMAL,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
primary_key,
|
||||
partition_key=None,
|
||||
projections=None,
|
||||
retrieve_vector=False,
|
||||
read_consistency=ReadConsistency.EVENTUAL,
|
||||
config=None,
|
||||
):
|
||||
return AttrDict(
|
||||
{
|
||||
"row": {
|
||||
"id": primary_key.get("id"),
|
||||
"vector": [0.23432432, 0.8923744, 0.89238432],
|
||||
"page_content": "text",
|
||||
"metadata": {"doc_id": "doc_id_001"},
|
||||
},
|
||||
"code": 0,
|
||||
"msg": "Success",
|
||||
}
|
||||
)
|
||||
|
||||
def delete(self, primary_key=None, partition_key=None, filter=None, config=None):
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def search(
|
||||
self,
|
||||
anns,
|
||||
partition_key=None,
|
||||
projections=None,
|
||||
retrieve_vector=False,
|
||||
read_consistency=ReadConsistency.EVENTUAL,
|
||||
config=None,
|
||||
):
|
||||
return AttrDict(
|
||||
{
|
||||
"rows": [
|
||||
{
|
||||
"row": {
|
||||
"id": "doc_id_001",
|
||||
"vector": [0.23432432, 0.8923744, 0.89238432],
|
||||
"page_content": "text",
|
||||
"metadata": {"doc_id": "doc_id_001"},
|
||||
},
|
||||
"distance": 0.1,
|
||||
"score": 0.5,
|
||||
}
|
||||
],
|
||||
"code": 0,
|
||||
"msg": "Success",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client)
|
||||
monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases)
|
||||
monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database)
|
||||
monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table)
|
||||
monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table)
|
||||
monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table)
|
||||
monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table)
|
||||
monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table)
|
||||
monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index)
|
||||
monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index)
|
||||
monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete)
|
||||
monkeypatch.setattr(Table, "query", MockBaiduVectorDBClass.query)
|
||||
monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
|
||||
|
||||
class MockIndicesClient:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def create(self, index, mappings, settings):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def refresh(self, index):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def delete(self, index):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def exists(self, index):
|
||||
return True
|
||||
|
||||
|
||||
class MockClient:
|
||||
def __init__(self, **kwargs):
|
||||
self.indices = MockIndicesClient()
|
||||
|
||||
def index(self, **kwargs):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def exists(self, **kwargs):
|
||||
return True
|
||||
|
||||
def delete(self, **kwargs):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def search(self, **kwargs):
|
||||
return {
|
||||
"took": 1,
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
Field.CONTENT_KEY: "abcdef",
|
||||
Field.VECTOR: [1, 2],
|
||||
Field.METADATA_KEY: {},
|
||||
},
|
||||
"_score": 1.0,
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
Field.CONTENT_KEY: "123456",
|
||||
Field.VECTOR: [2, 2],
|
||||
Field.METADATA_KEY: {},
|
||||
},
|
||||
"_score": 0.9,
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
Field.CONTENT_KEY: "a1b2c3",
|
||||
Field.VECTOR: [3, 2],
|
||||
Field.METADATA_KEY: {},
|
||||
},
|
||||
"_score": 0.8,
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_client_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Elasticsearch, "__init__", MockClient.__init__)
|
||||
monkeypatch.setattr(Elasticsearch, "index", MockClient.index)
|
||||
monkeypatch.setattr(Elasticsearch, "exists", MockClient.exists)
|
||||
monkeypatch.setattr(Elasticsearch, "delete", MockClient.delete)
|
||||
monkeypatch.setattr(Elasticsearch, "search", MockClient.search)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
192
dify/api/tests/integration_tests/vdb/__mock/tcvectordb.py
Normal file
192
dify/api/tests/integration_tests/vdb/__mock/tcvectordb.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import os
|
||||
from typing import Any, Union
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from tcvectordb import RPCVectorDBClient
|
||||
from tcvectordb.model import enum
|
||||
from tcvectordb.model.collection import FilterIndexConfig
|
||||
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank
|
||||
from tcvectordb.model.enum import ReadConsistency
|
||||
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex
|
||||
from tcvectordb.rpc.model.collection import RPCCollection
|
||||
from tcvectordb.rpc.model.database import RPCDatabase
|
||||
from xinference_client.types import Embedding
|
||||
|
||||
|
||||
class MockTcvectordbClass:
|
||||
def mock_vector_db_client(
|
||||
self,
|
||||
url: str,
|
||||
username="",
|
||||
key="",
|
||||
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
|
||||
timeout=10,
|
||||
adapter: Any | None = None,
|
||||
pool_size: int = 2,
|
||||
proxies: dict | None = None,
|
||||
password: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._conn = None
|
||||
self._read_consistency = read_consistency
|
||||
|
||||
def create_database_if_not_exists(self, database_name: str, timeout: float | None = None) -> RPCDatabase:
|
||||
return RPCDatabase(
|
||||
name="dify",
|
||||
read_consistency=self._read_consistency,
|
||||
)
|
||||
|
||||
def exists_collection(self, database_name: str, collection_name: str) -> bool:
|
||||
return True
|
||||
|
||||
def describe_collection(
|
||||
self, database_name: str, collection_name: str, timeout: float | None = None
|
||||
) -> RPCCollection:
|
||||
index = Index(
|
||||
FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
|
||||
VectorIndex(
|
||||
"vector",
|
||||
128,
|
||||
enum.IndexType.HNSW,
|
||||
enum.MetricType.IP,
|
||||
HNSWParams(m=16, efconstruction=200),
|
||||
),
|
||||
FilterIndex("text", enum.FieldType.String, enum.IndexType.FILTER),
|
||||
FilterIndex("metadata", enum.FieldType.String, enum.IndexType.FILTER),
|
||||
)
|
||||
return RPCCollection(
|
||||
RPCDatabase(
|
||||
name=database_name,
|
||||
read_consistency=self._read_consistency,
|
||||
),
|
||||
collection_name,
|
||||
index=index,
|
||||
)
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
shard: int,
|
||||
replicas: int,
|
||||
description: str | None = None,
|
||||
index: Index | None = None,
|
||||
embedding: Embedding | None = None,
|
||||
timeout: float | None = None,
|
||||
ttl_config: dict | None = None,
|
||||
filter_index_config: FilterIndexConfig | None = None,
|
||||
indexes: list[IndexField] | None = None,
|
||||
) -> RPCCollection:
|
||||
return RPCCollection(
|
||||
RPCDatabase(
|
||||
name="dify",
|
||||
read_consistency=self._read_consistency,
|
||||
),
|
||||
collection_name,
|
||||
shard,
|
||||
replicas,
|
||||
description,
|
||||
index,
|
||||
embedding=embedding,
|
||||
read_consistency=self._read_consistency,
|
||||
timeout=timeout,
|
||||
ttl_config=ttl_config,
|
||||
filter_index_config=filter_index_config,
|
||||
indexes=indexes,
|
||||
)
|
||||
|
||||
def collection_upsert(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
documents: list[Union[Document, dict]],
|
||||
timeout: float | None = None,
|
||||
build_index: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
def collection_search(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
vectors: list[list[float]],
|
||||
filter: Filter | None = None,
|
||||
params=None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: int = 10,
|
||||
output_fields: list[str] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> list[list[dict]]:
|
||||
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
|
||||
|
||||
def collection_hybrid_search(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
ann: Union[list[AnnSearch], AnnSearch] | None = None,
|
||||
match: Union[list[KeywordSearch], KeywordSearch] | None = None,
|
||||
filter: Union[Filter, str] | None = None,
|
||||
rerank: Rerank | None = None,
|
||||
retrieve_vector: bool | None = None,
|
||||
output_fields: list[str] | None = None,
|
||||
limit: int | None = None,
|
||||
timeout: float | None = None,
|
||||
return_pd_object=False,
|
||||
**kwargs,
|
||||
) -> list[list[dict]]:
|
||||
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
|
||||
|
||||
def collection_query(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
document_ids: list | None = None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
filter: Filter | None = None,
|
||||
output_fields: list[str] | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
|
||||
|
||||
def collection_delete(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
document_ids: list[str] | None = None,
|
||||
filter: Filter | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
def drop_collection(self, database_name: str, collection_name: str, timeout: float | None = None):
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(RPCVectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
|
||||
monkeypatch.setattr(
|
||||
RPCVectorDBClient, "create_database_if_not_exists", MockTcvectordbClass.create_database_if_not_exists
|
||||
)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "exists_collection", MockTcvectordbClass.exists_collection)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "create_collection", MockTcvectordbClass.create_collection)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "describe_collection", MockTcvectordbClass.describe_collection)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "upsert", MockTcvectordbClass.collection_upsert)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "search", MockTcvectordbClass.collection_search)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "hybrid_search", MockTcvectordbClass.collection_hybrid_search)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "query", MockTcvectordbClass.collection_query)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "delete", MockTcvectordbClass.collection_delete)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "drop_collection", MockTcvectordbClass.drop_collection)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
from collections import UserDict
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from upstash_vector import Index
|
||||
|
||||
|
||||
# Mocking the Index class from upstash_vector
|
||||
class MockIndex:
|
||||
def __init__(self, url="", token=""):
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.vectors = []
|
||||
|
||||
def upsert(self, vectors):
|
||||
for vector in vectors:
|
||||
vector.score = 0.5
|
||||
self.vectors.append(vector)
|
||||
return {"code": 0, "msg": "operation success", "affectedCount": len(vectors)}
|
||||
|
||||
def fetch(self, ids):
|
||||
return [vector for vector in self.vectors if vector.id in ids]
|
||||
|
||||
def delete(self, ids):
|
||||
self.vectors = [vector for vector in self.vectors if vector.id not in ids]
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def query(
|
||||
self,
|
||||
vector: None,
|
||||
top_k: int = 10,
|
||||
include_vectors: bool = False,
|
||||
include_metadata: bool = False,
|
||||
filter: str = "",
|
||||
data: str | None = None,
|
||||
namespace: str = "",
|
||||
include_data: bool = False,
|
||||
):
|
||||
# Simple mock query, in real scenario you would calculate similarity
|
||||
mock_result = []
|
||||
for vector_data in self.vectors:
|
||||
mock_result.append(vector_data)
|
||||
return mock_result[:top_k]
|
||||
|
||||
def reset(self):
|
||||
self.vectors = []
|
||||
|
||||
def info(self):
|
||||
return AttrDict({"dimension": 1024})
|
||||
|
||||
|
||||
class AttrDict(UserDict):
|
||||
def __getattr__(self, item):
|
||||
return self.get(item)
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_upstashvector_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Index, "__init__", MockIndex.__init__)
|
||||
monkeypatch.setattr(Index, "upsert", MockIndex.upsert)
|
||||
monkeypatch.setattr(Index, "fetch", MockIndex.fetch)
|
||||
monkeypatch.setattr(Index, "delete", MockIndex.delete)
|
||||
monkeypatch.setattr(Index, "query", MockIndex.query)
|
||||
monkeypatch.setattr(Index, "reset", MockIndex.reset)
|
||||
monkeypatch.setattr(Index, "info", MockIndex.info)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
215
dify/api/tests/integration_tests/vdb/__mock/vikingdb.py
Normal file
215
dify/api/tests/integration_tests/vdb/__mock/vikingdb.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import os
|
||||
from typing import Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from volcengine.viking_db import (
|
||||
Collection,
|
||||
Data,
|
||||
DistanceType,
|
||||
Field,
|
||||
FieldType,
|
||||
Index,
|
||||
IndexType,
|
||||
QuantType,
|
||||
VectorIndexParams,
|
||||
VikingDBService,
|
||||
)
|
||||
|
||||
from core.rag.datasource.vdb.field import Field as vdb_Field
|
||||
|
||||
|
||||
class MockVikingDBClass:
|
||||
def __init__(
|
||||
self,
|
||||
host="api-vikingdb.volces.com",
|
||||
region="cn-north-1",
|
||||
ak="",
|
||||
sk="",
|
||||
scheme="http",
|
||||
connection_timeout=30,
|
||||
socket_timeout=30,
|
||||
proxy=None,
|
||||
):
|
||||
self._viking_db_service = MagicMock()
|
||||
self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}')
|
||||
|
||||
def get_collection(self, collection_name) -> Collection:
|
||||
return Collection(
|
||||
collection_name=collection_name,
|
||||
description="Collection For Dify",
|
||||
viking_db_service=self._viking_db_service,
|
||||
primary_key=vdb_Field.PRIMARY_KEY,
|
||||
fields=[
|
||||
Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True),
|
||||
Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text),
|
||||
Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=768),
|
||||
],
|
||||
indexes=[
|
||||
Index(
|
||||
collection_name=collection_name,
|
||||
index_name=f"{collection_name}_idx",
|
||||
vector_index=VectorIndexParams(
|
||||
distance=DistanceType.L2,
|
||||
index_type=IndexType.HNSW,
|
||||
quant=QuantType.Float,
|
||||
),
|
||||
scalar_index=None,
|
||||
stat=None,
|
||||
viking_db_service=self._viking_db_service,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def drop_collection(self, collection_name):
|
||||
assert collection_name != ""
|
||||
|
||||
def create_collection(self, collection_name, fields, description="") -> Collection:
|
||||
return Collection(
|
||||
collection_name=collection_name,
|
||||
description=description,
|
||||
primary_key=vdb_Field.PRIMARY_KEY,
|
||||
viking_db_service=self._viking_db_service,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
def get_index(self, collection_name, index_name) -> Index:
|
||||
return Index(
|
||||
collection_name=collection_name,
|
||||
index_name=index_name,
|
||||
viking_db_service=self._viking_db_service,
|
||||
stat=None,
|
||||
scalar_index=None,
|
||||
vector_index=VectorIndexParams(
|
||||
distance=DistanceType.L2,
|
||||
index_type=IndexType.HNSW,
|
||||
quant=QuantType.Float,
|
||||
),
|
||||
)
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
collection_name,
|
||||
index_name,
|
||||
vector_index=None,
|
||||
cpu_quota=2,
|
||||
description="",
|
||||
partition_by="",
|
||||
scalar_index=None,
|
||||
shard_count=None,
|
||||
shard_policy=None,
|
||||
):
|
||||
return Index(
|
||||
collection_name=collection_name,
|
||||
index_name=index_name,
|
||||
vector_index=vector_index,
|
||||
cpu_quota=cpu_quota,
|
||||
description=description,
|
||||
partition_by=partition_by,
|
||||
scalar_index=scalar_index,
|
||||
shard_count=shard_count,
|
||||
shard_policy=shard_policy,
|
||||
viking_db_service=self._viking_db_service,
|
||||
stat=None,
|
||||
)
|
||||
|
||||
def drop_index(self, collection_name, index_name):
|
||||
assert collection_name != ""
|
||||
assert index_name != ""
|
||||
|
||||
def upsert_data(self, data: Union[Data, list[Data]]):
|
||||
assert data is not None
|
||||
|
||||
def fetch_data(self, id: Union[str, list[str], int, list[int]]):
|
||||
return Data(
|
||||
fields={
|
||||
vdb_Field.GROUP_KEY: "test_group",
|
||||
vdb_Field.METADATA_KEY: "{}",
|
||||
vdb_Field.CONTENT_KEY: "content",
|
||||
vdb_Field.PRIMARY_KEY: id,
|
||||
vdb_Field.VECTOR: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
|
||||
},
|
||||
id=id,
|
||||
)
|
||||
|
||||
def delete_data(self, id: Union[str, list[str], int, list[int]]):
|
||||
assert id is not None
|
||||
|
||||
def search_by_vector(
|
||||
self,
|
||||
vector,
|
||||
sparse_vectors=None,
|
||||
filter=None,
|
||||
limit=10,
|
||||
output_fields=None,
|
||||
partition="default",
|
||||
dense_weight=None,
|
||||
) -> list[Data]:
|
||||
return [
|
||||
Data(
|
||||
fields={
|
||||
vdb_Field.GROUP_KEY: "test_group",
|
||||
vdb_Field.METADATA_KEY: '\
|
||||
{"source": "/var/folders/ml/xxx/xxx.txt", \
|
||||
"document_id": "test_document_id", \
|
||||
"dataset_id": "test_dataset_id", \
|
||||
"doc_id": "test_id", \
|
||||
"doc_hash": "test_hash"}',
|
||||
vdb_Field.CONTENT_KEY: "content",
|
||||
vdb_Field.PRIMARY_KEY: "test_id",
|
||||
vdb_Field.VECTOR: vector,
|
||||
},
|
||||
id="test_id",
|
||||
score=0.10,
|
||||
)
|
||||
]
|
||||
|
||||
def search(
|
||||
self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None
|
||||
) -> list[Data]:
|
||||
return [
|
||||
Data(
|
||||
fields={
|
||||
vdb_Field.GROUP_KEY: "test_group",
|
||||
vdb_Field.METADATA_KEY: '\
|
||||
{"source": "/var/folders/ml/xxx/xxx.txt", \
|
||||
"document_id": "test_document_id", \
|
||||
"dataset_id": "test_dataset_id", \
|
||||
"doc_id": "test_id", \
|
||||
"doc_hash": "test_hash"}',
|
||||
vdb_Field.CONTENT_KEY: "content",
|
||||
vdb_Field.PRIMARY_KEY: "test_id",
|
||||
vdb_Field.VECTOR: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
|
||||
},
|
||||
id="test_id",
|
||||
score=0.10,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_vikingdb_mock(monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__)
|
||||
monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection)
|
||||
monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection)
|
||||
monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection)
|
||||
monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index)
|
||||
monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index)
|
||||
monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index)
|
||||
monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data)
|
||||
monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data)
|
||||
monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data)
|
||||
monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector)
|
||||
monkeypatch.setattr(Index, "search", MockVikingDBClass.search)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,49 @@
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
|
||||
|
||||
|
||||
class AnalyticdbVectorTest(AbstractVectorTest):
|
||||
def __init__(self, config_type: str):
|
||||
super().__init__()
|
||||
# Analyticdb requires collection_name length less than 60.
|
||||
# it's ok for normal usage.
|
||||
self.collection_name = self.collection_name.replace("_test", "")
|
||||
if config_type == "sql":
|
||||
self.vector = AnalyticdbVector(
|
||||
collection_name=self.collection_name,
|
||||
sql_config=AnalyticdbVectorBySqlConfig(
|
||||
host="test_host",
|
||||
port=5432,
|
||||
account="test_account",
|
||||
account_password="test_passwd",
|
||||
namespace="difytest_namespace",
|
||||
),
|
||||
api_config=None,
|
||||
)
|
||||
else:
|
||||
self.vector = AnalyticdbVector(
|
||||
collection_name=self.collection_name,
|
||||
sql_config=None,
|
||||
api_config=AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id="test_key_id",
|
||||
access_key_secret="test_key_secret",
|
||||
region_id="test_region",
|
||||
instance_id="test_id",
|
||||
account="test_account",
|
||||
account_password="test_passwd",
|
||||
namespace="difytest_namespace",
|
||||
collection="difytest_collection",
|
||||
namespace_password="test_passwd",
|
||||
),
|
||||
)
|
||||
|
||||
def run_all_tests(self):
|
||||
self.vector.delete()
|
||||
return super().run_all_tests()
|
||||
|
||||
|
||||
def test_chroma_vector(setup_mock_redis):
|
||||
AnalyticdbVectorTest("api").run_all_tests()
|
||||
AnalyticdbVectorTest("sql").run_all_tests()
|
||||
31
dify/api/tests/integration_tests/vdb/baidu/test_baidu.py
Normal file
31
dify/api/tests/integration_tests/vdb/baidu/test_baidu.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
|
||||
from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class BaiduVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = BaiduVector(
|
||||
"dify",
|
||||
BaiduConfig(
|
||||
endpoint="http://127.0.0.1:5287",
|
||||
account="root",
|
||||
api_key="dify",
|
||||
database="dify",
|
||||
shard=1,
|
||||
replicas=3,
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
|
||||
def test_baidu_vector(setup_mock_redis, setup_baiduvectordb_mock):
|
||||
BaiduVectorTest().run_all_tests()
|
||||
33
dify/api/tests/integration_tests/vdb/chroma/test_chroma.py
Normal file
33
dify/api/tests/integration_tests/vdb/chroma/test_chroma.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import chromadb
|
||||
|
||||
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaConfig, ChromaVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class ChromaVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = ChromaVector(
|
||||
collection_name=self.collection_name,
|
||||
config=ChromaConfig(
|
||||
host="localhost",
|
||||
port=8000,
|
||||
tenant=chromadb.DEFAULT_TENANT,
|
||||
database=chromadb.DEFAULT_DATABASE,
|
||||
auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
|
||||
auth_credentials="difyai123456",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
# chroma dos not support full text searching
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
|
||||
def test_chroma_vector(setup_mock_redis):
|
||||
ChromaVectorTest().run_all_tests()
|
||||
25
dify/api/tests/integration_tests/vdb/clickzetta/README.md
Normal file
25
dify/api/tests/integration_tests/vdb/clickzetta/README.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# Clickzetta Integration Tests
|
||||
|
||||
## Running Tests
|
||||
|
||||
To run the Clickzetta integration tests, you need to set the following environment variables:
|
||||
|
||||
```bash
|
||||
export CLICKZETTA_USERNAME=your_username
|
||||
export CLICKZETTA_PASSWORD=your_password
|
||||
export CLICKZETTA_INSTANCE=your_instance
|
||||
export CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
export CLICKZETTA_WORKSPACE=your_workspace
|
||||
export CLICKZETTA_VCLUSTER=your_vcluster
|
||||
export CLICKZETTA_SCHEMA=dify
|
||||
```
|
||||
|
||||
Then run the tests:
|
||||
|
||||
```bash
|
||||
pytest api/tests/integration_tests/vdb/clickzetta/
|
||||
```
|
||||
|
||||
## Security Note
|
||||
|
||||
Never commit credentials to the repository. Always use environment variables or secure credential management systems.
|
||||
@@ -0,0 +1,223 @@
|
||||
import contextlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class TestClickzettaVector(AbstractVectorTest):
|
||||
"""
|
||||
Test cases for Clickzetta vector database integration.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store(self):
|
||||
"""Create a Clickzetta vector store instance for testing."""
|
||||
# Skip test if Clickzetta credentials are not configured
|
||||
if not os.getenv("CLICKZETTA_USERNAME"):
|
||||
pytest.skip("CLICKZETTA_USERNAME is not configured")
|
||||
if not os.getenv("CLICKZETTA_PASSWORD"):
|
||||
pytest.skip("CLICKZETTA_PASSWORD is not configured")
|
||||
if not os.getenv("CLICKZETTA_INSTANCE"):
|
||||
pytest.skip("CLICKZETTA_INSTANCE is not configured")
|
||||
|
||||
config = ClickzettaConfig(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", ""),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", ""),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", ""),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"),
|
||||
batch_size=10, # Small batch size for testing
|
||||
enable_inverted_index=True,
|
||||
analyzer_type="chinese",
|
||||
analyzer_mode="smart",
|
||||
vector_distance_function="cosine_distance",
|
||||
)
|
||||
|
||||
with setup_mock_redis():
|
||||
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
|
||||
|
||||
yield vector
|
||||
|
||||
# Cleanup: delete the test collection
|
||||
with contextlib.suppress(Exception):
|
||||
vector.delete()
|
||||
|
||||
def test_clickzetta_vector_basic_operations(self, vector_store):
|
||||
"""Test basic CRUD operations on Clickzetta vector store."""
|
||||
# Prepare test data
|
||||
texts = [
|
||||
"这是第一个测试文档,包含一些中文内容。",
|
||||
"This is the second test document with English content.",
|
||||
"第三个文档混合了English和中文内容。",
|
||||
]
|
||||
embeddings = [
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[0.5, 0.6, 0.7, 0.8],
|
||||
[0.9, 1.0, 1.1, 1.2],
|
||||
]
|
||||
documents = [
|
||||
Document(page_content=text, metadata={"doc_id": f"doc_{i}", "source": "test"})
|
||||
for i, text in enumerate(texts)
|
||||
]
|
||||
|
||||
# Test create (initial insert)
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test text_exists
|
||||
assert vector_store.text_exists("doc_0")
|
||||
assert not vector_store.text_exists("doc_999")
|
||||
|
||||
# Test search_by_vector
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
results = vector_store.search_by_vector(query_vector, top_k=2)
|
||||
assert len(results) > 0
|
||||
assert results[0].page_content == texts[0] # Should match the first document
|
||||
|
||||
# Test search_by_full_text (Chinese)
|
||||
results = vector_store.search_by_full_text("中文", top_k=3)
|
||||
assert len(results) >= 2 # Should find documents with Chinese content
|
||||
|
||||
# Test search_by_full_text (English)
|
||||
results = vector_store.search_by_full_text("English", top_k=3)
|
||||
assert len(results) >= 2 # Should find documents with English content
|
||||
|
||||
# Test delete_by_ids
|
||||
vector_store.delete_by_ids(["doc_0"])
|
||||
assert not vector_store.text_exists("doc_0")
|
||||
assert vector_store.text_exists("doc_1")
|
||||
|
||||
# Test delete_by_metadata_field
|
||||
vector_store.delete_by_metadata_field("source", "test")
|
||||
assert not vector_store.text_exists("doc_1")
|
||||
assert not vector_store.text_exists("doc_2")
|
||||
|
||||
def test_clickzetta_vector_advanced_search(self, vector_store):
|
||||
"""Test advanced search features of Clickzetta vector store."""
|
||||
# Prepare test data with more complex metadata
|
||||
documents = []
|
||||
embeddings = []
|
||||
for i in range(10):
|
||||
doc = Document(
|
||||
page_content=f"Document {i}: " + get_example_text(),
|
||||
metadata={
|
||||
"doc_id": f"adv_doc_{i}",
|
||||
"category": "technical" if i % 2 == 0 else "general",
|
||||
"document_id": f"doc_{i // 3}", # Group documents
|
||||
"importance": i,
|
||||
},
|
||||
)
|
||||
documents.append(doc)
|
||||
# Create varied embeddings
|
||||
embeddings.append([0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i])
|
||||
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test vector search with document filter
|
||||
query_vector = [0.5, 1.0, 1.5, 2.0]
|
||||
results = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["doc_0", "doc_1"])
|
||||
assert len(results) > 0
|
||||
# All results should belong to doc_0 or doc_1 groups
|
||||
for result in results:
|
||||
assert result.metadata["document_id"] in ["doc_0", "doc_1"]
|
||||
|
||||
# Test score threshold
|
||||
results = vector_store.search_by_vector(query_vector, top_k=10, score_threshold=0.5)
|
||||
# Check that all results have a score above threshold
|
||||
for result in results:
|
||||
assert result.metadata.get("score", 0) >= 0.5
|
||||
|
||||
def test_clickzetta_batch_operations(self, vector_store):
|
||||
"""Test batch insertion operations."""
|
||||
# Prepare large batch of documents
|
||||
batch_size = 25
|
||||
documents = []
|
||||
embeddings = []
|
||||
|
||||
for i in range(batch_size):
|
||||
doc = Document(
|
||||
page_content=f"Batch document {i}: This is a test document for batch processing.",
|
||||
metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"},
|
||||
)
|
||||
documents.append(doc)
|
||||
embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)])
|
||||
|
||||
# Test batch insert
|
||||
vector_store.add_texts(documents=documents, embeddings=embeddings)
|
||||
|
||||
# Verify all documents were inserted
|
||||
for i in range(batch_size):
|
||||
assert vector_store.text_exists(f"batch_doc_{i}")
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_metadata_field("batch", "test_batch")
|
||||
|
||||
def test_clickzetta_edge_cases(self, vector_store):
|
||||
"""Test edge cases and error handling."""
|
||||
# Test empty operations
|
||||
vector_store.create(texts=[], embeddings=[])
|
||||
vector_store.add_texts(documents=[], embeddings=[])
|
||||
vector_store.delete_by_ids([])
|
||||
|
||||
# Test special characters in content
|
||||
special_doc = Document(
|
||||
page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline",
|
||||
metadata={"doc_id": "special_doc", "test": "edge_case"},
|
||||
)
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4]]
|
||||
|
||||
vector_store.add_texts(documents=[special_doc], embeddings=embeddings)
|
||||
assert vector_store.text_exists("special_doc")
|
||||
|
||||
# Test search with special characters
|
||||
results = vector_store.search_by_full_text("quotes", top_k=1)
|
||||
if results: # Full-text search might not be available
|
||||
assert len(results) > 0
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_ids(["special_doc"])
|
||||
|
||||
def test_clickzetta_full_text_search_modes(self, vector_store):
|
||||
"""Test different full-text search capabilities."""
|
||||
# Prepare documents with various language content
|
||||
documents = [
|
||||
Document(
|
||||
page_content="云器科技提供强大的Lakehouse解决方案", metadata={"doc_id": "cn_doc_1", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Clickzetta provides powerful Lakehouse solutions",
|
||||
metadata={"doc_id": "en_doc_1", "lang": "english"},
|
||||
),
|
||||
Document(
|
||||
page_content="Lakehouse是现代数据架构的重要组成部分", metadata={"doc_id": "cn_doc_2", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Modern data architecture includes Lakehouse technology",
|
||||
metadata={"doc_id": "en_doc_2", "lang": "english"},
|
||||
),
|
||||
]
|
||||
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4] for _ in documents]
|
||||
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test Chinese full-text search
|
||||
results = vector_store.search_by_full_text("Lakehouse", top_k=4)
|
||||
assert len(results) >= 2 # Should find at least documents with "Lakehouse"
|
||||
|
||||
# Test English full-text search
|
||||
results = vector_store.search_by_full_text("solutions", top_k=2)
|
||||
assert len(results) >= 1 # Should find English documents with "solutions"
|
||||
|
||||
# Test mixed search
|
||||
results = vector_store.search_by_full_text("数据架构", top_k=2)
|
||||
assert len(results) >= 1 # Should find Chinese documents with this phrase
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_metadata_field("lang", "chinese")
|
||||
vector_store.delete_by_metadata_field("lang", "english")
|
||||
@@ -0,0 +1,165 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Clickzetta integration in Docker environment
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from clickzetta import connect
|
||||
|
||||
|
||||
def test_clickzetta_connection():
|
||||
"""Test direct connection to Clickzetta"""
|
||||
print("=== Testing direct Clickzetta connection ===")
|
||||
try:
|
||||
conn = connect(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", "test_password"),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"),
|
||||
database=os.getenv("CLICKZETTA_SCHEMA", "dify"),
|
||||
)
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Test basic connectivity
|
||||
cursor.execute("SELECT 1 as test")
|
||||
result = cursor.fetchone()
|
||||
print(f"✓ Connection test: {result}")
|
||||
|
||||
# Check if our test table exists
|
||||
cursor.execute("SHOW TABLES IN dify")
|
||||
tables = cursor.fetchall()
|
||||
print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}")
|
||||
|
||||
# Check if test collection exists
|
||||
test_collection = "collection_test_dataset"
|
||||
if test_collection in [t[1] for t in tables if t[0] == "dify"]:
|
||||
cursor.execute(f"DESCRIBE dify.{test_collection}")
|
||||
columns = cursor.fetchall()
|
||||
print(f"✓ Table structure for {test_collection}:")
|
||||
for col in columns:
|
||||
print(f" - {col[0]}: {col[1]}")
|
||||
|
||||
# Check for indexes
|
||||
cursor.execute(f"SHOW INDEXES IN dify.{test_collection}")
|
||||
indexes = cursor.fetchall()
|
||||
print(f"✓ Indexes on {test_collection}:")
|
||||
for idx in indexes:
|
||||
print(f" - {idx}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ Connection test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_dify_api():
|
||||
"""Test Dify API with Clickzetta backend"""
|
||||
print("\n=== Testing Dify API ===")
|
||||
base_url = "http://localhost:5001"
|
||||
|
||||
# Wait for API to be ready
|
||||
max_retries = 30
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = httpx.get(f"{base_url}/console/api/health")
|
||||
if response.status_code == 200:
|
||||
print("✓ Dify API is ready")
|
||||
break
|
||||
except:
|
||||
if i == max_retries - 1:
|
||||
print("✗ Dify API is not responding")
|
||||
return False
|
||||
time.sleep(2)
|
||||
|
||||
# Check vector store configuration
|
||||
try:
|
||||
# This is a simplified check - in production, you'd use proper auth
|
||||
print("✓ Dify is configured to use Clickzetta as vector store")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ API test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def verify_table_structure():
|
||||
"""Verify the table structure meets Dify requirements"""
|
||||
print("\n=== Verifying Table Structure ===")
|
||||
|
||||
expected_columns = {
|
||||
"id": "VARCHAR",
|
||||
"page_content": "VARCHAR",
|
||||
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
|
||||
"vector": "ARRAY<FLOAT>",
|
||||
}
|
||||
|
||||
expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"]
|
||||
|
||||
print("✓ Expected table structure:")
|
||||
for col, dtype in expected_columns.items():
|
||||
print(f" - {col}: {dtype}")
|
||||
|
||||
print("\n✓ Required metadata fields:")
|
||||
for field in expected_metadata_fields:
|
||||
print(f" - {field}")
|
||||
|
||||
print("\n✓ Index requirements:")
|
||||
print(" - Vector index (HNSW) on 'vector' column")
|
||||
print(" - Full-text index on 'page_content' (optional)")
|
||||
print(" - Functional index on metadata->>'$.doc_id' (recommended)")
|
||||
print(" - Functional index on metadata->>'$.document_id' (recommended)")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("Starting Clickzetta integration tests for Dify Docker\n")
|
||||
|
||||
tests = [
|
||||
("Direct Clickzetta Connection", test_clickzetta_connection),
|
||||
("Dify API Status", test_dify_api),
|
||||
("Table Structure Verification", verify_table_structure),
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_name, test_func in tests:
|
||||
try:
|
||||
success = test_func()
|
||||
results.append((test_name, success))
|
||||
except Exception as e:
|
||||
print(f"\n✗ {test_name} crashed: {e}")
|
||||
results.append((test_name, False))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Test Summary:")
|
||||
print("=" * 50)
|
||||
|
||||
passed = sum(1 for _, success in results if success)
|
||||
total = len(results)
|
||||
|
||||
for test_name, success in results:
|
||||
status = "✅ PASSED" if success else "❌ FAILED"
|
||||
print(f"{test_name}: {status}")
|
||||
|
||||
print(f"\nTotal: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
|
||||
print("\nNext steps:")
|
||||
print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d")
|
||||
print("2. Access Dify at http://localhost:3000")
|
||||
print("3. Create a dataset and test vector storage with Clickzetta")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please check the errors above.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
@@ -0,0 +1,49 @@
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
def wait_for_healthy_container(service_name="couchbase-server", timeout=300):
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
result = subprocess.run(
|
||||
["docker", "inspect", "--format", "{{.State.Health.Status}}", service_name], capture_output=True, text=True
|
||||
)
|
||||
if result.stdout.strip() == "healthy":
|
||||
print(f"{service_name} is healthy!")
|
||||
return True
|
||||
else:
|
||||
print(f"Waiting for {service_name} to be healthy...")
|
||||
time.sleep(10)
|
||||
raise TimeoutError(f"{service_name} did not become healthy in time")
|
||||
|
||||
|
||||
class CouchbaseTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = CouchbaseVector(
|
||||
collection_name=self.collection_name,
|
||||
config=CouchbaseConfig(
|
||||
connection_string="couchbase://127.0.0.1",
|
||||
user="Administrator",
|
||||
password="password",
|
||||
bucket_name="Embeddings",
|
||||
scope_name="_default",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
# brief sleep to ensure document is indexed
|
||||
time.sleep(5)
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
|
||||
|
||||
def test_couchbase(setup_mock_redis):
|
||||
wait_for_healthy_container("couchbase-server", timeout=60)
|
||||
CouchbaseTest().run_all_tests()
|
||||
@@ -0,0 +1,22 @@
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class ElasticSearchVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
|
||||
self.vector = ElasticSearchVector(
|
||||
index_name=self.collection_name.lower(),
|
||||
config=ElasticSearchConfig(
|
||||
use_cloud=False, host="http://localhost", port="9200", username="elastic", password="elastic"
|
||||
),
|
||||
attributes=self.attributes,
|
||||
)
|
||||
|
||||
|
||||
def test_elasticsearch_vector(setup_mock_redis):
|
||||
ElasticSearchVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,28 @@
|
||||
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig
|
||||
from tests.integration_tests.vdb.__mock.huaweicloudvectordb import setup_client_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class HuaweiCloudVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = HuaweiCloudVector(
|
||||
"dify",
|
||||
HuaweiCloudVectorConfig(
|
||||
hosts="https://127.0.0.1:9200",
|
||||
username="dify",
|
||||
password="dify",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 3
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 3
|
||||
|
||||
|
||||
def test_huawei_cloud_vector(setup_mock_redis, setup_client_mock):
|
||||
HuaweiCloudVectorTest().run_all_tests()
|
||||
58
dify/api/tests/integration_tests/vdb/lindorm/test_lindorm.py
Normal file
58
dify/api/tests/integration_tests/vdb/lindorm/test_lindorm.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
|
||||
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
|
||||
|
||||
|
||||
class Config:
|
||||
SEARCH_ENDPOINT = os.environ.get(
|
||||
"SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070"
|
||||
)
|
||||
SEARCH_USERNAME = os.environ.get("SEARCH_USERNAME", "ADMIN")
|
||||
SEARCH_PWD = os.environ.get("SEARCH_PWD", "ADMIN")
|
||||
USING_UGC = os.environ.get("USING_UGC", "True").lower() == "true"
|
||||
|
||||
|
||||
class TestLindormVectorStore(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = LindormVectorStore(
|
||||
collection_name=self.collection_name,
|
||||
config=LindormVectorStoreConfig(
|
||||
hosts=Config.SEARCH_ENDPOINT,
|
||||
username=Config.SEARCH_USERNAME,
|
||||
password=Config.SEARCH_PWD,
|
||||
),
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||
assert ids is not None
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == self.example_doc_id
|
||||
|
||||
|
||||
class TestLindormVectorStoreUGC(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = LindormVectorStore(
|
||||
collection_name="ugc_index_test",
|
||||
config=LindormVectorStoreConfig(
|
||||
hosts=Config.SEARCH_ENDPOINT,
|
||||
username=Config.SEARCH_USERNAME,
|
||||
password=Config.SEARCH_PWD,
|
||||
using_ugc=Config.USING_UGC,
|
||||
),
|
||||
routing_value=self.collection_name,
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||
assert ids is not None
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == self.example_doc_id
|
||||
|
||||
|
||||
def test_lindorm_vector_ugc(setup_mock_redis):
|
||||
TestLindormVectorStore().run_all_tests()
|
||||
TestLindormVectorStoreUGC().run_all_tests()
|
||||
@@ -0,0 +1,24 @@
|
||||
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class MatrixoneVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = MatrixoneVector(
|
||||
collection_name=self.collection_name,
|
||||
config=MatrixoneConfig(
|
||||
host="localhost", port=6001, user="dump", password="111", database="dify", metric="l2"
|
||||
),
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
def test_matrixone_vector(setup_mock_redis):
|
||||
MatrixoneVectorTest().run_all_tests()
|
||||
32
dify/api/tests/integration_tests/vdb/milvus/test_milvus.py
Normal file
32
dify/api/tests/integration_tests/vdb/milvus/test_milvus.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class MilvusVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = MilvusVector(
|
||||
collection_name=self.collection_name,
|
||||
config=MilvusConfig(
|
||||
uri="http://localhost:19530",
|
||||
user="root",
|
||||
password="Milvus",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
# milvus support BM25 full text search after version 2.5.0-beta
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) >= 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
def test_milvus_vector(setup_mock_redis):
|
||||
MilvusVectorTest().run_all_tests()
|
||||
29
dify/api/tests/integration_tests/vdb/myscale/test_myscale.py
Normal file
29
dify/api/tests/integration_tests/vdb/myscale/test_myscale.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleConfig, MyScaleVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class MyScaleVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = MyScaleVector(
|
||||
collection_name=self.collection_name,
|
||||
config=MyScaleConfig(
|
||||
host="localhost",
|
||||
port=8123,
|
||||
user="default",
|
||||
password="",
|
||||
database="dify",
|
||||
fts_params="",
|
||||
),
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
def test_myscale_vector(setup_mock_redis):
|
||||
MyScaleVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import (
|
||||
OceanBaseVector,
|
||||
OceanBaseVectorConfig,
|
||||
)
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oceanbase_vector():
|
||||
return OceanBaseVector(
|
||||
"dify_test_collection",
|
||||
config=OceanBaseVectorConfig(
|
||||
host="127.0.0.1",
|
||||
port=2881,
|
||||
user="root",
|
||||
database="test",
|
||||
password="difyai123456",
|
||||
enable_hybrid_search=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class OceanBaseVectorTest(AbstractVectorTest):
|
||||
def __init__(self, vector: OceanBaseVector):
|
||||
super().__init__()
|
||||
self.vector = vector
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
def test_oceanbase_vector(
|
||||
setup_mock_redis,
|
||||
oceanbase_vector,
|
||||
):
|
||||
OceanBaseVectorTest(oceanbase_vector).run_all_tests()
|
||||
@@ -0,0 +1,41 @@
|
||||
import time
|
||||
|
||||
import psycopg2
|
||||
|
||||
from core.rag.datasource.vdb.opengauss.opengauss import OpenGauss, OpenGaussConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class OpenGaussTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
max_retries = 5
|
||||
retry_delay = 20
|
||||
retry_count = 0
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
config = OpenGaussConfig(
|
||||
host="localhost",
|
||||
port=6600,
|
||||
user="postgres",
|
||||
password="Dify@123",
|
||||
database="dify",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
)
|
||||
break
|
||||
except psycopg2.OperationalError as e:
|
||||
retry_count += 1
|
||||
if retry_count < max_retries:
|
||||
time.sleep(retry_delay)
|
||||
self.vector = OpenGauss(
|
||||
collection_name=self.collection_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def test_opengauss(setup_mock_redis):
|
||||
OpenGaussTest().run_all_tests()
|
||||
@@ -0,0 +1,237 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchConfig, OpenSearchVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions import ext_redis
|
||||
|
||||
|
||||
def get_example_text() -> str:
|
||||
return "This is a sample text for testing purposes."
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def setup_mock_redis():
|
||||
ext_redis.redis_client.get = MagicMock(return_value=None)
|
||||
ext_redis.redis_client.set = MagicMock(return_value=None)
|
||||
|
||||
mock_redis_lock = MagicMock()
|
||||
mock_redis_lock.__enter__ = MagicMock()
|
||||
mock_redis_lock.__exit__ = MagicMock()
|
||||
ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock)
|
||||
|
||||
|
||||
class TestOpenSearchConfig:
|
||||
def test_to_opensearch_params(self):
|
||||
config = OpenSearchConfig(
|
||||
host="localhost",
|
||||
port=9200,
|
||||
secure=True,
|
||||
user="admin",
|
||||
password="password",
|
||||
)
|
||||
|
||||
params = config.to_opensearch_params()
|
||||
|
||||
assert params["hosts"] == [{"host": "localhost", "port": 9200}]
|
||||
assert params["use_ssl"] is True
|
||||
assert params["verify_certs"] is True
|
||||
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
|
||||
assert params["http_auth"] == ("admin", "password")
|
||||
|
||||
@patch("boto3.Session")
|
||||
@patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth")
|
||||
def test_to_opensearch_params_with_aws_managed_iam(
|
||||
self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock
|
||||
):
|
||||
mock_credentials = MagicMock()
|
||||
mock_boto_session.return_value.get_credentials.return_value = mock_credentials
|
||||
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_aws_signer_auth.return_value = mock_auth_instance
|
||||
|
||||
aws_region = "ap-southeast-2"
|
||||
aws_service = "aoss"
|
||||
host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com"
|
||||
port = 9201
|
||||
|
||||
config = OpenSearchConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
secure=True,
|
||||
auth_method="aws_managed_iam",
|
||||
aws_region=aws_region,
|
||||
aws_service=aws_service,
|
||||
)
|
||||
|
||||
params = config.to_opensearch_params()
|
||||
|
||||
assert params["hosts"] == [{"host": host, "port": port}]
|
||||
assert params["use_ssl"] is True
|
||||
assert params["verify_certs"] is True
|
||||
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
|
||||
assert params["http_auth"] is mock_auth_instance
|
||||
|
||||
mock_aws_signer_auth.assert_called_once_with(
|
||||
credentials=mock_credentials, region=aws_region, service=aws_service
|
||||
)
|
||||
assert mock_boto_session.return_value.get_credentials.called
|
||||
|
||||
|
||||
class TestOpenSearchVector:
|
||||
def setup_method(self):
|
||||
self.collection_name = "test_collection"
|
||||
self.example_doc_id = "example_doc_id"
|
||||
self.vector = OpenSearchVector(
|
||||
collection_name=self.collection_name,
|
||||
config=OpenSearchConfig(host="localhost", port=9200, secure=False, user="admin", password="password"),
|
||||
)
|
||||
self.vector._client = MagicMock()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("search_response", "expected_length", "expected_doc_id"),
|
||||
[
|
||||
(
|
||||
{
|
||||
"hits": {
|
||||
"total": {"value": 1},
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
"page_content": get_example_text(),
|
||||
"metadata": {"document_id": "example_doc_id"},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
1,
|
||||
"example_doc_id",
|
||||
),
|
||||
({"hits": {"total": {"value": 0}, "hits": []}}, 0, None),
|
||||
],
|
||||
)
|
||||
def test_search_by_full_text(self, search_response, expected_length, expected_doc_id):
|
||||
self.vector._client.search.return_value = search_response
|
||||
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == expected_length
|
||||
if expected_length > 0:
|
||||
assert hits_by_full_text[0].metadata["document_id"] == expected_doc_id
|
||||
|
||||
def test_search_by_vector(self):
|
||||
vector = [0.1] * 128
|
||||
mock_response = {
|
||||
"hits": {
|
||||
"total": {"value": 1},
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
Field.CONTENT_KEY: get_example_text(),
|
||||
Field.METADATA_KEY: {"document_id": self.example_doc_id},
|
||||
},
|
||||
"_score": 1.0,
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
self.vector._client.search.return_value = mock_response
|
||||
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=vector)
|
||||
|
||||
print("Hits by vector:", hits_by_vector)
|
||||
print("Expected document ID:", self.example_doc_id)
|
||||
print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits")
|
||||
|
||||
assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}"
|
||||
assert hits_by_vector[0].metadata["document_id"] == self.example_doc_id, (
|
||||
f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}"
|
||||
)
|
||||
|
||||
def test_get_ids_by_metadata_field(self):
|
||||
mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}}
|
||||
self.vector._client.search.return_value = mock_response
|
||||
|
||||
doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
|
||||
embedding = [0.1] * 128
|
||||
|
||||
with patch("opensearchpy.helpers.bulk") as mock_bulk:
|
||||
mock_bulk.return_value = ([], [])
|
||||
self.vector.add_texts([doc], [embedding])
|
||||
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == "mock_id"
|
||||
|
||||
def test_add_texts(self):
|
||||
self.vector._client.index.return_value = {"result": "created"}
|
||||
|
||||
doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
|
||||
embedding = [0.1] * 128
|
||||
|
||||
with patch("opensearchpy.helpers.bulk") as mock_bulk:
|
||||
mock_bulk.return_value = ([], [])
|
||||
self.vector.add_texts([doc], [embedding])
|
||||
|
||||
mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}}
|
||||
self.vector._client.search.return_value = mock_response
|
||||
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == "mock_id"
|
||||
|
||||
def test_delete_nonexistent_index(self):
|
||||
"""Test deleting a non-existent index."""
|
||||
# Create a vector instance with a non-existent collection name
|
||||
self.vector._client.indices.exists.return_value = False
|
||||
|
||||
# Should not raise an exception
|
||||
self.vector.delete()
|
||||
|
||||
# Verify that exists was called but delete was not
|
||||
self.vector._client.indices.exists.assert_called_once_with(index=self.collection_name.lower())
|
||||
self.vector._client.indices.delete.assert_not_called()
|
||||
|
||||
def test_delete_existing_index(self):
|
||||
"""Test deleting an existing index."""
|
||||
self.vector._client.indices.exists.return_value = True
|
||||
|
||||
self.vector.delete()
|
||||
|
||||
# Verify both exists and delete were called
|
||||
self.vector._client.indices.exists.assert_called_once_with(index=self.collection_name.lower())
|
||||
self.vector._client.indices.delete.assert_called_once_with(index=self.collection_name.lower())
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_mock_redis")
|
||||
class TestOpenSearchVectorWithRedis:
|
||||
def setup_method(self):
|
||||
self.tester = TestOpenSearchVector()
|
||||
|
||||
def test_search_by_full_text(self):
|
||||
self.tester.setup_method()
|
||||
search_response = {
|
||||
"hits": {
|
||||
"total": {"value": 1},
|
||||
"hits": [
|
||||
{"_source": {"page_content": get_example_text(), "metadata": {"document_id": "example_doc_id"}}}
|
||||
],
|
||||
}
|
||||
}
|
||||
expected_length = 1
|
||||
expected_doc_id = "example_doc_id"
|
||||
self.tester.test_search_by_full_text(search_response, expected_length, expected_doc_id)
|
||||
|
||||
def test_get_ids_by_metadata_field(self):
|
||||
self.tester.setup_method()
|
||||
self.tester.test_get_ids_by_metadata_field()
|
||||
|
||||
def test_add_texts(self):
|
||||
self.tester.setup_method()
|
||||
self.tester.test_add_texts()
|
||||
|
||||
def test_search_by_vector(self):
|
||||
self.tester.setup_method()
|
||||
self.tester.test_search_by_vector()
|
||||
@@ -0,0 +1,28 @@
|
||||
from core.rag.datasource.vdb.oracle.oraclevector import OracleVector, OracleVectorConfig
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class OracleVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = OracleVector(
|
||||
collection_name=self.collection_name,
|
||||
config=OracleVectorConfig(
|
||||
user="dify",
|
||||
password="dify",
|
||||
dsn="localhost:1521/FREEPDB1",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
|
||||
def test_oraclevector(setup_mock_redis):
|
||||
OracleVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,35 @@
|
||||
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class PGVectoRSVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = PGVectoRS(
|
||||
collection_name=self.collection_name.lower(),
|
||||
config=PgvectoRSConfig(
|
||||
host="localhost",
|
||||
port=5431,
|
||||
user="postgres",
|
||||
password="difyai123456",
|
||||
database="dify",
|
||||
),
|
||||
dim=128,
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
# pgvecto rs only support english text search, So it’s not open for now
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
def test_pgvecto_rs(setup_mock_redis):
|
||||
PGVectoRSVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,27 @@
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class PGVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = PGVector(
|
||||
collection_name=self.collection_name,
|
||||
config=PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5433,
|
||||
user="postgres",
|
||||
password="difyai123456",
|
||||
database="dify",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_pgvector(setup_mock_redis):
|
||||
PGVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,26 @@
|
||||
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVector, VastbaseVectorConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class VastbaseVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = VastbaseVector(
|
||||
collection_name=self.collection_name,
|
||||
config=VastbaseVectorConfig(
|
||||
host="localhost",
|
||||
port=5434,
|
||||
user="dify",
|
||||
password="Difyai123456",
|
||||
database="dify",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_vastbase_vector(setup_mock_redis):
|
||||
VastbaseVectorTest().run_all_tests()
|
||||
32
dify/api/tests/integration_tests/vdb/qdrant/test_qdrant.py
Normal file
32
dify/api/tests/integration_tests/vdb/qdrant/test_qdrant.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class QdrantVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
|
||||
self.vector = QdrantVector(
|
||||
collection_name=self.collection_name,
|
||||
group_id=self.dataset_id,
|
||||
config=QdrantConfig(
|
||||
endpoint="http://localhost:6333",
|
||||
api_key="difyai123456",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
super().search_by_vector()
|
||||
# only test for qdrant, may not work on other vector stores
|
||||
hits_by_vector: list[Document] = self.vector.search_by_vector(
|
||||
query_vector=self.example_embedding, score_threshold=1
|
||||
)
|
||||
assert len(hits_by_vector) == 0
|
||||
|
||||
|
||||
def test_qdrant_vector(setup_mock_redis):
|
||||
QdrantVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,100 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import tablestore
|
||||
from _pytest.python_api import approx
|
||||
|
||||
from core.rag.datasource.vdb.tablestore.tablestore_vector import (
|
||||
TableStoreConfig,
|
||||
TableStoreVector,
|
||||
)
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_document,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class TableStoreVectorTest(AbstractVectorTest):
|
||||
def __init__(self, normalize_full_text_score: bool = False):
|
||||
super().__init__()
|
||||
self.vector = TableStoreVector(
|
||||
collection_name=self.collection_name,
|
||||
config=TableStoreConfig(
|
||||
endpoint=os.getenv("TABLESTORE_ENDPOINT"),
|
||||
instance_name=os.getenv("TABLESTORE_INSTANCE_NAME"),
|
||||
access_key_id=os.getenv("TABLESTORE_ACCESS_KEY_ID"),
|
||||
access_key_secret=os.getenv("TABLESTORE_ACCESS_KEY_SECRET"),
|
||||
normalize_full_text_bm25_score=normalize_full_text_score,
|
||||
),
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||
assert ids is not None
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == self.example_doc_id
|
||||
|
||||
def create_vector(self):
|
||||
self.vector.create(
|
||||
texts=[get_example_document(doc_id=self.example_doc_id)],
|
||||
embeddings=[self.example_embedding],
|
||||
)
|
||||
while True:
|
||||
search_response = self.vector._tablestore_client.search(
|
||||
table_name=self.vector._table_name,
|
||||
index_name=self.vector._index_name,
|
||||
search_query=tablestore.SearchQuery(query=tablestore.MatchAllQuery(), get_total_count=True, limit=0),
|
||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||
)
|
||||
if search_response.total_count == 1:
|
||||
break
|
||||
|
||||
def search_by_vector(self):
|
||||
super().search_by_vector()
|
||||
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[self.example_doc_id])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||
assert docs[0].metadata["score"] > 0
|
||||
|
||||
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[str(uuid.uuid4())])
|
||||
assert len(docs) == 0
|
||||
|
||||
def search_by_full_text(self):
|
||||
super().search_by_full_text()
|
||||
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||
if self.vector._config.normalize_full_text_bm25_score:
|
||||
assert docs[0].metadata["score"] == approx(0.1214, abs=1e-3)
|
||||
else:
|
||||
assert docs[0].metadata.get("score") is None
|
||||
|
||||
# return none if normalize_full_text_score=true and score_threshold > 0
|
||||
docs = self.vector.search_by_full_text(
|
||||
get_example_text(), document_ids_filter=[self.example_doc_id], score_threshold=0.5
|
||||
)
|
||||
if self.vector._config.normalize_full_text_bm25_score:
|
||||
assert len(docs) == 0
|
||||
else:
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||
assert docs[0].metadata.get("score") is None
|
||||
|
||||
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())])
|
||||
assert len(docs) == 0
|
||||
|
||||
def run_all_tests(self):
|
||||
try:
|
||||
self.vector.delete()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return super().run_all_tests()
|
||||
|
||||
|
||||
def test_tablestore_vector(setup_mock_redis):
|
||||
TableStoreVectorTest().run_all_tests()
|
||||
TableStoreVectorTest(normalize_full_text_score=True).run_all_tests()
|
||||
TableStoreVectorTest(normalize_full_text_score=False).run_all_tests()
|
||||
@@ -0,0 +1,38 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.rag.datasource.vdb.tencent.tencent_vector import TencentConfig, TencentVector
|
||||
from tests.integration_tests.vdb.__mock.tcvectordb import setup_tcvectordb_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.list_databases.return_value = [{"name": "test"}]
|
||||
|
||||
|
||||
class TencentVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = TencentVector(
|
||||
"dify",
|
||||
TencentConfig(
|
||||
url="http://127.0.0.1",
|
||||
api_key="dify",
|
||||
timeout=30,
|
||||
username="dify",
|
||||
database="dify",
|
||||
shard=1,
|
||||
replicas=2,
|
||||
enable_hybrid_search=True,
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) >= 0
|
||||
|
||||
|
||||
def test_tencent_vector(setup_mock_redis, setup_tcvectordb_mock):
|
||||
TencentVectorTest().run_all_tests()
|
||||
95
dify/api/tests/integration_tests/vdb/test_vector_store.py
Normal file
95
dify/api/tests/integration_tests/vdb/test_vector_store.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions import ext_redis
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
def get_example_text() -> str:
|
||||
return "test_text"
|
||||
|
||||
|
||||
def get_example_document(doc_id: str) -> Document:
|
||||
doc = Document(
|
||||
page_content=get_example_text(),
|
||||
metadata={
|
||||
"doc_id": doc_id,
|
||||
"doc_hash": doc_id,
|
||||
"document_id": doc_id,
|
||||
"dataset_id": doc_id,
|
||||
},
|
||||
)
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_redis():
|
||||
# get
|
||||
ext_redis.redis_client.get = MagicMock(return_value=None)
|
||||
|
||||
# set
|
||||
ext_redis.redis_client.set = MagicMock(return_value=None)
|
||||
|
||||
# lock
|
||||
mock_redis_lock = MagicMock()
|
||||
mock_redis_lock.__enter__ = MagicMock()
|
||||
mock_redis_lock.__exit__ = MagicMock()
|
||||
ext_redis.redis_client.lock = mock_redis_lock
|
||||
|
||||
|
||||
class AbstractVectorTest:
|
||||
def __init__(self):
|
||||
self.vector = None
|
||||
self.dataset_id = str(uuid.uuid4())
|
||||
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
|
||||
self.example_doc_id = str(uuid.uuid4())
|
||||
self.example_embedding = [1.001 * i for i in range(128)]
|
||||
|
||||
def create_vector(self):
|
||||
self.vector.create(
|
||||
texts=[get_example_document(doc_id=self.example_doc_id)],
|
||||
embeddings=[self.example_embedding],
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
assert hits_by_vector[0].metadata["doc_id"] == self.example_doc_id
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 1
|
||||
assert hits_by_full_text[0].metadata["doc_id"] == self.example_doc_id
|
||||
|
||||
def delete_vector(self):
|
||||
self.vector.delete()
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
self.vector.delete_by_ids(ids=ids)
|
||||
|
||||
def add_texts(self) -> list[str]:
|
||||
batch_size = 100
|
||||
documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]
|
||||
embeddings = [self.example_embedding] * batch_size
|
||||
self.vector.add_texts(documents=documents, embeddings=embeddings)
|
||||
return [doc.metadata["doc_id"] for doc in documents]
|
||||
|
||||
def text_exists(self):
|
||||
assert self.vector.text_exists(self.example_doc_id)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
with pytest.raises(NotImplementedError):
|
||||
self.vector.get_ids_by_metadata_field(key="key", value="value")
|
||||
|
||||
def run_all_tests(self):
|
||||
self.create_vector()
|
||||
self.search_by_vector()
|
||||
self.search_by_full_text()
|
||||
self.text_exists()
|
||||
self.get_ids_by_metadata_field()
|
||||
added_doc_ids = self.add_texts()
|
||||
self.delete_by_ids(added_doc_ids)
|
||||
self.delete_vector()
|
||||
@@ -0,0 +1,59 @@
|
||||
import time
|
||||
|
||||
import pymysql
|
||||
|
||||
|
||||
def check_tiflash_ready() -> bool:
|
||||
try:
|
||||
connection = pymysql.connect(
|
||||
host="localhost",
|
||||
port=4000,
|
||||
user="root",
|
||||
password="",
|
||||
)
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
# Doc reference:
|
||||
# https://docs.pingcap.com/zh/tidb/stable/information-schema-cluster-hardware
|
||||
select_tiflash_query = """
|
||||
SELECT * FROM information_schema.cluster_hardware
|
||||
WHERE TYPE='tiflash'
|
||||
LIMIT 1;
|
||||
"""
|
||||
cursor.execute(select_tiflash_query)
|
||||
result = cursor.fetchall()
|
||||
return result is not None and len(result) > 0
|
||||
except Exception as e:
|
||||
print(f"TiFlash is not ready. Exception: {e}")
|
||||
return False
|
||||
finally:
|
||||
if connection:
|
||||
connection.close()
|
||||
|
||||
|
||||
def main():
|
||||
max_attempts = 30
|
||||
retry_interval_seconds = 2
|
||||
is_tiflash_ready = False
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
is_tiflash_ready = check_tiflash_ready()
|
||||
except Exception as e:
|
||||
print(f"TiFlash is not ready. Exception: {e}")
|
||||
is_tiflash_ready = False
|
||||
|
||||
if is_tiflash_ready:
|
||||
break
|
||||
else:
|
||||
print(f"Attempt {attempt + 1} failed, retry in {retry_interval_seconds} seconds...")
|
||||
time.sleep(retry_interval_seconds)
|
||||
|
||||
if is_tiflash_ready:
|
||||
print("TiFlash is ready in TiDB.")
|
||||
else:
|
||||
print(f"TiFlash is not ready in TiDB after {max_attempts} attempting checks.")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,40 @@
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
|
||||
from models.dataset import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tidb_vector():
|
||||
return TiDBVector(
|
||||
collection_name="test_collection",
|
||||
config=TiDBVectorConfig(
|
||||
host="localhost",
|
||||
port=4000,
|
||||
user="root",
|
||||
password="",
|
||||
database="test",
|
||||
program_name="langgenius/dify",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TiDBVectorTest(AbstractVectorTest):
|
||||
def __init__(self, vector):
|
||||
super().__init__()
|
||||
self.vector = vector
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
def test_tidb_vector(setup_mock_redis, tidb_vector):
|
||||
# TiDBVectorTest(vector=tidb_vector).run_all_tests()
|
||||
# something wrong with tidb,ignore tidb test
|
||||
return
|
||||
@@ -0,0 +1,28 @@
|
||||
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVector, UpstashVectorConfig
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.__mock.upstashvectordb import setup_upstashvector_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
|
||||
|
||||
|
||||
class UpstashVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = UpstashVector(
|
||||
collection_name="test_collection",
|
||||
config=UpstashVectorConfig(
|
||||
url="your-server-url",
|
||||
token="your-access-token",
|
||||
),
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) != 0
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
|
||||
def test_upstash_vector(setup_upstashvector_mock):
|
||||
UpstashVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,37 @@
|
||||
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBConfig, VikingDBVector
|
||||
from tests.integration_tests.vdb.__mock.vikingdb import setup_vikingdb_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class VikingDBVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = VikingDBVector(
|
||||
"test_collection",
|
||||
"test_group",
|
||||
config=VikingDBConfig(
|
||||
access_key="test_access_key",
|
||||
host="test_host",
|
||||
region="test_region",
|
||||
scheme="test_scheme",
|
||||
secret_key="test_secret_key",
|
||||
connection_timeout=30,
|
||||
socket_timeout=30,
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value="test_document_id")
|
||||
assert len(ids) > 0
|
||||
|
||||
|
||||
def test_vikingdb_vector(setup_mock_redis, setup_vikingdb_mock):
|
||||
VikingDBVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,23 @@
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class WeaviateVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
|
||||
self.vector = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint="http://localhost:8080",
|
||||
api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih",
|
||||
),
|
||||
attributes=self.attributes,
|
||||
)
|
||||
|
||||
|
||||
def test_weaviate_vector(setup_mock_redis):
|
||||
WeaviateVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from jinja2 import Template
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
|
||||
|
||||
|
||||
class MockedCodeExecutor:
|
||||
@classmethod
|
||||
def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict):
|
||||
# invoke directly
|
||||
match language:
|
||||
case CodeLanguage.PYTHON3:
|
||||
return {"result": 3}
|
||||
case CodeLanguage.JINJA2:
|
||||
return {"result": Template(code).render(inputs)}
|
||||
case _:
|
||||
raise Exception("Language not supported")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_code_executor_mock(request, monkeypatch: MonkeyPatch):
|
||||
if not MOCK:
|
||||
yield
|
||||
return
|
||||
|
||||
monkeypatch.setattr(CodeExecutor, "execute_workflow_code_template", MockedCodeExecutor.invoke)
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,56 @@
|
||||
import os
|
||||
from json import dumps
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
|
||||
|
||||
|
||||
class MockedHttp:
|
||||
@staticmethod
|
||||
def httpx_request(
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Mocked httpx.request
|
||||
"""
|
||||
if url == "http://404.com":
|
||||
response = httpx.Response(status_code=404, request=httpx.Request(method, url), content=b"Not Found")
|
||||
return response
|
||||
|
||||
# get data, files
|
||||
data = kwargs.get("data")
|
||||
files = kwargs.get("files")
|
||||
json = kwargs.get("json")
|
||||
content = kwargs.get("content")
|
||||
if data is not None:
|
||||
resp = dumps(data).encode("utf-8")
|
||||
elif files is not None:
|
||||
resp = dumps(files).encode("utf-8")
|
||||
elif json is not None:
|
||||
resp = dumps(json).encode("utf-8")
|
||||
elif content is not None:
|
||||
resp = content
|
||||
else:
|
||||
resp = b"OK"
|
||||
|
||||
response = httpx.Response(
|
||||
status_code=200, request=httpx.Request(method, url), headers=kwargs.get("headers", {}), content=resp
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_http_mock(request, monkeypatch: MonkeyPatch):
|
||||
if not MOCK:
|
||||
yield
|
||||
return
|
||||
|
||||
monkeypatch.setattr(ssrf_proxy, "make_request", MockedHttp.httpx_request)
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,50 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
def get_mocked_fetch_model_config(
|
||||
provider: str,
|
||||
model: str,
|
||||
mode: str,
|
||||
credentials: dict,
|
||||
):
|
||||
model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
configuration=ProviderConfiguration(
|
||||
tenant_id="1",
|
||||
provider=model_provider_factory.get_provider_schema(provider),
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
using_provider_type=ProviderType.CUSTOM,
|
||||
system_configuration=SystemConfiguration(enabled=False),
|
||||
custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)),
|
||||
model_settings=[],
|
||||
),
|
||||
model_type_instance=model_type_instance,
|
||||
)
|
||||
model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model)
|
||||
model_schema = model_provider_factory.get_model_schema(
|
||||
provider=provider,
|
||||
model_type=model_type_instance.model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
assert model_schema is not None
|
||||
model_config = ModelConfigWithCredentialsEntity(
|
||||
model=model,
|
||||
provider=provider,
|
||||
mode=mode,
|
||||
credentials=credentials,
|
||||
parameters={},
|
||||
model_schema=model_schema,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
)
|
||||
|
||||
return MagicMock(return_value=(model_instance, model_config))
|
||||
@@ -0,0 +1,11 @@
|
||||
import pytest
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
|
||||
CODE_LANGUAGE = "unsupported_language"
|
||||
|
||||
|
||||
def test_unsupported_with_code_template():
|
||||
with pytest.raises(CodeExecutionError) as e:
|
||||
CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={})
|
||||
assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user