This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

@@ -0,0 +1,166 @@
import os
from collections import UserDict
from typing import Any
from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from pymochow import MochowClient
from pymochow.model.database import Database
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
from pymochow.model.schema import HNSWParams, VectorIndex
from pymochow.model.table import Table
class AttrDict(UserDict):
def __getattr__(self, item):
return self.get(item)
class MockBaiduVectorDBClass:
def mock_vector_db_client(
self,
config=None,
adapter: Any | None = None,
):
self.conn = MagicMock()
self._config = MagicMock()
def list_databases(self, config=None) -> list[Database]:
return [
Database(
conn=self.conn,
database_name="dify",
config=self._config,
)
]
def create_database(self, database_name: str, config=None) -> Database:
return Database(conn=self.conn, database_name=database_name, config=config)
def list_table(self, config=None) -> list[Table]:
return []
def drop_table(self, table_name: str, config=None):
return {"code": 0, "msg": "Success"}
def create_table(
self,
table_name: str,
replication: int,
partition: int,
schema,
enable_dynamic_field=False,
description: str = "",
config=None,
) -> Table:
return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config)
def describe_table(self, table_name: str, config=None) -> Table:
return Table(
self,
table_name,
3,
1,
None,
enable_dynamic_field=False,
description="table for dify",
config=config,
state=TableState.NORMAL,
)
def upsert(self, rows, config=None):
return {"code": 0, "msg": "operation success", "affectedCount": 1}
def rebuild_index(self, index_name: str, config=None):
return {"code": 0, "msg": "Success"}
def describe_index(self, index_name: str, config=None):
return VectorIndex(
index_name=index_name,
index_type=IndexType.HNSW,
field="vector",
metric_type=MetricType.L2,
params=HNSWParams(m=16, efconstruction=200),
auto_build=False,
state=IndexState.NORMAL,
)
def query(
self,
primary_key,
partition_key=None,
projections=None,
retrieve_vector=False,
read_consistency=ReadConsistency.EVENTUAL,
config=None,
):
return AttrDict(
{
"row": {
"id": primary_key.get("id"),
"vector": [0.23432432, 0.8923744, 0.89238432],
"page_content": "text",
"metadata": {"doc_id": "doc_id_001"},
},
"code": 0,
"msg": "Success",
}
)
def delete(self, primary_key=None, partition_key=None, filter=None, config=None):
return {"code": 0, "msg": "Success"}
def search(
self,
anns,
partition_key=None,
projections=None,
retrieve_vector=False,
read_consistency=ReadConsistency.EVENTUAL,
config=None,
):
return AttrDict(
{
"rows": [
{
"row": {
"id": "doc_id_001",
"vector": [0.23432432, 0.8923744, 0.89238432],
"page_content": "text",
"metadata": {"doc_id": "doc_id_001"},
},
"distance": 0.1,
"score": 0.5,
}
],
"code": 0,
"msg": "Success",
}
)
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client)
monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases)
monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database)
monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table)
monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table)
monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table)
monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table)
monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table)
monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index)
monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index)
monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete)
monkeypatch.setattr(Table, "query", MockBaiduVectorDBClass.query)
monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search)
yield
if MOCK:
monkeypatch.undo()

View File

@@ -0,0 +1,89 @@
import os
import pytest
from _pytest.monkeypatch import MonkeyPatch
from elasticsearch import Elasticsearch
from core.rag.datasource.vdb.field import Field
class MockIndicesClient:
def __init__(self):
pass
def create(self, index, mappings, settings):
return {"acknowledge": True}
def refresh(self, index):
return {"acknowledge": True}
def delete(self, index):
return {"acknowledge": True}
def exists(self, index):
return True
class MockClient:
def __init__(self, **kwargs):
self.indices = MockIndicesClient()
def index(self, **kwargs):
return {"acknowledge": True}
def exists(self, **kwargs):
return True
def delete(self, **kwargs):
return {"acknowledge": True}
def search(self, **kwargs):
return {
"took": 1,
"hits": {
"hits": [
{
"_source": {
Field.CONTENT_KEY: "abcdef",
Field.VECTOR: [1, 2],
Field.METADATA_KEY: {},
},
"_score": 1.0,
},
{
"_source": {
Field.CONTENT_KEY: "123456",
Field.VECTOR: [2, 2],
Field.METADATA_KEY: {},
},
"_score": 0.9,
},
{
"_source": {
Field.CONTENT_KEY: "a1b2c3",
Field.VECTOR: [3, 2],
Field.METADATA_KEY: {},
},
"_score": 0.8,
},
]
},
}
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_client_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Elasticsearch, "__init__", MockClient.__init__)
monkeypatch.setattr(Elasticsearch, "index", MockClient.index)
monkeypatch.setattr(Elasticsearch, "exists", MockClient.exists)
monkeypatch.setattr(Elasticsearch, "delete", MockClient.delete)
monkeypatch.setattr(Elasticsearch, "search", MockClient.search)
yield
if MOCK:
monkeypatch.undo()

View File

@@ -0,0 +1,192 @@
import os
from typing import Any, Union
import pytest
from _pytest.monkeypatch import MonkeyPatch
from tcvectordb import RPCVectorDBClient
from tcvectordb.model import enum
from tcvectordb.model.collection import FilterIndexConfig
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank
from tcvectordb.model.enum import ReadConsistency
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex
from tcvectordb.rpc.model.collection import RPCCollection
from tcvectordb.rpc.model.database import RPCDatabase
from xinference_client.types import Embedding
class MockTcvectordbClass:
def mock_vector_db_client(
self,
url: str,
username="",
key="",
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
timeout=10,
adapter: Any | None = None,
pool_size: int = 2,
proxies: dict | None = None,
password: str | None = None,
**kwargs,
):
self._conn = None
self._read_consistency = read_consistency
def create_database_if_not_exists(self, database_name: str, timeout: float | None = None) -> RPCDatabase:
return RPCDatabase(
name="dify",
read_consistency=self._read_consistency,
)
def exists_collection(self, database_name: str, collection_name: str) -> bool:
return True
def describe_collection(
self, database_name: str, collection_name: str, timeout: float | None = None
) -> RPCCollection:
index = Index(
FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
VectorIndex(
"vector",
128,
enum.IndexType.HNSW,
enum.MetricType.IP,
HNSWParams(m=16, efconstruction=200),
),
FilterIndex("text", enum.FieldType.String, enum.IndexType.FILTER),
FilterIndex("metadata", enum.FieldType.String, enum.IndexType.FILTER),
)
return RPCCollection(
RPCDatabase(
name=database_name,
read_consistency=self._read_consistency,
),
collection_name,
index=index,
)
def create_collection(
self,
database_name: str,
collection_name: str,
shard: int,
replicas: int,
description: str | None = None,
index: Index | None = None,
embedding: Embedding | None = None,
timeout: float | None = None,
ttl_config: dict | None = None,
filter_index_config: FilterIndexConfig | None = None,
indexes: list[IndexField] | None = None,
) -> RPCCollection:
return RPCCollection(
RPCDatabase(
name="dify",
read_consistency=self._read_consistency,
),
collection_name,
shard,
replicas,
description,
index,
embedding=embedding,
read_consistency=self._read_consistency,
timeout=timeout,
ttl_config=ttl_config,
filter_index_config=filter_index_config,
indexes=indexes,
)
def collection_upsert(
self,
database_name: str,
collection_name: str,
documents: list[Union[Document, dict]],
timeout: float | None = None,
build_index: bool = True,
**kwargs,
):
return {"code": 0, "msg": "operation success"}
def collection_search(
self,
database_name: str,
collection_name: str,
vectors: list[list[float]],
filter: Filter | None = None,
params=None,
retrieve_vector: bool = False,
limit: int = 10,
output_fields: list[str] | None = None,
timeout: float | None = None,
) -> list[list[dict]]:
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
def collection_hybrid_search(
self,
database_name: str,
collection_name: str,
ann: Union[list[AnnSearch], AnnSearch] | None = None,
match: Union[list[KeywordSearch], KeywordSearch] | None = None,
filter: Union[Filter, str] | None = None,
rerank: Rerank | None = None,
retrieve_vector: bool | None = None,
output_fields: list[str] | None = None,
limit: int | None = None,
timeout: float | None = None,
return_pd_object=False,
**kwargs,
) -> list[list[dict]]:
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
def collection_query(
self,
database_name: str,
collection_name: str,
document_ids: list | None = None,
retrieve_vector: bool = False,
limit: int | None = None,
offset: int | None = None,
filter: Filter | None = None,
output_fields: list[str] | None = None,
timeout: float | None = None,
):
return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
def collection_delete(
self,
database_name: str,
collection_name: str,
document_ids: list[str] | None = None,
filter: Filter | None = None,
timeout: float | None = None,
):
return {"code": 0, "msg": "operation success"}
def drop_collection(self, database_name: str, collection_name: str, timeout: float | None = None):
return {"code": 0, "msg": "operation success"}
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(RPCVectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
monkeypatch.setattr(
RPCVectorDBClient, "create_database_if_not_exists", MockTcvectordbClass.create_database_if_not_exists
)
monkeypatch.setattr(RPCVectorDBClient, "exists_collection", MockTcvectordbClass.exists_collection)
monkeypatch.setattr(RPCVectorDBClient, "create_collection", MockTcvectordbClass.create_collection)
monkeypatch.setattr(RPCVectorDBClient, "describe_collection", MockTcvectordbClass.describe_collection)
monkeypatch.setattr(RPCVectorDBClient, "upsert", MockTcvectordbClass.collection_upsert)
monkeypatch.setattr(RPCVectorDBClient, "search", MockTcvectordbClass.collection_search)
monkeypatch.setattr(RPCVectorDBClient, "hybrid_search", MockTcvectordbClass.collection_hybrid_search)
monkeypatch.setattr(RPCVectorDBClient, "query", MockTcvectordbClass.collection_query)
monkeypatch.setattr(RPCVectorDBClient, "delete", MockTcvectordbClass.collection_delete)
monkeypatch.setattr(RPCVectorDBClient, "drop_collection", MockTcvectordbClass.drop_collection)
yield
if MOCK:
monkeypatch.undo()

View File

@@ -0,0 +1,75 @@
import os
from collections import UserDict
import pytest
from _pytest.monkeypatch import MonkeyPatch
from upstash_vector import Index
# Mocking the Index class from upstash_vector
class MockIndex:
def __init__(self, url="", token=""):
self.url = url
self.token = token
self.vectors = []
def upsert(self, vectors):
for vector in vectors:
vector.score = 0.5
self.vectors.append(vector)
return {"code": 0, "msg": "operation success", "affectedCount": len(vectors)}
def fetch(self, ids):
return [vector for vector in self.vectors if vector.id in ids]
def delete(self, ids):
self.vectors = [vector for vector in self.vectors if vector.id not in ids]
return {"code": 0, "msg": "Success"}
def query(
self,
vector: None,
top_k: int = 10,
include_vectors: bool = False,
include_metadata: bool = False,
filter: str = "",
data: str | None = None,
namespace: str = "",
include_data: bool = False,
):
# Simple mock query, in real scenario you would calculate similarity
mock_result = []
for vector_data in self.vectors:
mock_result.append(vector_data)
return mock_result[:top_k]
def reset(self):
self.vectors = []
def info(self):
return AttrDict({"dimension": 1024})
class AttrDict(UserDict):
def __getattr__(self, item):
return self.get(item)
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_upstashvector_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Index, "__init__", MockIndex.__init__)
monkeypatch.setattr(Index, "upsert", MockIndex.upsert)
monkeypatch.setattr(Index, "fetch", MockIndex.fetch)
monkeypatch.setattr(Index, "delete", MockIndex.delete)
monkeypatch.setattr(Index, "query", MockIndex.query)
monkeypatch.setattr(Index, "reset", MockIndex.reset)
monkeypatch.setattr(Index, "info", MockIndex.info)
yield
if MOCK:
monkeypatch.undo()

View File

@@ -0,0 +1,215 @@
import os
from typing import Union
from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from volcengine.viking_db import (
Collection,
Data,
DistanceType,
Field,
FieldType,
Index,
IndexType,
QuantType,
VectorIndexParams,
VikingDBService,
)
from core.rag.datasource.vdb.field import Field as vdb_Field
class MockVikingDBClass:
def __init__(
self,
host="api-vikingdb.volces.com",
region="cn-north-1",
ak="",
sk="",
scheme="http",
connection_timeout=30,
socket_timeout=30,
proxy=None,
):
self._viking_db_service = MagicMock()
self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}')
def get_collection(self, collection_name) -> Collection:
return Collection(
collection_name=collection_name,
description="Collection For Dify",
viking_db_service=self._viking_db_service,
primary_key=vdb_Field.PRIMARY_KEY,
fields=[
Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True),
Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String),
Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String),
Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text),
Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=768),
],
indexes=[
Index(
collection_name=collection_name,
index_name=f"{collection_name}_idx",
vector_index=VectorIndexParams(
distance=DistanceType.L2,
index_type=IndexType.HNSW,
quant=QuantType.Float,
),
scalar_index=None,
stat=None,
viking_db_service=self._viking_db_service,
)
],
)
def drop_collection(self, collection_name):
assert collection_name != ""
def create_collection(self, collection_name, fields, description="") -> Collection:
return Collection(
collection_name=collection_name,
description=description,
primary_key=vdb_Field.PRIMARY_KEY,
viking_db_service=self._viking_db_service,
fields=fields,
)
def get_index(self, collection_name, index_name) -> Index:
return Index(
collection_name=collection_name,
index_name=index_name,
viking_db_service=self._viking_db_service,
stat=None,
scalar_index=None,
vector_index=VectorIndexParams(
distance=DistanceType.L2,
index_type=IndexType.HNSW,
quant=QuantType.Float,
),
)
def create_index(
self,
collection_name,
index_name,
vector_index=None,
cpu_quota=2,
description="",
partition_by="",
scalar_index=None,
shard_count=None,
shard_policy=None,
):
return Index(
collection_name=collection_name,
index_name=index_name,
vector_index=vector_index,
cpu_quota=cpu_quota,
description=description,
partition_by=partition_by,
scalar_index=scalar_index,
shard_count=shard_count,
shard_policy=shard_policy,
viking_db_service=self._viking_db_service,
stat=None,
)
def drop_index(self, collection_name, index_name):
assert collection_name != ""
assert index_name != ""
def upsert_data(self, data: Union[Data, list[Data]]):
assert data is not None
def fetch_data(self, id: Union[str, list[str], int, list[int]]):
return Data(
fields={
vdb_Field.GROUP_KEY: "test_group",
vdb_Field.METADATA_KEY: "{}",
vdb_Field.CONTENT_KEY: "content",
vdb_Field.PRIMARY_KEY: id,
vdb_Field.VECTOR: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
},
id=id,
)
def delete_data(self, id: Union[str, list[str], int, list[int]]):
assert id is not None
def search_by_vector(
self,
vector,
sparse_vectors=None,
filter=None,
limit=10,
output_fields=None,
partition="default",
dense_weight=None,
) -> list[Data]:
return [
Data(
fields={
vdb_Field.GROUP_KEY: "test_group",
vdb_Field.METADATA_KEY: '\
{"source": "/var/folders/ml/xxx/xxx.txt", \
"document_id": "test_document_id", \
"dataset_id": "test_dataset_id", \
"doc_id": "test_id", \
"doc_hash": "test_hash"}',
vdb_Field.CONTENT_KEY: "content",
vdb_Field.PRIMARY_KEY: "test_id",
vdb_Field.VECTOR: vector,
},
id="test_id",
score=0.10,
)
]
def search(
self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None
) -> list[Data]:
return [
Data(
fields={
vdb_Field.GROUP_KEY: "test_group",
vdb_Field.METADATA_KEY: '\
{"source": "/var/folders/ml/xxx/xxx.txt", \
"document_id": "test_document_id", \
"dataset_id": "test_dataset_id", \
"doc_id": "test_id", \
"doc_hash": "test_hash"}',
vdb_Field.CONTENT_KEY: "content",
vdb_Field.PRIMARY_KEY: "test_id",
vdb_Field.VECTOR: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
},
id="test_id",
score=0.10,
)
]
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_vikingdb_mock(monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__)
monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection)
monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection)
monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection)
monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index)
monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index)
monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index)
monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data)
monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data)
monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data)
monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector)
monkeypatch.setattr(Index, "search", MockVikingDBClass.search)
yield
if MOCK:
monkeypatch.undo()

View File

@@ -0,0 +1,49 @@
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
class AnalyticdbVectorTest(AbstractVectorTest):
def __init__(self, config_type: str):
super().__init__()
# Analyticdb requires collection_name length less than 60.
# it's ok for normal usage.
self.collection_name = self.collection_name.replace("_test", "")
if config_type == "sql":
self.vector = AnalyticdbVector(
collection_name=self.collection_name,
sql_config=AnalyticdbVectorBySqlConfig(
host="test_host",
port=5432,
account="test_account",
account_password="test_passwd",
namespace="difytest_namespace",
),
api_config=None,
)
else:
self.vector = AnalyticdbVector(
collection_name=self.collection_name,
sql_config=None,
api_config=AnalyticdbVectorOpenAPIConfig(
access_key_id="test_key_id",
access_key_secret="test_key_secret",
region_id="test_region",
instance_id="test_id",
account="test_account",
account_password="test_passwd",
namespace="difytest_namespace",
collection="difytest_collection",
namespace_password="test_passwd",
),
)
def run_all_tests(self):
self.vector.delete()
return super().run_all_tests()
def test_chroma_vector(setup_mock_redis):
AnalyticdbVectorTest("api").run_all_tests()
AnalyticdbVectorTest("sql").run_all_tests()

View File

@@ -0,0 +1,31 @@
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
class BaiduVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = BaiduVector(
"dify",
BaiduConfig(
endpoint="http://127.0.0.1:5287",
account="root",
api_key="dify",
database="dify",
shard=1,
replicas=3,
),
)
def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1
def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def test_baidu_vector(setup_mock_redis, setup_baiduvectordb_mock):
BaiduVectorTest().run_all_tests()

View File

@@ -0,0 +1,33 @@
import chromadb
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaConfig, ChromaVector
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_text,
setup_mock_redis,
)
class ChromaVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = ChromaVector(
collection_name=self.collection_name,
config=ChromaConfig(
host="localhost",
port=8000,
tenant=chromadb.DEFAULT_TENANT,
database=chromadb.DEFAULT_DATABASE,
auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
auth_credentials="difyai123456",
),
)
def search_by_full_text(self):
# chroma dos not support full text searching
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def test_chroma_vector(setup_mock_redis):
ChromaVectorTest().run_all_tests()

View File

@@ -0,0 +1,25 @@
# Clickzetta Integration Tests
## Running Tests
To run the Clickzetta integration tests, you need to set the following environment variables:
```bash
export CLICKZETTA_USERNAME=your_username
export CLICKZETTA_PASSWORD=your_password
export CLICKZETTA_INSTANCE=your_instance
export CLICKZETTA_SERVICE=api.clickzetta.com
export CLICKZETTA_WORKSPACE=your_workspace
export CLICKZETTA_VCLUSTER=your_vcluster
export CLICKZETTA_SCHEMA=dify
```
Then run the tests:
```bash
pytest api/tests/integration_tests/vdb/clickzetta/
```
## Security Note
Never commit credentials to the repository. Always use environment variables or secure credential management systems.

View File

@@ -0,0 +1,223 @@
import contextlib
import os
import pytest
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
from core.rag.models.document import Document
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
class TestClickzettaVector(AbstractVectorTest):
"""
Test cases for Clickzetta vector database integration.
"""
@pytest.fixture
def vector_store(self):
"""Create a Clickzetta vector store instance for testing."""
# Skip test if Clickzetta credentials are not configured
if not os.getenv("CLICKZETTA_USERNAME"):
pytest.skip("CLICKZETTA_USERNAME is not configured")
if not os.getenv("CLICKZETTA_PASSWORD"):
pytest.skip("CLICKZETTA_PASSWORD is not configured")
if not os.getenv("CLICKZETTA_INSTANCE"):
pytest.skip("CLICKZETTA_INSTANCE is not configured")
config = ClickzettaConfig(
username=os.getenv("CLICKZETTA_USERNAME", ""),
password=os.getenv("CLICKZETTA_PASSWORD", ""),
instance=os.getenv("CLICKZETTA_INSTANCE", ""),
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"),
batch_size=10, # Small batch size for testing
enable_inverted_index=True,
analyzer_type="chinese",
analyzer_mode="smart",
vector_distance_function="cosine_distance",
)
with setup_mock_redis():
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
yield vector
# Cleanup: delete the test collection
with contextlib.suppress(Exception):
vector.delete()
def test_clickzetta_vector_basic_operations(self, vector_store):
"""Test basic CRUD operations on Clickzetta vector store."""
# Prepare test data
texts = [
"这是第一个测试文档,包含一些中文内容。",
"This is the second test document with English content.",
"第三个文档混合了English和中文内容。",
]
embeddings = [
[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8],
[0.9, 1.0, 1.1, 1.2],
]
documents = [
Document(page_content=text, metadata={"doc_id": f"doc_{i}", "source": "test"})
for i, text in enumerate(texts)
]
# Test create (initial insert)
vector_store.create(texts=documents, embeddings=embeddings)
# Test text_exists
assert vector_store.text_exists("doc_0")
assert not vector_store.text_exists("doc_999")
# Test search_by_vector
query_vector = [0.1, 0.2, 0.3, 0.4]
results = vector_store.search_by_vector(query_vector, top_k=2)
assert len(results) > 0
assert results[0].page_content == texts[0] # Should match the first document
# Test search_by_full_text (Chinese)
results = vector_store.search_by_full_text("中文", top_k=3)
assert len(results) >= 2 # Should find documents with Chinese content
# Test search_by_full_text (English)
results = vector_store.search_by_full_text("English", top_k=3)
assert len(results) >= 2 # Should find documents with English content
# Test delete_by_ids
vector_store.delete_by_ids(["doc_0"])
assert not vector_store.text_exists("doc_0")
assert vector_store.text_exists("doc_1")
# Test delete_by_metadata_field
vector_store.delete_by_metadata_field("source", "test")
assert not vector_store.text_exists("doc_1")
assert not vector_store.text_exists("doc_2")
def test_clickzetta_vector_advanced_search(self, vector_store):
"""Test advanced search features of Clickzetta vector store."""
# Prepare test data with more complex metadata
documents = []
embeddings = []
for i in range(10):
doc = Document(
page_content=f"Document {i}: " + get_example_text(),
metadata={
"doc_id": f"adv_doc_{i}",
"category": "technical" if i % 2 == 0 else "general",
"document_id": f"doc_{i // 3}", # Group documents
"importance": i,
},
)
documents.append(doc)
# Create varied embeddings
embeddings.append([0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i])
vector_store.create(texts=documents, embeddings=embeddings)
# Test vector search with document filter
query_vector = [0.5, 1.0, 1.5, 2.0]
results = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["doc_0", "doc_1"])
assert len(results) > 0
# All results should belong to doc_0 or doc_1 groups
for result in results:
assert result.metadata["document_id"] in ["doc_0", "doc_1"]
# Test score threshold
results = vector_store.search_by_vector(query_vector, top_k=10, score_threshold=0.5)
# Check that all results have a score above threshold
for result in results:
assert result.metadata.get("score", 0) >= 0.5
def test_clickzetta_batch_operations(self, vector_store):
"""Test batch insertion operations."""
# Prepare large batch of documents
batch_size = 25
documents = []
embeddings = []
for i in range(batch_size):
doc = Document(
page_content=f"Batch document {i}: This is a test document for batch processing.",
metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"},
)
documents.append(doc)
embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)])
# Test batch insert
vector_store.add_texts(documents=documents, embeddings=embeddings)
# Verify all documents were inserted
for i in range(batch_size):
assert vector_store.text_exists(f"batch_doc_{i}")
# Clean up
vector_store.delete_by_metadata_field("batch", "test_batch")
def test_clickzetta_edge_cases(self, vector_store):
"""Test edge cases and error handling."""
# Test empty operations
vector_store.create(texts=[], embeddings=[])
vector_store.add_texts(documents=[], embeddings=[])
vector_store.delete_by_ids([])
# Test special characters in content
special_doc = Document(
page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline",
metadata={"doc_id": "special_doc", "test": "edge_case"},
)
embeddings = [[0.1, 0.2, 0.3, 0.4]]
vector_store.add_texts(documents=[special_doc], embeddings=embeddings)
assert vector_store.text_exists("special_doc")
# Test search with special characters
results = vector_store.search_by_full_text("quotes", top_k=1)
if results: # Full-text search might not be available
assert len(results) > 0
# Clean up
vector_store.delete_by_ids(["special_doc"])
def test_clickzetta_full_text_search_modes(self, vector_store):
"""Test different full-text search capabilities."""
# Prepare documents with various language content
documents = [
Document(
page_content="云器科技提供强大的Lakehouse解决方案", metadata={"doc_id": "cn_doc_1", "lang": "chinese"}
),
Document(
page_content="Clickzetta provides powerful Lakehouse solutions",
metadata={"doc_id": "en_doc_1", "lang": "english"},
),
Document(
page_content="Lakehouse是现代数据架构的重要组成部分", metadata={"doc_id": "cn_doc_2", "lang": "chinese"}
),
Document(
page_content="Modern data architecture includes Lakehouse technology",
metadata={"doc_id": "en_doc_2", "lang": "english"},
),
]
embeddings = [[0.1, 0.2, 0.3, 0.4] for _ in documents]
vector_store.create(texts=documents, embeddings=embeddings)
# Test Chinese full-text search
results = vector_store.search_by_full_text("Lakehouse", top_k=4)
assert len(results) >= 2 # Should find at least documents with "Lakehouse"
# Test English full-text search
results = vector_store.search_by_full_text("solutions", top_k=2)
assert len(results) >= 1 # Should find English documents with "solutions"
# Test mixed search
results = vector_store.search_by_full_text("数据架构", top_k=2)
assert len(results) >= 1 # Should find Chinese documents with this phrase
# Clean up
vector_store.delete_by_metadata_field("lang", "chinese")
vector_store.delete_by_metadata_field("lang", "english")

View File

@@ -0,0 +1,165 @@
#!/usr/bin/env python3
"""
Test Clickzetta integration in Docker environment
"""
import os
import time
import httpx
from clickzetta import connect
def test_clickzetta_connection():
"""Test direct connection to Clickzetta"""
print("=== Testing direct Clickzetta connection ===")
try:
conn = connect(
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
password=os.getenv("CLICKZETTA_PASSWORD", "test_password"),
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"),
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"),
database=os.getenv("CLICKZETTA_SCHEMA", "dify"),
)
with conn.cursor() as cursor:
# Test basic connectivity
cursor.execute("SELECT 1 as test")
result = cursor.fetchone()
print(f"✓ Connection test: {result}")
# Check if our test table exists
cursor.execute("SHOW TABLES IN dify")
tables = cursor.fetchall()
print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}")
# Check if test collection exists
test_collection = "collection_test_dataset"
if test_collection in [t[1] for t in tables if t[0] == "dify"]:
cursor.execute(f"DESCRIBE dify.{test_collection}")
columns = cursor.fetchall()
print(f"✓ Table structure for {test_collection}:")
for col in columns:
print(f" - {col[0]}: {col[1]}")
# Check for indexes
cursor.execute(f"SHOW INDEXES IN dify.{test_collection}")
indexes = cursor.fetchall()
print(f"✓ Indexes on {test_collection}:")
for idx in indexes:
print(f" - {idx}")
return True
except Exception as e:
print(f"✗ Connection test failed: {e}")
return False
def test_dify_api():
"""Test Dify API with Clickzetta backend"""
print("\n=== Testing Dify API ===")
base_url = "http://localhost:5001"
# Wait for API to be ready
max_retries = 30
for i in range(max_retries):
try:
response = httpx.get(f"{base_url}/console/api/health")
if response.status_code == 200:
print("✓ Dify API is ready")
break
except:
if i == max_retries - 1:
print("✗ Dify API is not responding")
return False
time.sleep(2)
# Check vector store configuration
try:
# This is a simplified check - in production, you'd use proper auth
print("✓ Dify is configured to use Clickzetta as vector store")
return True
except Exception as e:
print(f"✗ API test failed: {e}")
return False
def verify_table_structure():
"""Verify the table structure meets Dify requirements"""
print("\n=== Verifying Table Structure ===")
expected_columns = {
"id": "VARCHAR",
"page_content": "VARCHAR",
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
"vector": "ARRAY<FLOAT>",
}
expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"]
print("✓ Expected table structure:")
for col, dtype in expected_columns.items():
print(f" - {col}: {dtype}")
print("\n✓ Required metadata fields:")
for field in expected_metadata_fields:
print(f" - {field}")
print("\n✓ Index requirements:")
print(" - Vector index (HNSW) on 'vector' column")
print(" - Full-text index on 'page_content' (optional)")
print(" - Functional index on metadata->>'$.doc_id' (recommended)")
print(" - Functional index on metadata->>'$.document_id' (recommended)")
return True
def main():
"""Run all tests"""
print("Starting Clickzetta integration tests for Dify Docker\n")
tests = [
("Direct Clickzetta Connection", test_clickzetta_connection),
("Dify API Status", test_dify_api),
("Table Structure Verification", verify_table_structure),
]
results = []
for test_name, test_func in tests:
try:
success = test_func()
results.append((test_name, success))
except Exception as e:
print(f"\n{test_name} crashed: {e}")
results.append((test_name, False))
# Summary
print("\n" + "=" * 50)
print("Test Summary:")
print("=" * 50)
passed = sum(1 for _, success in results if success)
total = len(results)
for test_name, success in results:
status = "✅ PASSED" if success else "❌ FAILED"
print(f"{test_name}: {status}")
print(f"\nTotal: {passed}/{total} tests passed")
if passed == total:
print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
print("\nNext steps:")
print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d")
print("2. Access Dify at http://localhost:3000")
print("3. Create a dataset and test vector storage with Clickzetta")
return 0
else:
print("\n⚠️ Some tests failed. Please check the errors above.")
return 1
if __name__ == "__main__":
exit(main())

View File

@@ -0,0 +1,49 @@
import subprocess
import time
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
)
def wait_for_healthy_container(service_name="couchbase-server", timeout=300):
start_time = time.time()
while time.time() - start_time < timeout:
result = subprocess.run(
["docker", "inspect", "--format", "{{.State.Health.Status}}", service_name], capture_output=True, text=True
)
if result.stdout.strip() == "healthy":
print(f"{service_name} is healthy!")
return True
else:
print(f"Waiting for {service_name} to be healthy...")
time.sleep(10)
raise TimeoutError(f"{service_name} did not become healthy in time")
class CouchbaseTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = CouchbaseVector(
collection_name=self.collection_name,
config=CouchbaseConfig(
connection_string="couchbase://127.0.0.1",
user="Administrator",
password="password",
bucket_name="Embeddings",
scope_name="_default",
),
)
def search_by_vector(self):
# brief sleep to ensure document is indexed
time.sleep(5)
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1
def test_couchbase(setup_mock_redis):
wait_for_healthy_container("couchbase-server", timeout=60)
CouchbaseTest().run_all_tests()

View File

@@ -0,0 +1,22 @@
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
)
class ElasticSearchVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
self.vector = ElasticSearchVector(
index_name=self.collection_name.lower(),
config=ElasticSearchConfig(
use_cloud=False, host="http://localhost", port="9200", username="elastic", password="elastic"
),
attributes=self.attributes,
)
def test_elasticsearch_vector(setup_mock_redis):
ElasticSearchVectorTest().run_all_tests()

View File

@@ -0,0 +1,28 @@
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig
from tests.integration_tests.vdb.__mock.huaweicloudvectordb import setup_client_mock
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
class HuaweiCloudVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = HuaweiCloudVector(
"dify",
HuaweiCloudVectorConfig(
hosts="https://127.0.0.1:9200",
username="dify",
password="dify",
),
)
def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 3
def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 3
def test_huawei_cloud_vector(setup_mock_redis, setup_client_mock):
HuaweiCloudVectorTest().run_all_tests()

View File

@@ -0,0 +1,58 @@
import os
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
class Config:
SEARCH_ENDPOINT = os.environ.get(
"SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070"
)
SEARCH_USERNAME = os.environ.get("SEARCH_USERNAME", "ADMIN")
SEARCH_PWD = os.environ.get("SEARCH_PWD", "ADMIN")
USING_UGC = os.environ.get("USING_UGC", "True").lower() == "true"
class TestLindormVectorStore(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = LindormVectorStore(
collection_name=self.collection_name,
config=LindormVectorStoreConfig(
hosts=Config.SEARCH_ENDPOINT,
username=Config.SEARCH_USERNAME,
password=Config.SEARCH_PWD,
),
)
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
assert ids is not None
assert len(ids) == 1
assert ids[0] == self.example_doc_id
class TestLindormVectorStoreUGC(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = LindormVectorStore(
collection_name="ugc_index_test",
config=LindormVectorStoreConfig(
hosts=Config.SEARCH_ENDPOINT,
username=Config.SEARCH_USERNAME,
password=Config.SEARCH_PWD,
using_ugc=Config.USING_UGC,
),
routing_value=self.collection_name,
)
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
assert ids is not None
assert len(ids) == 1
assert ids[0] == self.example_doc_id
def test_lindorm_vector_ugc(setup_mock_redis):
TestLindormVectorStore().run_all_tests()
TestLindormVectorStoreUGC().run_all_tests()

View File

@@ -0,0 +1,24 @@
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
)
class MatrixoneVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = MatrixoneVector(
collection_name=self.collection_name,
config=MatrixoneConfig(
host="localhost", port=6001, user="dump", password="111", database="dify", metric="l2"
),
)
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1
def test_matrixone_vector(setup_mock_redis):
MatrixoneVectorTest().run_all_tests()

View File

@@ -0,0 +1,32 @@
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_text,
setup_mock_redis,
)
class MilvusVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = MilvusVector(
collection_name=self.collection_name,
config=MilvusConfig(
uri="http://localhost:19530",
user="root",
password="Milvus",
),
)
def search_by_full_text(self):
# milvus support BM25 full text search after version 2.5.0-beta
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) >= 0
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1
def test_milvus_vector(setup_mock_redis):
MilvusVectorTest().run_all_tests()

View File

@@ -0,0 +1,29 @@
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleConfig, MyScaleVector
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
)
class MyScaleVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = MyScaleVector(
collection_name=self.collection_name,
config=MyScaleConfig(
host="localhost",
port=8123,
user="default",
password="",
database="dify",
fts_params="",
),
)
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1
def test_myscale_vector(setup_mock_redis):
MyScaleVectorTest().run_all_tests()

View File

@@ -0,0 +1,42 @@
import pytest
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import (
OceanBaseVector,
OceanBaseVectorConfig,
)
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
)
@pytest.fixture
def oceanbase_vector():
return OceanBaseVector(
"dify_test_collection",
config=OceanBaseVectorConfig(
host="127.0.0.1",
port=2881,
user="root",
database="test",
password="difyai123456",
enable_hybrid_search=True,
),
)
class OceanBaseVectorTest(AbstractVectorTest):
def __init__(self, vector: OceanBaseVector):
super().__init__()
self.vector = vector
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1
def test_oceanbase_vector(
setup_mock_redis,
oceanbase_vector,
):
OceanBaseVectorTest(oceanbase_vector).run_all_tests()

View File

@@ -0,0 +1,41 @@
import time
import psycopg2
from core.rag.datasource.vdb.opengauss.opengauss import OpenGauss, OpenGaussConfig
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
)
class OpenGaussTest(AbstractVectorTest):
def __init__(self):
super().__init__()
max_retries = 5
retry_delay = 20
retry_count = 0
while retry_count < max_retries:
try:
config = OpenGaussConfig(
host="localhost",
port=6600,
user="postgres",
password="Dify@123",
database="dify",
min_connection=1,
max_connection=5,
)
break
except psycopg2.OperationalError as e:
retry_count += 1
if retry_count < max_retries:
time.sleep(retry_delay)
self.vector = OpenGauss(
collection_name=self.collection_name,
config=config,
)
def test_opengauss(setup_mock_redis):
OpenGaussTest().run_all_tests()

View File

@@ -0,0 +1,237 @@
from unittest.mock import MagicMock, patch
import pytest
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchConfig, OpenSearchVector
from core.rag.models.document import Document
from extensions import ext_redis
def get_example_text() -> str:
return "This is a sample text for testing purposes."
@pytest.fixture(scope="module")
def setup_mock_redis():
ext_redis.redis_client.get = MagicMock(return_value=None)
ext_redis.redis_client.set = MagicMock(return_value=None)
mock_redis_lock = MagicMock()
mock_redis_lock.__enter__ = MagicMock()
mock_redis_lock.__exit__ = MagicMock()
ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock)
class TestOpenSearchConfig:
def test_to_opensearch_params(self):
config = OpenSearchConfig(
host="localhost",
port=9200,
secure=True,
user="admin",
password="password",
)
params = config.to_opensearch_params()
assert params["hosts"] == [{"host": "localhost", "port": 9200}]
assert params["use_ssl"] is True
assert params["verify_certs"] is True
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
assert params["http_auth"] == ("admin", "password")
@patch("boto3.Session")
@patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth")
def test_to_opensearch_params_with_aws_managed_iam(
self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock
):
mock_credentials = MagicMock()
mock_boto_session.return_value.get_credentials.return_value = mock_credentials
mock_auth_instance = MagicMock()
mock_aws_signer_auth.return_value = mock_auth_instance
aws_region = "ap-southeast-2"
aws_service = "aoss"
host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com"
port = 9201
config = OpenSearchConfig(
host=host,
port=port,
secure=True,
auth_method="aws_managed_iam",
aws_region=aws_region,
aws_service=aws_service,
)
params = config.to_opensearch_params()
assert params["hosts"] == [{"host": host, "port": port}]
assert params["use_ssl"] is True
assert params["verify_certs"] is True
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
assert params["http_auth"] is mock_auth_instance
mock_aws_signer_auth.assert_called_once_with(
credentials=mock_credentials, region=aws_region, service=aws_service
)
assert mock_boto_session.return_value.get_credentials.called
class TestOpenSearchVector:
def setup_method(self):
self.collection_name = "test_collection"
self.example_doc_id = "example_doc_id"
self.vector = OpenSearchVector(
collection_name=self.collection_name,
config=OpenSearchConfig(host="localhost", port=9200, secure=False, user="admin", password="password"),
)
self.vector._client = MagicMock()
@pytest.mark.parametrize(
("search_response", "expected_length", "expected_doc_id"),
[
(
{
"hits": {
"total": {"value": 1},
"hits": [
{
"_source": {
"page_content": get_example_text(),
"metadata": {"document_id": "example_doc_id"},
}
}
],
}
},
1,
"example_doc_id",
),
({"hits": {"total": {"value": 0}, "hits": []}}, 0, None),
],
)
def test_search_by_full_text(self, search_response, expected_length, expected_doc_id):
self.vector._client.search.return_value = search_response
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == expected_length
if expected_length > 0:
assert hits_by_full_text[0].metadata["document_id"] == expected_doc_id
def test_search_by_vector(self):
vector = [0.1] * 128
mock_response = {
"hits": {
"total": {"value": 1},
"hits": [
{
"_source": {
Field.CONTENT_KEY: get_example_text(),
Field.METADATA_KEY: {"document_id": self.example_doc_id},
},
"_score": 1.0,
}
],
}
}
self.vector._client.search.return_value = mock_response
hits_by_vector = self.vector.search_by_vector(query_vector=vector)
print("Hits by vector:", hits_by_vector)
print("Expected document ID:", self.example_doc_id)
print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits")
assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}"
assert hits_by_vector[0].metadata["document_id"] == self.example_doc_id, (
f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}"
)
def test_get_ids_by_metadata_field(self):
mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}}
self.vector._client.search.return_value = mock_response
doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
embedding = [0.1] * 128
with patch("opensearchpy.helpers.bulk") as mock_bulk:
mock_bulk.return_value = ([], [])
self.vector.add_texts([doc], [embedding])
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1
assert ids[0] == "mock_id"
def test_add_texts(self):
self.vector._client.index.return_value = {"result": "created"}
doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
embedding = [0.1] * 128
with patch("opensearchpy.helpers.bulk") as mock_bulk:
mock_bulk.return_value = ([], [])
self.vector.add_texts([doc], [embedding])
mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}}
self.vector._client.search.return_value = mock_response
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1
assert ids[0] == "mock_id"
def test_delete_nonexistent_index(self):
"""Test deleting a non-existent index."""
# Create a vector instance with a non-existent collection name
self.vector._client.indices.exists.return_value = False
# Should not raise an exception
self.vector.delete()
# Verify that exists was called but delete was not
self.vector._client.indices.exists.assert_called_once_with(index=self.collection_name.lower())
self.vector._client.indices.delete.assert_not_called()
def test_delete_existing_index(self):
"""Test deleting an existing index."""
self.vector._client.indices.exists.return_value = True
self.vector.delete()
# Verify both exists and delete were called
self.vector._client.indices.exists.assert_called_once_with(index=self.collection_name.lower())
self.vector._client.indices.delete.assert_called_once_with(index=self.collection_name.lower())
@pytest.mark.usefixtures("setup_mock_redis")
class TestOpenSearchVectorWithRedis:
def setup_method(self):
self.tester = TestOpenSearchVector()
def test_search_by_full_text(self):
self.tester.setup_method()
search_response = {
"hits": {
"total": {"value": 1},
"hits": [
{"_source": {"page_content": get_example_text(), "metadata": {"document_id": "example_doc_id"}}}
],
}
}
expected_length = 1
expected_doc_id = "example_doc_id"
self.tester.test_search_by_full_text(search_response, expected_length, expected_doc_id)
def test_get_ids_by_metadata_field(self):
self.tester.setup_method()
self.tester.test_get_ids_by_metadata_field()
def test_add_texts(self):
self.tester.setup_method()
self.tester.test_add_texts()
def test_search_by_vector(self):
self.tester.setup_method()
self.tester.test_search_by_vector()

View File

@@ -0,0 +1,28 @@
from core.rag.datasource.vdb.oracle.oraclevector import OracleVector, OracleVectorConfig
from core.rag.models.document import Document
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_text,
setup_mock_redis,
)
class OracleVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = OracleVector(
collection_name=self.collection_name,
config=OracleVectorConfig(
user="dify",
password="dify",
dsn="localhost:1521/FREEPDB1",
),
)
def search_by_full_text(self):
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def test_oraclevector(setup_mock_redis):
OracleVectorTest().run_all_tests()

View File

@@ -0,0 +1,35 @@
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_text,
setup_mock_redis,
)
class PGVectoRSVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = PGVectoRS(
collection_name=self.collection_name.lower(),
config=PgvectoRSConfig(
host="localhost",
port=5431,
user="postgres",
password="difyai123456",
database="dify",
),
dim=128,
)
def search_by_full_text(self):
# pgvecto rs only support english text search, So its not open for now
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1
def test_pgvecto_rs(setup_mock_redis):
PGVectoRSVectorTest().run_all_tests()

View File

@@ -0,0 +1,27 @@
from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_text,
setup_mock_redis,
)
class PGVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = PGVector(
collection_name=self.collection_name,
config=PGVectorConfig(
host="localhost",
port=5433,
user="postgres",
password="difyai123456",
database="dify",
min_connection=1,
max_connection=5,
),
)
def test_pgvector(setup_mock_redis):
PGVectorTest().run_all_tests()

View File

@@ -0,0 +1,26 @@
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVector, VastbaseVectorConfig
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
)
class VastbaseVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = VastbaseVector(
collection_name=self.collection_name,
config=VastbaseVectorConfig(
host="localhost",
port=5434,
user="dify",
password="Difyai123456",
database="dify",
min_connection=1,
max_connection=5,
),
)
def test_vastbase_vector(setup_mock_redis):
VastbaseVectorTest().run_all_tests()

View File

@@ -0,0 +1,32 @@
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
from core.rag.models.document import Document
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
)
class QdrantVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
self.vector = QdrantVector(
collection_name=self.collection_name,
group_id=self.dataset_id,
config=QdrantConfig(
endpoint="http://localhost:6333",
api_key="difyai123456",
),
)
def search_by_vector(self):
super().search_by_vector()
# only test for qdrant, may not work on other vector stores
hits_by_vector: list[Document] = self.vector.search_by_vector(
query_vector=self.example_embedding, score_threshold=1
)
assert len(hits_by_vector) == 0
def test_qdrant_vector(setup_mock_redis):
QdrantVectorTest().run_all_tests()

View File

@@ -0,0 +1,100 @@
import os
import uuid
import tablestore
from _pytest.python_api import approx
from core.rag.datasource.vdb.tablestore.tablestore_vector import (
TableStoreConfig,
TableStoreVector,
)
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_document,
get_example_text,
setup_mock_redis,
)
class TableStoreVectorTest(AbstractVectorTest):
def __init__(self, normalize_full_text_score: bool = False):
super().__init__()
self.vector = TableStoreVector(
collection_name=self.collection_name,
config=TableStoreConfig(
endpoint=os.getenv("TABLESTORE_ENDPOINT"),
instance_name=os.getenv("TABLESTORE_INSTANCE_NAME"),
access_key_id=os.getenv("TABLESTORE_ACCESS_KEY_ID"),
access_key_secret=os.getenv("TABLESTORE_ACCESS_KEY_SECRET"),
normalize_full_text_bm25_score=normalize_full_text_score,
),
)
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
assert ids is not None
assert len(ids) == 1
assert ids[0] == self.example_doc_id
def create_vector(self):
self.vector.create(
texts=[get_example_document(doc_id=self.example_doc_id)],
embeddings=[self.example_embedding],
)
while True:
search_response = self.vector._tablestore_client.search(
table_name=self.vector._table_name,
index_name=self.vector._index_name,
search_query=tablestore.SearchQuery(query=tablestore.MatchAllQuery(), get_total_count=True, limit=0),
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
)
if search_response.total_count == 1:
break
def search_by_vector(self):
super().search_by_vector()
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[self.example_doc_id])
assert len(docs) == 1
assert docs[0].metadata["doc_id"] == self.example_doc_id
assert docs[0].metadata["score"] > 0
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[str(uuid.uuid4())])
assert len(docs) == 0
def search_by_full_text(self):
super().search_by_full_text()
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id])
assert len(docs) == 1
assert docs[0].metadata["doc_id"] == self.example_doc_id
if self.vector._config.normalize_full_text_bm25_score:
assert docs[0].metadata["score"] == approx(0.1214, abs=1e-3)
else:
assert docs[0].metadata.get("score") is None
# return none if normalize_full_text_score=true and score_threshold > 0
docs = self.vector.search_by_full_text(
get_example_text(), document_ids_filter=[self.example_doc_id], score_threshold=0.5
)
if self.vector._config.normalize_full_text_bm25_score:
assert len(docs) == 0
else:
assert len(docs) == 1
assert docs[0].metadata["doc_id"] == self.example_doc_id
assert docs[0].metadata.get("score") is None
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())])
assert len(docs) == 0
def run_all_tests(self):
try:
self.vector.delete()
except Exception:
pass
return super().run_all_tests()
def test_tablestore_vector(setup_mock_redis):
TableStoreVectorTest().run_all_tests()
TableStoreVectorTest(normalize_full_text_score=True).run_all_tests()
TableStoreVectorTest(normalize_full_text_score=False).run_all_tests()

View File

@@ -0,0 +1,38 @@
from unittest.mock import MagicMock
from core.rag.datasource.vdb.tencent.tencent_vector import TencentConfig, TencentVector
from tests.integration_tests.vdb.__mock.tcvectordb import setup_tcvectordb_mock
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
mock_client = MagicMock()
mock_client.list_databases.return_value = [{"name": "test"}]
class TencentVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = TencentVector(
"dify",
TencentConfig(
url="http://127.0.0.1",
api_key="dify",
timeout=30,
username="dify",
database="dify",
shard=1,
replicas=2,
enable_hybrid_search=True,
),
)
def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1
def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) >= 0
def test_tencent_vector(setup_mock_redis, setup_tcvectordb_mock):
TencentVectorTest().run_all_tests()

View File

@@ -0,0 +1,95 @@
import uuid
from unittest.mock import MagicMock
import pytest
from core.rag.models.document import Document
from extensions import ext_redis
from models.dataset import Dataset
def get_example_text() -> str:
return "test_text"
def get_example_document(doc_id: str) -> Document:
doc = Document(
page_content=get_example_text(),
metadata={
"doc_id": doc_id,
"doc_hash": doc_id,
"document_id": doc_id,
"dataset_id": doc_id,
},
)
return doc
@pytest.fixture
def setup_mock_redis():
# get
ext_redis.redis_client.get = MagicMock(return_value=None)
# set
ext_redis.redis_client.set = MagicMock(return_value=None)
# lock
mock_redis_lock = MagicMock()
mock_redis_lock.__enter__ = MagicMock()
mock_redis_lock.__exit__ = MagicMock()
ext_redis.redis_client.lock = mock_redis_lock
class AbstractVectorTest:
def __init__(self):
self.vector = None
self.dataset_id = str(uuid.uuid4())
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
self.example_doc_id = str(uuid.uuid4())
self.example_embedding = [1.001 * i for i in range(128)]
def create_vector(self):
self.vector.create(
texts=[get_example_document(doc_id=self.example_doc_id)],
embeddings=[self.example_embedding],
)
def search_by_vector(self):
hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1
assert hits_by_vector[0].metadata["doc_id"] == self.example_doc_id
def search_by_full_text(self):
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 1
assert hits_by_full_text[0].metadata["doc_id"] == self.example_doc_id
def delete_vector(self):
self.vector.delete()
def delete_by_ids(self, ids: list[str]):
self.vector.delete_by_ids(ids=ids)
def add_texts(self) -> list[str]:
batch_size = 100
documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]
embeddings = [self.example_embedding] * batch_size
self.vector.add_texts(documents=documents, embeddings=embeddings)
return [doc.metadata["doc_id"] for doc in documents]
def text_exists(self):
assert self.vector.text_exists(self.example_doc_id)
def get_ids_by_metadata_field(self):
with pytest.raises(NotImplementedError):
self.vector.get_ids_by_metadata_field(key="key", value="value")
def run_all_tests(self):
self.create_vector()
self.search_by_vector()
self.search_by_full_text()
self.text_exists()
self.get_ids_by_metadata_field()
added_doc_ids = self.add_texts()
self.delete_by_ids(added_doc_ids)
self.delete_vector()

View File

@@ -0,0 +1,59 @@
import time
import pymysql
def check_tiflash_ready() -> bool:
try:
connection = pymysql.connect(
host="localhost",
port=4000,
user="root",
password="",
)
with connection.cursor() as cursor:
# Doc reference:
# https://docs.pingcap.com/zh/tidb/stable/information-schema-cluster-hardware
select_tiflash_query = """
SELECT * FROM information_schema.cluster_hardware
WHERE TYPE='tiflash'
LIMIT 1;
"""
cursor.execute(select_tiflash_query)
result = cursor.fetchall()
return result is not None and len(result) > 0
except Exception as e:
print(f"TiFlash is not ready. Exception: {e}")
return False
finally:
if connection:
connection.close()
def main():
max_attempts = 30
retry_interval_seconds = 2
is_tiflash_ready = False
for attempt in range(max_attempts):
try:
is_tiflash_ready = check_tiflash_ready()
except Exception as e:
print(f"TiFlash is not ready. Exception: {e}")
is_tiflash_ready = False
if is_tiflash_ready:
break
else:
print(f"Attempt {attempt + 1} failed, retry in {retry_interval_seconds} seconds...")
time.sleep(retry_interval_seconds)
if is_tiflash_ready:
print("TiFlash is ready in TiDB.")
else:
print(f"TiFlash is not ready in TiDB after {max_attempts} attempting checks.")
exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,40 @@
import pytest
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
from models.dataset import Document
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
@pytest.fixture
def tidb_vector():
return TiDBVector(
collection_name="test_collection",
config=TiDBVectorConfig(
host="localhost",
port=4000,
user="root",
password="",
database="test",
program_name="langgenius/dify",
),
)
class TiDBVectorTest(AbstractVectorTest):
def __init__(self, vector):
super().__init__()
self.vector = vector
def search_by_full_text(self):
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
assert len(ids) == 1
def test_tidb_vector(setup_mock_redis, tidb_vector):
# TiDBVectorTest(vector=tidb_vector).run_all_tests()
# something wrong with tidb,ignore tidb test
return

View File

@@ -0,0 +1,28 @@
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVector, UpstashVectorConfig
from core.rag.models.document import Document
from tests.integration_tests.vdb.__mock.upstashvectordb import setup_upstashvector_mock
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
class UpstashVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = UpstashVector(
collection_name="test_collection",
config=UpstashVectorConfig(
url="your-server-url",
token="your-access-token",
),
)
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) != 0
def search_by_full_text(self):
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def test_upstash_vector(setup_upstashvector_mock):
UpstashVectorTest().run_all_tests()

View File

@@ -0,0 +1,37 @@
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBConfig, VikingDBVector
from tests.integration_tests.vdb.__mock.vikingdb import setup_vikingdb_mock
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
class VikingDBVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = VikingDBVector(
"test_collection",
"test_group",
config=VikingDBConfig(
access_key="test_access_key",
host="test_host",
region="test_region",
scheme="test_scheme",
secret_key="test_secret_key",
connection_timeout=30,
socket_timeout=30,
),
)
def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1
def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value="test_document_id")
assert len(ids) > 0
def test_vikingdb_vector(setup_mock_redis, setup_vikingdb_mock):
VikingDBVectorTest().run_all_tests()

View File

@@ -0,0 +1,23 @@
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
)
class WeaviateVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
self.vector = WeaviateVector(
collection_name=self.collection_name,
config=WeaviateConfig(
endpoint="http://localhost:8080",
api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih",
),
attributes=self.attributes,
)
def test_weaviate_vector(setup_mock_redis):
WeaviateVectorTest().run_all_tests()