dify
This commit is contained in:
0
dify/api/tests/integration_tests/vdb/__init__.py
Normal file
0
dify/api/tests/integration_tests/vdb/__init__.py
Normal file
166
dify/api/tests/integration_tests/vdb/__mock/baiduvectordb.py
Normal file
166
dify/api/tests/integration_tests/vdb/__mock/baiduvectordb.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import os
|
||||
from collections import UserDict
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from pymochow import MochowClient
|
||||
from pymochow.model.database import Database
|
||||
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
|
||||
from pymochow.model.schema import HNSWParams, VectorIndex
|
||||
from pymochow.model.table import Table
|
||||
|
||||
|
||||
class AttrDict(UserDict):
|
||||
def __getattr__(self, item):
|
||||
return self.get(item)
|
||||
|
||||
|
||||
class MockBaiduVectorDBClass:
|
||||
def mock_vector_db_client(
|
||||
self,
|
||||
config=None,
|
||||
adapter: Any | None = None,
|
||||
):
|
||||
self.conn = MagicMock()
|
||||
self._config = MagicMock()
|
||||
|
||||
def list_databases(self, config=None) -> list[Database]:
|
||||
return [
|
||||
Database(
|
||||
conn=self.conn,
|
||||
database_name="dify",
|
||||
config=self._config,
|
||||
)
|
||||
]
|
||||
|
||||
def create_database(self, database_name: str, config=None) -> Database:
|
||||
return Database(conn=self.conn, database_name=database_name, config=config)
|
||||
|
||||
def list_table(self, config=None) -> list[Table]:
|
||||
return []
|
||||
|
||||
def drop_table(self, table_name: str, config=None):
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def create_table(
|
||||
self,
|
||||
table_name: str,
|
||||
replication: int,
|
||||
partition: int,
|
||||
schema,
|
||||
enable_dynamic_field=False,
|
||||
description: str = "",
|
||||
config=None,
|
||||
) -> Table:
|
||||
return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config)
|
||||
|
||||
def describe_table(self, table_name: str, config=None) -> Table:
|
||||
return Table(
|
||||
self,
|
||||
table_name,
|
||||
3,
|
||||
1,
|
||||
None,
|
||||
enable_dynamic_field=False,
|
||||
description="table for dify",
|
||||
config=config,
|
||||
state=TableState.NORMAL,
|
||||
)
|
||||
|
||||
def upsert(self, rows, config=None):
|
||||
return {"code": 0, "msg": "operation success", "affectedCount": 1}
|
||||
|
||||
def rebuild_index(self, index_name: str, config=None):
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def describe_index(self, index_name: str, config=None):
|
||||
return VectorIndex(
|
||||
index_name=index_name,
|
||||
index_type=IndexType.HNSW,
|
||||
field="vector",
|
||||
metric_type=MetricType.L2,
|
||||
params=HNSWParams(m=16, efconstruction=200),
|
||||
auto_build=False,
|
||||
state=IndexState.NORMAL,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
primary_key,
|
||||
partition_key=None,
|
||||
projections=None,
|
||||
retrieve_vector=False,
|
||||
read_consistency=ReadConsistency.EVENTUAL,
|
||||
config=None,
|
||||
):
|
||||
return AttrDict(
|
||||
{
|
||||
"row": {
|
||||
"id": primary_key.get("id"),
|
||||
"vector": [0.23432432, 0.8923744, 0.89238432],
|
||||
"page_content": "text",
|
||||
"metadata": {"doc_id": "doc_id_001"},
|
||||
},
|
||||
"code": 0,
|
||||
"msg": "Success",
|
||||
}
|
||||
)
|
||||
|
||||
def delete(self, primary_key=None, partition_key=None, filter=None, config=None):
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def search(
|
||||
self,
|
||||
anns,
|
||||
partition_key=None,
|
||||
projections=None,
|
||||
retrieve_vector=False,
|
||||
read_consistency=ReadConsistency.EVENTUAL,
|
||||
config=None,
|
||||
):
|
||||
return AttrDict(
|
||||
{
|
||||
"rows": [
|
||||
{
|
||||
"row": {
|
||||
"id": "doc_id_001",
|
||||
"vector": [0.23432432, 0.8923744, 0.89238432],
|
||||
"page_content": "text",
|
||||
"metadata": {"doc_id": "doc_id_001"},
|
||||
},
|
||||
"distance": 0.1,
|
||||
"score": 0.5,
|
||||
}
|
||||
],
|
||||
"code": 0,
|
||||
"msg": "Success",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client)
|
||||
monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases)
|
||||
monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database)
|
||||
monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table)
|
||||
monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table)
|
||||
monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table)
|
||||
monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table)
|
||||
monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table)
|
||||
monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index)
|
||||
monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index)
|
||||
monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete)
|
||||
monkeypatch.setattr(Table, "query", MockBaiduVectorDBClass.query)
|
||||
monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
|
||||
|
||||
class MockIndicesClient:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def create(self, index, mappings, settings):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def refresh(self, index):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def delete(self, index):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def exists(self, index):
|
||||
return True
|
||||
|
||||
|
||||
class MockClient:
|
||||
def __init__(self, **kwargs):
|
||||
self.indices = MockIndicesClient()
|
||||
|
||||
def index(self, **kwargs):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def exists(self, **kwargs):
|
||||
return True
|
||||
|
||||
def delete(self, **kwargs):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def search(self, **kwargs):
|
||||
return {
|
||||
"took": 1,
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
Field.CONTENT_KEY: "abcdef",
|
||||
Field.VECTOR: [1, 2],
|
||||
Field.METADATA_KEY: {},
|
||||
},
|
||||
"_score": 1.0,
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
Field.CONTENT_KEY: "123456",
|
||||
Field.VECTOR: [2, 2],
|
||||
Field.METADATA_KEY: {},
|
||||
},
|
||||
"_score": 0.9,
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
Field.CONTENT_KEY: "a1b2c3",
|
||||
Field.VECTOR: [3, 2],
|
||||
Field.METADATA_KEY: {},
|
||||
},
|
||||
"_score": 0.8,
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_client_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Elasticsearch, "__init__", MockClient.__init__)
|
||||
monkeypatch.setattr(Elasticsearch, "index", MockClient.index)
|
||||
monkeypatch.setattr(Elasticsearch, "exists", MockClient.exists)
|
||||
monkeypatch.setattr(Elasticsearch, "delete", MockClient.delete)
|
||||
monkeypatch.setattr(Elasticsearch, "search", MockClient.search)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
192
dify/api/tests/integration_tests/vdb/__mock/tcvectordb.py
Normal file
192
dify/api/tests/integration_tests/vdb/__mock/tcvectordb.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import os
|
||||
from typing import Any, Union
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from tcvectordb import RPCVectorDBClient
|
||||
from tcvectordb.model import enum
|
||||
from tcvectordb.model.collection import FilterIndexConfig
|
||||
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank
|
||||
from tcvectordb.model.enum import ReadConsistency
|
||||
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex
|
||||
from tcvectordb.rpc.model.collection import RPCCollection
|
||||
from tcvectordb.rpc.model.database import RPCDatabase
|
||||
from xinference_client.types import Embedding
|
||||
|
||||
|
||||
class MockTcvectordbClass:
|
||||
def mock_vector_db_client(
|
||||
self,
|
||||
url: str,
|
||||
username="",
|
||||
key="",
|
||||
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
|
||||
timeout=10,
|
||||
adapter: Any | None = None,
|
||||
pool_size: int = 2,
|
||||
proxies: dict | None = None,
|
||||
password: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._conn = None
|
||||
self._read_consistency = read_consistency
|
||||
|
||||
def create_database_if_not_exists(self, database_name: str, timeout: float | None = None) -> RPCDatabase:
|
||||
return RPCDatabase(
|
||||
name="dify",
|
||||
read_consistency=self._read_consistency,
|
||||
)
|
||||
|
||||
def exists_collection(self, database_name: str, collection_name: str) -> bool:
|
||||
return True
|
||||
|
||||
def describe_collection(
|
||||
self, database_name: str, collection_name: str, timeout: float | None = None
|
||||
) -> RPCCollection:
|
||||
index = Index(
|
||||
FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
|
||||
VectorIndex(
|
||||
"vector",
|
||||
128,
|
||||
enum.IndexType.HNSW,
|
||||
enum.MetricType.IP,
|
||||
HNSWParams(m=16, efconstruction=200),
|
||||
),
|
||||
FilterIndex("text", enum.FieldType.String, enum.IndexType.FILTER),
|
||||
FilterIndex("metadata", enum.FieldType.String, enum.IndexType.FILTER),
|
||||
)
|
||||
return RPCCollection(
|
||||
RPCDatabase(
|
||||
name=database_name,
|
||||
read_consistency=self._read_consistency,
|
||||
),
|
||||
collection_name,
|
||||
index=index,
|
||||
)
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
shard: int,
|
||||
replicas: int,
|
||||
description: str | None = None,
|
||||
index: Index | None = None,
|
||||
embedding: Embedding | None = None,
|
||||
timeout: float | None = None,
|
||||
ttl_config: dict | None = None,
|
||||
filter_index_config: FilterIndexConfig | None = None,
|
||||
indexes: list[IndexField] | None = None,
|
||||
) -> RPCCollection:
|
||||
return RPCCollection(
|
||||
RPCDatabase(
|
||||
name="dify",
|
||||
read_consistency=self._read_consistency,
|
||||
),
|
||||
collection_name,
|
||||
shard,
|
||||
replicas,
|
||||
description,
|
||||
index,
|
||||
embedding=embedding,
|
||||
read_consistency=self._read_consistency,
|
||||
timeout=timeout,
|
||||
ttl_config=ttl_config,
|
||||
filter_index_config=filter_index_config,
|
||||
indexes=indexes,
|
||||
)
|
||||
|
||||
def collection_upsert(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
documents: list[Union[Document, dict]],
|
||||
timeout: float | None = None,
|
||||
build_index: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
def collection_search(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
vectors: list[list[float]],
|
||||
filter: Filter | None = None,
|
||||
params=None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: int = 10,
|
||||
output_fields: list[str] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> list[list[dict]]:
|
||||
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
|
||||
|
||||
def collection_hybrid_search(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
ann: Union[list[AnnSearch], AnnSearch] | None = None,
|
||||
match: Union[list[KeywordSearch], KeywordSearch] | None = None,
|
||||
filter: Union[Filter, str] | None = None,
|
||||
rerank: Rerank | None = None,
|
||||
retrieve_vector: bool | None = None,
|
||||
output_fields: list[str] | None = None,
|
||||
limit: int | None = None,
|
||||
timeout: float | None = None,
|
||||
return_pd_object=False,
|
||||
**kwargs,
|
||||
) -> list[list[dict]]:
|
||||
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
|
||||
|
||||
def collection_query(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
document_ids: list | None = None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
filter: Filter | None = None,
|
||||
output_fields: list[str] | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
|
||||
|
||||
def collection_delete(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
document_ids: list[str] | None = None,
|
||||
filter: Filter | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
def drop_collection(self, database_name: str, collection_name: str, timeout: float | None = None):
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(RPCVectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
|
||||
monkeypatch.setattr(
|
||||
RPCVectorDBClient, "create_database_if_not_exists", MockTcvectordbClass.create_database_if_not_exists
|
||||
)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "exists_collection", MockTcvectordbClass.exists_collection)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "create_collection", MockTcvectordbClass.create_collection)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "describe_collection", MockTcvectordbClass.describe_collection)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "upsert", MockTcvectordbClass.collection_upsert)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "search", MockTcvectordbClass.collection_search)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "hybrid_search", MockTcvectordbClass.collection_hybrid_search)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "query", MockTcvectordbClass.collection_query)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "delete", MockTcvectordbClass.collection_delete)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "drop_collection", MockTcvectordbClass.drop_collection)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
from collections import UserDict
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from upstash_vector import Index
|
||||
|
||||
|
||||
# Mocking the Index class from upstash_vector
|
||||
class MockIndex:
|
||||
def __init__(self, url="", token=""):
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.vectors = []
|
||||
|
||||
def upsert(self, vectors):
|
||||
for vector in vectors:
|
||||
vector.score = 0.5
|
||||
self.vectors.append(vector)
|
||||
return {"code": 0, "msg": "operation success", "affectedCount": len(vectors)}
|
||||
|
||||
def fetch(self, ids):
|
||||
return [vector for vector in self.vectors if vector.id in ids]
|
||||
|
||||
def delete(self, ids):
|
||||
self.vectors = [vector for vector in self.vectors if vector.id not in ids]
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def query(
|
||||
self,
|
||||
vector: None,
|
||||
top_k: int = 10,
|
||||
include_vectors: bool = False,
|
||||
include_metadata: bool = False,
|
||||
filter: str = "",
|
||||
data: str | None = None,
|
||||
namespace: str = "",
|
||||
include_data: bool = False,
|
||||
):
|
||||
# Simple mock query, in real scenario you would calculate similarity
|
||||
mock_result = []
|
||||
for vector_data in self.vectors:
|
||||
mock_result.append(vector_data)
|
||||
return mock_result[:top_k]
|
||||
|
||||
def reset(self):
|
||||
self.vectors = []
|
||||
|
||||
def info(self):
|
||||
return AttrDict({"dimension": 1024})
|
||||
|
||||
|
||||
class AttrDict(UserDict):
|
||||
def __getattr__(self, item):
|
||||
return self.get(item)
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_upstashvector_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Index, "__init__", MockIndex.__init__)
|
||||
monkeypatch.setattr(Index, "upsert", MockIndex.upsert)
|
||||
monkeypatch.setattr(Index, "fetch", MockIndex.fetch)
|
||||
monkeypatch.setattr(Index, "delete", MockIndex.delete)
|
||||
monkeypatch.setattr(Index, "query", MockIndex.query)
|
||||
monkeypatch.setattr(Index, "reset", MockIndex.reset)
|
||||
monkeypatch.setattr(Index, "info", MockIndex.info)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
215
dify/api/tests/integration_tests/vdb/__mock/vikingdb.py
Normal file
215
dify/api/tests/integration_tests/vdb/__mock/vikingdb.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import os
|
||||
from typing import Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from volcengine.viking_db import (
|
||||
Collection,
|
||||
Data,
|
||||
DistanceType,
|
||||
Field,
|
||||
FieldType,
|
||||
Index,
|
||||
IndexType,
|
||||
QuantType,
|
||||
VectorIndexParams,
|
||||
VikingDBService,
|
||||
)
|
||||
|
||||
from core.rag.datasource.vdb.field import Field as vdb_Field
|
||||
|
||||
|
||||
class MockVikingDBClass:
|
||||
def __init__(
|
||||
self,
|
||||
host="api-vikingdb.volces.com",
|
||||
region="cn-north-1",
|
||||
ak="",
|
||||
sk="",
|
||||
scheme="http",
|
||||
connection_timeout=30,
|
||||
socket_timeout=30,
|
||||
proxy=None,
|
||||
):
|
||||
self._viking_db_service = MagicMock()
|
||||
self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}')
|
||||
|
||||
def get_collection(self, collection_name) -> Collection:
|
||||
return Collection(
|
||||
collection_name=collection_name,
|
||||
description="Collection For Dify",
|
||||
viking_db_service=self._viking_db_service,
|
||||
primary_key=vdb_Field.PRIMARY_KEY,
|
||||
fields=[
|
||||
Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True),
|
||||
Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text),
|
||||
Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=768),
|
||||
],
|
||||
indexes=[
|
||||
Index(
|
||||
collection_name=collection_name,
|
||||
index_name=f"{collection_name}_idx",
|
||||
vector_index=VectorIndexParams(
|
||||
distance=DistanceType.L2,
|
||||
index_type=IndexType.HNSW,
|
||||
quant=QuantType.Float,
|
||||
),
|
||||
scalar_index=None,
|
||||
stat=None,
|
||||
viking_db_service=self._viking_db_service,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def drop_collection(self, collection_name):
|
||||
assert collection_name != ""
|
||||
|
||||
def create_collection(self, collection_name, fields, description="") -> Collection:
|
||||
return Collection(
|
||||
collection_name=collection_name,
|
||||
description=description,
|
||||
primary_key=vdb_Field.PRIMARY_KEY,
|
||||
viking_db_service=self._viking_db_service,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
def get_index(self, collection_name, index_name) -> Index:
|
||||
return Index(
|
||||
collection_name=collection_name,
|
||||
index_name=index_name,
|
||||
viking_db_service=self._viking_db_service,
|
||||
stat=None,
|
||||
scalar_index=None,
|
||||
vector_index=VectorIndexParams(
|
||||
distance=DistanceType.L2,
|
||||
index_type=IndexType.HNSW,
|
||||
quant=QuantType.Float,
|
||||
),
|
||||
)
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
collection_name,
|
||||
index_name,
|
||||
vector_index=None,
|
||||
cpu_quota=2,
|
||||
description="",
|
||||
partition_by="",
|
||||
scalar_index=None,
|
||||
shard_count=None,
|
||||
shard_policy=None,
|
||||
):
|
||||
return Index(
|
||||
collection_name=collection_name,
|
||||
index_name=index_name,
|
||||
vector_index=vector_index,
|
||||
cpu_quota=cpu_quota,
|
||||
description=description,
|
||||
partition_by=partition_by,
|
||||
scalar_index=scalar_index,
|
||||
shard_count=shard_count,
|
||||
shard_policy=shard_policy,
|
||||
viking_db_service=self._viking_db_service,
|
||||
stat=None,
|
||||
)
|
||||
|
||||
def drop_index(self, collection_name, index_name):
|
||||
assert collection_name != ""
|
||||
assert index_name != ""
|
||||
|
||||
def upsert_data(self, data: Union[Data, list[Data]]):
|
||||
assert data is not None
|
||||
|
||||
def fetch_data(self, id: Union[str, list[str], int, list[int]]):
|
||||
return Data(
|
||||
fields={
|
||||
vdb_Field.GROUP_KEY: "test_group",
|
||||
vdb_Field.METADATA_KEY: "{}",
|
||||
vdb_Field.CONTENT_KEY: "content",
|
||||
vdb_Field.PRIMARY_KEY: id,
|
||||
vdb_Field.VECTOR: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
|
||||
},
|
||||
id=id,
|
||||
)
|
||||
|
||||
def delete_data(self, id: Union[str, list[str], int, list[int]]):
|
||||
assert id is not None
|
||||
|
||||
def search_by_vector(
|
||||
self,
|
||||
vector,
|
||||
sparse_vectors=None,
|
||||
filter=None,
|
||||
limit=10,
|
||||
output_fields=None,
|
||||
partition="default",
|
||||
dense_weight=None,
|
||||
) -> list[Data]:
|
||||
return [
|
||||
Data(
|
||||
fields={
|
||||
vdb_Field.GROUP_KEY: "test_group",
|
||||
vdb_Field.METADATA_KEY: '\
|
||||
{"source": "/var/folders/ml/xxx/xxx.txt", \
|
||||
"document_id": "test_document_id", \
|
||||
"dataset_id": "test_dataset_id", \
|
||||
"doc_id": "test_id", \
|
||||
"doc_hash": "test_hash"}',
|
||||
vdb_Field.CONTENT_KEY: "content",
|
||||
vdb_Field.PRIMARY_KEY: "test_id",
|
||||
vdb_Field.VECTOR: vector,
|
||||
},
|
||||
id="test_id",
|
||||
score=0.10,
|
||||
)
|
||||
]
|
||||
|
||||
def search(
|
||||
self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None
|
||||
) -> list[Data]:
|
||||
return [
|
||||
Data(
|
||||
fields={
|
||||
vdb_Field.GROUP_KEY: "test_group",
|
||||
vdb_Field.METADATA_KEY: '\
|
||||
{"source": "/var/folders/ml/xxx/xxx.txt", \
|
||||
"document_id": "test_document_id", \
|
||||
"dataset_id": "test_dataset_id", \
|
||||
"doc_id": "test_id", \
|
||||
"doc_hash": "test_hash"}',
|
||||
vdb_Field.CONTENT_KEY: "content",
|
||||
vdb_Field.PRIMARY_KEY: "test_id",
|
||||
vdb_Field.VECTOR: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
|
||||
},
|
||||
id="test_id",
|
||||
score=0.10,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_vikingdb_mock(monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__)
|
||||
monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection)
|
||||
monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection)
|
||||
monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection)
|
||||
monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index)
|
||||
monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index)
|
||||
monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index)
|
||||
monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data)
|
||||
monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data)
|
||||
monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data)
|
||||
monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector)
|
||||
monkeypatch.setattr(Index, "search", MockVikingDBClass.search)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,49 @@
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
|
||||
|
||||
|
||||
class AnalyticdbVectorTest(AbstractVectorTest):
|
||||
def __init__(self, config_type: str):
|
||||
super().__init__()
|
||||
# Analyticdb requires collection_name length less than 60.
|
||||
# it's ok for normal usage.
|
||||
self.collection_name = self.collection_name.replace("_test", "")
|
||||
if config_type == "sql":
|
||||
self.vector = AnalyticdbVector(
|
||||
collection_name=self.collection_name,
|
||||
sql_config=AnalyticdbVectorBySqlConfig(
|
||||
host="test_host",
|
||||
port=5432,
|
||||
account="test_account",
|
||||
account_password="test_passwd",
|
||||
namespace="difytest_namespace",
|
||||
),
|
||||
api_config=None,
|
||||
)
|
||||
else:
|
||||
self.vector = AnalyticdbVector(
|
||||
collection_name=self.collection_name,
|
||||
sql_config=None,
|
||||
api_config=AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id="test_key_id",
|
||||
access_key_secret="test_key_secret",
|
||||
region_id="test_region",
|
||||
instance_id="test_id",
|
||||
account="test_account",
|
||||
account_password="test_passwd",
|
||||
namespace="difytest_namespace",
|
||||
collection="difytest_collection",
|
||||
namespace_password="test_passwd",
|
||||
),
|
||||
)
|
||||
|
||||
def run_all_tests(self):
|
||||
self.vector.delete()
|
||||
return super().run_all_tests()
|
||||
|
||||
|
||||
def test_chroma_vector(setup_mock_redis):
|
||||
AnalyticdbVectorTest("api").run_all_tests()
|
||||
AnalyticdbVectorTest("sql").run_all_tests()
|
||||
31
dify/api/tests/integration_tests/vdb/baidu/test_baidu.py
Normal file
31
dify/api/tests/integration_tests/vdb/baidu/test_baidu.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
|
||||
from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class BaiduVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = BaiduVector(
|
||||
"dify",
|
||||
BaiduConfig(
|
||||
endpoint="http://127.0.0.1:5287",
|
||||
account="root",
|
||||
api_key="dify",
|
||||
database="dify",
|
||||
shard=1,
|
||||
replicas=3,
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
|
||||
def test_baidu_vector(setup_mock_redis, setup_baiduvectordb_mock):
|
||||
BaiduVectorTest().run_all_tests()
|
||||
33
dify/api/tests/integration_tests/vdb/chroma/test_chroma.py
Normal file
33
dify/api/tests/integration_tests/vdb/chroma/test_chroma.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import chromadb
|
||||
|
||||
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaConfig, ChromaVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class ChromaVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = ChromaVector(
|
||||
collection_name=self.collection_name,
|
||||
config=ChromaConfig(
|
||||
host="localhost",
|
||||
port=8000,
|
||||
tenant=chromadb.DEFAULT_TENANT,
|
||||
database=chromadb.DEFAULT_DATABASE,
|
||||
auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
|
||||
auth_credentials="difyai123456",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
# chroma dos not support full text searching
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
|
||||
def test_chroma_vector(setup_mock_redis):
|
||||
ChromaVectorTest().run_all_tests()
|
||||
25
dify/api/tests/integration_tests/vdb/clickzetta/README.md
Normal file
25
dify/api/tests/integration_tests/vdb/clickzetta/README.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# Clickzetta Integration Tests
|
||||
|
||||
## Running Tests
|
||||
|
||||
To run the Clickzetta integration tests, you need to set the following environment variables:
|
||||
|
||||
```bash
|
||||
export CLICKZETTA_USERNAME=your_username
|
||||
export CLICKZETTA_PASSWORD=your_password
|
||||
export CLICKZETTA_INSTANCE=your_instance
|
||||
export CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
export CLICKZETTA_WORKSPACE=your_workspace
|
||||
export CLICKZETTA_VCLUSTER=your_vcluster
|
||||
export CLICKZETTA_SCHEMA=dify
|
||||
```
|
||||
|
||||
Then run the tests:
|
||||
|
||||
```bash
|
||||
pytest api/tests/integration_tests/vdb/clickzetta/
|
||||
```
|
||||
|
||||
## Security Note
|
||||
|
||||
Never commit credentials to the repository. Always use environment variables or secure credential management systems.
|
||||
@@ -0,0 +1,223 @@
|
||||
import contextlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class TestClickzettaVector(AbstractVectorTest):
|
||||
"""
|
||||
Test cases for Clickzetta vector database integration.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store(self):
|
||||
"""Create a Clickzetta vector store instance for testing."""
|
||||
# Skip test if Clickzetta credentials are not configured
|
||||
if not os.getenv("CLICKZETTA_USERNAME"):
|
||||
pytest.skip("CLICKZETTA_USERNAME is not configured")
|
||||
if not os.getenv("CLICKZETTA_PASSWORD"):
|
||||
pytest.skip("CLICKZETTA_PASSWORD is not configured")
|
||||
if not os.getenv("CLICKZETTA_INSTANCE"):
|
||||
pytest.skip("CLICKZETTA_INSTANCE is not configured")
|
||||
|
||||
config = ClickzettaConfig(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", ""),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", ""),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", ""),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"),
|
||||
batch_size=10, # Small batch size for testing
|
||||
enable_inverted_index=True,
|
||||
analyzer_type="chinese",
|
||||
analyzer_mode="smart",
|
||||
vector_distance_function="cosine_distance",
|
||||
)
|
||||
|
||||
with setup_mock_redis():
|
||||
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
|
||||
|
||||
yield vector
|
||||
|
||||
# Cleanup: delete the test collection
|
||||
with contextlib.suppress(Exception):
|
||||
vector.delete()
|
||||
|
||||
def test_clickzetta_vector_basic_operations(self, vector_store):
|
||||
"""Test basic CRUD operations on Clickzetta vector store."""
|
||||
# Prepare test data
|
||||
texts = [
|
||||
"这是第一个测试文档,包含一些中文内容。",
|
||||
"This is the second test document with English content.",
|
||||
"第三个文档混合了English和中文内容。",
|
||||
]
|
||||
embeddings = [
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[0.5, 0.6, 0.7, 0.8],
|
||||
[0.9, 1.0, 1.1, 1.2],
|
||||
]
|
||||
documents = [
|
||||
Document(page_content=text, metadata={"doc_id": f"doc_{i}", "source": "test"})
|
||||
for i, text in enumerate(texts)
|
||||
]
|
||||
|
||||
# Test create (initial insert)
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test text_exists
|
||||
assert vector_store.text_exists("doc_0")
|
||||
assert not vector_store.text_exists("doc_999")
|
||||
|
||||
# Test search_by_vector
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
results = vector_store.search_by_vector(query_vector, top_k=2)
|
||||
assert len(results) > 0
|
||||
assert results[0].page_content == texts[0] # Should match the first document
|
||||
|
||||
# Test search_by_full_text (Chinese)
|
||||
results = vector_store.search_by_full_text("中文", top_k=3)
|
||||
assert len(results) >= 2 # Should find documents with Chinese content
|
||||
|
||||
# Test search_by_full_text (English)
|
||||
results = vector_store.search_by_full_text("English", top_k=3)
|
||||
assert len(results) >= 2 # Should find documents with English content
|
||||
|
||||
# Test delete_by_ids
|
||||
vector_store.delete_by_ids(["doc_0"])
|
||||
assert not vector_store.text_exists("doc_0")
|
||||
assert vector_store.text_exists("doc_1")
|
||||
|
||||
# Test delete_by_metadata_field
|
||||
vector_store.delete_by_metadata_field("source", "test")
|
||||
assert not vector_store.text_exists("doc_1")
|
||||
assert not vector_store.text_exists("doc_2")
|
||||
|
||||
def test_clickzetta_vector_advanced_search(self, vector_store):
|
||||
"""Test advanced search features of Clickzetta vector store."""
|
||||
# Prepare test data with more complex metadata
|
||||
documents = []
|
||||
embeddings = []
|
||||
for i in range(10):
|
||||
doc = Document(
|
||||
page_content=f"Document {i}: " + get_example_text(),
|
||||
metadata={
|
||||
"doc_id": f"adv_doc_{i}",
|
||||
"category": "technical" if i % 2 == 0 else "general",
|
||||
"document_id": f"doc_{i // 3}", # Group documents
|
||||
"importance": i,
|
||||
},
|
||||
)
|
||||
documents.append(doc)
|
||||
# Create varied embeddings
|
||||
embeddings.append([0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i])
|
||||
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test vector search with document filter
|
||||
query_vector = [0.5, 1.0, 1.5, 2.0]
|
||||
results = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["doc_0", "doc_1"])
|
||||
assert len(results) > 0
|
||||
# All results should belong to doc_0 or doc_1 groups
|
||||
for result in results:
|
||||
assert result.metadata["document_id"] in ["doc_0", "doc_1"]
|
||||
|
||||
# Test score threshold
|
||||
results = vector_store.search_by_vector(query_vector, top_k=10, score_threshold=0.5)
|
||||
# Check that all results have a score above threshold
|
||||
for result in results:
|
||||
assert result.metadata.get("score", 0) >= 0.5
|
||||
|
||||
def test_clickzetta_batch_operations(self, vector_store):
|
||||
"""Test batch insertion operations."""
|
||||
# Prepare large batch of documents
|
||||
batch_size = 25
|
||||
documents = []
|
||||
embeddings = []
|
||||
|
||||
for i in range(batch_size):
|
||||
doc = Document(
|
||||
page_content=f"Batch document {i}: This is a test document for batch processing.",
|
||||
metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"},
|
||||
)
|
||||
documents.append(doc)
|
||||
embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)])
|
||||
|
||||
# Test batch insert
|
||||
vector_store.add_texts(documents=documents, embeddings=embeddings)
|
||||
|
||||
# Verify all documents were inserted
|
||||
for i in range(batch_size):
|
||||
assert vector_store.text_exists(f"batch_doc_{i}")
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_metadata_field("batch", "test_batch")
|
||||
|
||||
def test_clickzetta_edge_cases(self, vector_store):
|
||||
"""Test edge cases and error handling."""
|
||||
# Test empty operations
|
||||
vector_store.create(texts=[], embeddings=[])
|
||||
vector_store.add_texts(documents=[], embeddings=[])
|
||||
vector_store.delete_by_ids([])
|
||||
|
||||
# Test special characters in content
|
||||
special_doc = Document(
|
||||
page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline",
|
||||
metadata={"doc_id": "special_doc", "test": "edge_case"},
|
||||
)
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4]]
|
||||
|
||||
vector_store.add_texts(documents=[special_doc], embeddings=embeddings)
|
||||
assert vector_store.text_exists("special_doc")
|
||||
|
||||
# Test search with special characters
|
||||
results = vector_store.search_by_full_text("quotes", top_k=1)
|
||||
if results: # Full-text search might not be available
|
||||
assert len(results) > 0
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_ids(["special_doc"])
|
||||
|
||||
def test_clickzetta_full_text_search_modes(self, vector_store):
|
||||
"""Test different full-text search capabilities."""
|
||||
# Prepare documents with various language content
|
||||
documents = [
|
||||
Document(
|
||||
page_content="云器科技提供强大的Lakehouse解决方案", metadata={"doc_id": "cn_doc_1", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Clickzetta provides powerful Lakehouse solutions",
|
||||
metadata={"doc_id": "en_doc_1", "lang": "english"},
|
||||
),
|
||||
Document(
|
||||
page_content="Lakehouse是现代数据架构的重要组成部分", metadata={"doc_id": "cn_doc_2", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Modern data architecture includes Lakehouse technology",
|
||||
metadata={"doc_id": "en_doc_2", "lang": "english"},
|
||||
),
|
||||
]
|
||||
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4] for _ in documents]
|
||||
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test Chinese full-text search
|
||||
results = vector_store.search_by_full_text("Lakehouse", top_k=4)
|
||||
assert len(results) >= 2 # Should find at least documents with "Lakehouse"
|
||||
|
||||
# Test English full-text search
|
||||
results = vector_store.search_by_full_text("solutions", top_k=2)
|
||||
assert len(results) >= 1 # Should find English documents with "solutions"
|
||||
|
||||
# Test mixed search
|
||||
results = vector_store.search_by_full_text("数据架构", top_k=2)
|
||||
assert len(results) >= 1 # Should find Chinese documents with this phrase
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_metadata_field("lang", "chinese")
|
||||
vector_store.delete_by_metadata_field("lang", "english")
|
||||
@@ -0,0 +1,165 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Clickzetta integration in Docker environment
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from clickzetta import connect
|
||||
|
||||
|
||||
def test_clickzetta_connection():
|
||||
"""Test direct connection to Clickzetta"""
|
||||
print("=== Testing direct Clickzetta connection ===")
|
||||
try:
|
||||
conn = connect(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", "test_password"),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"),
|
||||
database=os.getenv("CLICKZETTA_SCHEMA", "dify"),
|
||||
)
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Test basic connectivity
|
||||
cursor.execute("SELECT 1 as test")
|
||||
result = cursor.fetchone()
|
||||
print(f"✓ Connection test: {result}")
|
||||
|
||||
# Check if our test table exists
|
||||
cursor.execute("SHOW TABLES IN dify")
|
||||
tables = cursor.fetchall()
|
||||
print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}")
|
||||
|
||||
# Check if test collection exists
|
||||
test_collection = "collection_test_dataset"
|
||||
if test_collection in [t[1] for t in tables if t[0] == "dify"]:
|
||||
cursor.execute(f"DESCRIBE dify.{test_collection}")
|
||||
columns = cursor.fetchall()
|
||||
print(f"✓ Table structure for {test_collection}:")
|
||||
for col in columns:
|
||||
print(f" - {col[0]}: {col[1]}")
|
||||
|
||||
# Check for indexes
|
||||
cursor.execute(f"SHOW INDEXES IN dify.{test_collection}")
|
||||
indexes = cursor.fetchall()
|
||||
print(f"✓ Indexes on {test_collection}:")
|
||||
for idx in indexes:
|
||||
print(f" - {idx}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ Connection test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_dify_api():
|
||||
"""Test Dify API with Clickzetta backend"""
|
||||
print("\n=== Testing Dify API ===")
|
||||
base_url = "http://localhost:5001"
|
||||
|
||||
# Wait for API to be ready
|
||||
max_retries = 30
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = httpx.get(f"{base_url}/console/api/health")
|
||||
if response.status_code == 200:
|
||||
print("✓ Dify API is ready")
|
||||
break
|
||||
except:
|
||||
if i == max_retries - 1:
|
||||
print("✗ Dify API is not responding")
|
||||
return False
|
||||
time.sleep(2)
|
||||
|
||||
# Check vector store configuration
|
||||
try:
|
||||
# This is a simplified check - in production, you'd use proper auth
|
||||
print("✓ Dify is configured to use Clickzetta as vector store")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ API test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def verify_table_structure():
|
||||
"""Verify the table structure meets Dify requirements"""
|
||||
print("\n=== Verifying Table Structure ===")
|
||||
|
||||
expected_columns = {
|
||||
"id": "VARCHAR",
|
||||
"page_content": "VARCHAR",
|
||||
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
|
||||
"vector": "ARRAY<FLOAT>",
|
||||
}
|
||||
|
||||
expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"]
|
||||
|
||||
print("✓ Expected table structure:")
|
||||
for col, dtype in expected_columns.items():
|
||||
print(f" - {col}: {dtype}")
|
||||
|
||||
print("\n✓ Required metadata fields:")
|
||||
for field in expected_metadata_fields:
|
||||
print(f" - {field}")
|
||||
|
||||
print("\n✓ Index requirements:")
|
||||
print(" - Vector index (HNSW) on 'vector' column")
|
||||
print(" - Full-text index on 'page_content' (optional)")
|
||||
print(" - Functional index on metadata->>'$.doc_id' (recommended)")
|
||||
print(" - Functional index on metadata->>'$.document_id' (recommended)")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("Starting Clickzetta integration tests for Dify Docker\n")
|
||||
|
||||
tests = [
|
||||
("Direct Clickzetta Connection", test_clickzetta_connection),
|
||||
("Dify API Status", test_dify_api),
|
||||
("Table Structure Verification", verify_table_structure),
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_name, test_func in tests:
|
||||
try:
|
||||
success = test_func()
|
||||
results.append((test_name, success))
|
||||
except Exception as e:
|
||||
print(f"\n✗ {test_name} crashed: {e}")
|
||||
results.append((test_name, False))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Test Summary:")
|
||||
print("=" * 50)
|
||||
|
||||
passed = sum(1 for _, success in results if success)
|
||||
total = len(results)
|
||||
|
||||
for test_name, success in results:
|
||||
status = "✅ PASSED" if success else "❌ FAILED"
|
||||
print(f"{test_name}: {status}")
|
||||
|
||||
print(f"\nTotal: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
|
||||
print("\nNext steps:")
|
||||
print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d")
|
||||
print("2. Access Dify at http://localhost:3000")
|
||||
print("3. Create a dataset and test vector storage with Clickzetta")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please check the errors above.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
@@ -0,0 +1,49 @@
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
def wait_for_healthy_container(service_name="couchbase-server", timeout=300):
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
result = subprocess.run(
|
||||
["docker", "inspect", "--format", "{{.State.Health.Status}}", service_name], capture_output=True, text=True
|
||||
)
|
||||
if result.stdout.strip() == "healthy":
|
||||
print(f"{service_name} is healthy!")
|
||||
return True
|
||||
else:
|
||||
print(f"Waiting for {service_name} to be healthy...")
|
||||
time.sleep(10)
|
||||
raise TimeoutError(f"{service_name} did not become healthy in time")
|
||||
|
||||
|
||||
class CouchbaseTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = CouchbaseVector(
|
||||
collection_name=self.collection_name,
|
||||
config=CouchbaseConfig(
|
||||
connection_string="couchbase://127.0.0.1",
|
||||
user="Administrator",
|
||||
password="password",
|
||||
bucket_name="Embeddings",
|
||||
scope_name="_default",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
# brief sleep to ensure document is indexed
|
||||
time.sleep(5)
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
|
||||
|
||||
def test_couchbase(setup_mock_redis):
|
||||
wait_for_healthy_container("couchbase-server", timeout=60)
|
||||
CouchbaseTest().run_all_tests()
|
||||
@@ -0,0 +1,22 @@
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class ElasticSearchVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
|
||||
self.vector = ElasticSearchVector(
|
||||
index_name=self.collection_name.lower(),
|
||||
config=ElasticSearchConfig(
|
||||
use_cloud=False, host="http://localhost", port="9200", username="elastic", password="elastic"
|
||||
),
|
||||
attributes=self.attributes,
|
||||
)
|
||||
|
||||
|
||||
def test_elasticsearch_vector(setup_mock_redis):
|
||||
ElasticSearchVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,28 @@
|
||||
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig
|
||||
from tests.integration_tests.vdb.__mock.huaweicloudvectordb import setup_client_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class HuaweiCloudVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = HuaweiCloudVector(
|
||||
"dify",
|
||||
HuaweiCloudVectorConfig(
|
||||
hosts="https://127.0.0.1:9200",
|
||||
username="dify",
|
||||
password="dify",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 3
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 3
|
||||
|
||||
|
||||
def test_huawei_cloud_vector(setup_mock_redis, setup_client_mock):
|
||||
HuaweiCloudVectorTest().run_all_tests()
|
||||
58
dify/api/tests/integration_tests/vdb/lindorm/test_lindorm.py
Normal file
58
dify/api/tests/integration_tests/vdb/lindorm/test_lindorm.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
|
||||
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
|
||||
|
||||
|
||||
class Config:
|
||||
SEARCH_ENDPOINT = os.environ.get(
|
||||
"SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070"
|
||||
)
|
||||
SEARCH_USERNAME = os.environ.get("SEARCH_USERNAME", "ADMIN")
|
||||
SEARCH_PWD = os.environ.get("SEARCH_PWD", "ADMIN")
|
||||
USING_UGC = os.environ.get("USING_UGC", "True").lower() == "true"
|
||||
|
||||
|
||||
class TestLindormVectorStore(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = LindormVectorStore(
|
||||
collection_name=self.collection_name,
|
||||
config=LindormVectorStoreConfig(
|
||||
hosts=Config.SEARCH_ENDPOINT,
|
||||
username=Config.SEARCH_USERNAME,
|
||||
password=Config.SEARCH_PWD,
|
||||
),
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||
assert ids is not None
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == self.example_doc_id
|
||||
|
||||
|
||||
class TestLindormVectorStoreUGC(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = LindormVectorStore(
|
||||
collection_name="ugc_index_test",
|
||||
config=LindormVectorStoreConfig(
|
||||
hosts=Config.SEARCH_ENDPOINT,
|
||||
username=Config.SEARCH_USERNAME,
|
||||
password=Config.SEARCH_PWD,
|
||||
using_ugc=Config.USING_UGC,
|
||||
),
|
||||
routing_value=self.collection_name,
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||
assert ids is not None
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == self.example_doc_id
|
||||
|
||||
|
||||
def test_lindorm_vector_ugc(setup_mock_redis):
|
||||
TestLindormVectorStore().run_all_tests()
|
||||
TestLindormVectorStoreUGC().run_all_tests()
|
||||
@@ -0,0 +1,24 @@
|
||||
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class MatrixoneVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = MatrixoneVector(
|
||||
collection_name=self.collection_name,
|
||||
config=MatrixoneConfig(
|
||||
host="localhost", port=6001, user="dump", password="111", database="dify", metric="l2"
|
||||
),
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
def test_matrixone_vector(setup_mock_redis):
|
||||
MatrixoneVectorTest().run_all_tests()
|
||||
32
dify/api/tests/integration_tests/vdb/milvus/test_milvus.py
Normal file
32
dify/api/tests/integration_tests/vdb/milvus/test_milvus.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class MilvusVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = MilvusVector(
|
||||
collection_name=self.collection_name,
|
||||
config=MilvusConfig(
|
||||
uri="http://localhost:19530",
|
||||
user="root",
|
||||
password="Milvus",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
# milvus support BM25 full text search after version 2.5.0-beta
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) >= 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
def test_milvus_vector(setup_mock_redis):
|
||||
MilvusVectorTest().run_all_tests()
|
||||
29
dify/api/tests/integration_tests/vdb/myscale/test_myscale.py
Normal file
29
dify/api/tests/integration_tests/vdb/myscale/test_myscale.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleConfig, MyScaleVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class MyScaleVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = MyScaleVector(
|
||||
collection_name=self.collection_name,
|
||||
config=MyScaleConfig(
|
||||
host="localhost",
|
||||
port=8123,
|
||||
user="default",
|
||||
password="",
|
||||
database="dify",
|
||||
fts_params="",
|
||||
),
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
def test_myscale_vector(setup_mock_redis):
|
||||
MyScaleVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import (
|
||||
OceanBaseVector,
|
||||
OceanBaseVectorConfig,
|
||||
)
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oceanbase_vector():
|
||||
return OceanBaseVector(
|
||||
"dify_test_collection",
|
||||
config=OceanBaseVectorConfig(
|
||||
host="127.0.0.1",
|
||||
port=2881,
|
||||
user="root",
|
||||
database="test",
|
||||
password="difyai123456",
|
||||
enable_hybrid_search=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class OceanBaseVectorTest(AbstractVectorTest):
|
||||
def __init__(self, vector: OceanBaseVector):
|
||||
super().__init__()
|
||||
self.vector = vector
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
def test_oceanbase_vector(
|
||||
setup_mock_redis,
|
||||
oceanbase_vector,
|
||||
):
|
||||
OceanBaseVectorTest(oceanbase_vector).run_all_tests()
|
||||
@@ -0,0 +1,41 @@
|
||||
import time
|
||||
|
||||
import psycopg2
|
||||
|
||||
from core.rag.datasource.vdb.opengauss.opengauss import OpenGauss, OpenGaussConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class OpenGaussTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
max_retries = 5
|
||||
retry_delay = 20
|
||||
retry_count = 0
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
config = OpenGaussConfig(
|
||||
host="localhost",
|
||||
port=6600,
|
||||
user="postgres",
|
||||
password="Dify@123",
|
||||
database="dify",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
)
|
||||
break
|
||||
except psycopg2.OperationalError as e:
|
||||
retry_count += 1
|
||||
if retry_count < max_retries:
|
||||
time.sleep(retry_delay)
|
||||
self.vector = OpenGauss(
|
||||
collection_name=self.collection_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def test_opengauss(setup_mock_redis):
|
||||
OpenGaussTest().run_all_tests()
|
||||
@@ -0,0 +1,237 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchConfig, OpenSearchVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions import ext_redis
|
||||
|
||||
|
||||
def get_example_text() -> str:
|
||||
return "This is a sample text for testing purposes."
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def setup_mock_redis():
|
||||
ext_redis.redis_client.get = MagicMock(return_value=None)
|
||||
ext_redis.redis_client.set = MagicMock(return_value=None)
|
||||
|
||||
mock_redis_lock = MagicMock()
|
||||
mock_redis_lock.__enter__ = MagicMock()
|
||||
mock_redis_lock.__exit__ = MagicMock()
|
||||
ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock)
|
||||
|
||||
|
||||
class TestOpenSearchConfig:
|
||||
def test_to_opensearch_params(self):
|
||||
config = OpenSearchConfig(
|
||||
host="localhost",
|
||||
port=9200,
|
||||
secure=True,
|
||||
user="admin",
|
||||
password="password",
|
||||
)
|
||||
|
||||
params = config.to_opensearch_params()
|
||||
|
||||
assert params["hosts"] == [{"host": "localhost", "port": 9200}]
|
||||
assert params["use_ssl"] is True
|
||||
assert params["verify_certs"] is True
|
||||
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
|
||||
assert params["http_auth"] == ("admin", "password")
|
||||
|
||||
@patch("boto3.Session")
|
||||
@patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth")
|
||||
def test_to_opensearch_params_with_aws_managed_iam(
|
||||
self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock
|
||||
):
|
||||
mock_credentials = MagicMock()
|
||||
mock_boto_session.return_value.get_credentials.return_value = mock_credentials
|
||||
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_aws_signer_auth.return_value = mock_auth_instance
|
||||
|
||||
aws_region = "ap-southeast-2"
|
||||
aws_service = "aoss"
|
||||
host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com"
|
||||
port = 9201
|
||||
|
||||
config = OpenSearchConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
secure=True,
|
||||
auth_method="aws_managed_iam",
|
||||
aws_region=aws_region,
|
||||
aws_service=aws_service,
|
||||
)
|
||||
|
||||
params = config.to_opensearch_params()
|
||||
|
||||
assert params["hosts"] == [{"host": host, "port": port}]
|
||||
assert params["use_ssl"] is True
|
||||
assert params["verify_certs"] is True
|
||||
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
|
||||
assert params["http_auth"] is mock_auth_instance
|
||||
|
||||
mock_aws_signer_auth.assert_called_once_with(
|
||||
credentials=mock_credentials, region=aws_region, service=aws_service
|
||||
)
|
||||
assert mock_boto_session.return_value.get_credentials.called
|
||||
|
||||
|
||||
class TestOpenSearchVector:
|
||||
def setup_method(self):
|
||||
self.collection_name = "test_collection"
|
||||
self.example_doc_id = "example_doc_id"
|
||||
self.vector = OpenSearchVector(
|
||||
collection_name=self.collection_name,
|
||||
config=OpenSearchConfig(host="localhost", port=9200, secure=False, user="admin", password="password"),
|
||||
)
|
||||
self.vector._client = MagicMock()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("search_response", "expected_length", "expected_doc_id"),
|
||||
[
|
||||
(
|
||||
{
|
||||
"hits": {
|
||||
"total": {"value": 1},
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
"page_content": get_example_text(),
|
||||
"metadata": {"document_id": "example_doc_id"},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
1,
|
||||
"example_doc_id",
|
||||
),
|
||||
({"hits": {"total": {"value": 0}, "hits": []}}, 0, None),
|
||||
],
|
||||
)
|
||||
def test_search_by_full_text(self, search_response, expected_length, expected_doc_id):
|
||||
self.vector._client.search.return_value = search_response
|
||||
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == expected_length
|
||||
if expected_length > 0:
|
||||
assert hits_by_full_text[0].metadata["document_id"] == expected_doc_id
|
||||
|
||||
def test_search_by_vector(self):
|
||||
vector = [0.1] * 128
|
||||
mock_response = {
|
||||
"hits": {
|
||||
"total": {"value": 1},
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
Field.CONTENT_KEY: get_example_text(),
|
||||
Field.METADATA_KEY: {"document_id": self.example_doc_id},
|
||||
},
|
||||
"_score": 1.0,
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
self.vector._client.search.return_value = mock_response
|
||||
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=vector)
|
||||
|
||||
print("Hits by vector:", hits_by_vector)
|
||||
print("Expected document ID:", self.example_doc_id)
|
||||
print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits")
|
||||
|
||||
assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}"
|
||||
assert hits_by_vector[0].metadata["document_id"] == self.example_doc_id, (
|
||||
f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}"
|
||||
)
|
||||
|
||||
def test_get_ids_by_metadata_field(self):
|
||||
mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}}
|
||||
self.vector._client.search.return_value = mock_response
|
||||
|
||||
doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
|
||||
embedding = [0.1] * 128
|
||||
|
||||
with patch("opensearchpy.helpers.bulk") as mock_bulk:
|
||||
mock_bulk.return_value = ([], [])
|
||||
self.vector.add_texts([doc], [embedding])
|
||||
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == "mock_id"
|
||||
|
||||
def test_add_texts(self):
|
||||
self.vector._client.index.return_value = {"result": "created"}
|
||||
|
||||
doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
|
||||
embedding = [0.1] * 128
|
||||
|
||||
with patch("opensearchpy.helpers.bulk") as mock_bulk:
|
||||
mock_bulk.return_value = ([], [])
|
||||
self.vector.add_texts([doc], [embedding])
|
||||
|
||||
mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}}
|
||||
self.vector._client.search.return_value = mock_response
|
||||
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == "mock_id"
|
||||
|
||||
def test_delete_nonexistent_index(self):
|
||||
"""Test deleting a non-existent index."""
|
||||
# Create a vector instance with a non-existent collection name
|
||||
self.vector._client.indices.exists.return_value = False
|
||||
|
||||
# Should not raise an exception
|
||||
self.vector.delete()
|
||||
|
||||
# Verify that exists was called but delete was not
|
||||
self.vector._client.indices.exists.assert_called_once_with(index=self.collection_name.lower())
|
||||
self.vector._client.indices.delete.assert_not_called()
|
||||
|
||||
def test_delete_existing_index(self):
|
||||
"""Test deleting an existing index."""
|
||||
self.vector._client.indices.exists.return_value = True
|
||||
|
||||
self.vector.delete()
|
||||
|
||||
# Verify both exists and delete were called
|
||||
self.vector._client.indices.exists.assert_called_once_with(index=self.collection_name.lower())
|
||||
self.vector._client.indices.delete.assert_called_once_with(index=self.collection_name.lower())
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_mock_redis")
|
||||
class TestOpenSearchVectorWithRedis:
|
||||
def setup_method(self):
|
||||
self.tester = TestOpenSearchVector()
|
||||
|
||||
def test_search_by_full_text(self):
|
||||
self.tester.setup_method()
|
||||
search_response = {
|
||||
"hits": {
|
||||
"total": {"value": 1},
|
||||
"hits": [
|
||||
{"_source": {"page_content": get_example_text(), "metadata": {"document_id": "example_doc_id"}}}
|
||||
],
|
||||
}
|
||||
}
|
||||
expected_length = 1
|
||||
expected_doc_id = "example_doc_id"
|
||||
self.tester.test_search_by_full_text(search_response, expected_length, expected_doc_id)
|
||||
|
||||
def test_get_ids_by_metadata_field(self):
|
||||
self.tester.setup_method()
|
||||
self.tester.test_get_ids_by_metadata_field()
|
||||
|
||||
def test_add_texts(self):
|
||||
self.tester.setup_method()
|
||||
self.tester.test_add_texts()
|
||||
|
||||
def test_search_by_vector(self):
|
||||
self.tester.setup_method()
|
||||
self.tester.test_search_by_vector()
|
||||
@@ -0,0 +1,28 @@
|
||||
from core.rag.datasource.vdb.oracle.oraclevector import OracleVector, OracleVectorConfig
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class OracleVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = OracleVector(
|
||||
collection_name=self.collection_name,
|
||||
config=OracleVectorConfig(
|
||||
user="dify",
|
||||
password="dify",
|
||||
dsn="localhost:1521/FREEPDB1",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
|
||||
def test_oraclevector(setup_mock_redis):
|
||||
OracleVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,35 @@
|
||||
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class PGVectoRSVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = PGVectoRS(
|
||||
collection_name=self.collection_name.lower(),
|
||||
config=PgvectoRSConfig(
|
||||
host="localhost",
|
||||
port=5431,
|
||||
user="postgres",
|
||||
password="difyai123456",
|
||||
database="dify",
|
||||
),
|
||||
dim=128,
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
# pgvecto rs only support english text search, So it’s not open for now
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
def test_pgvecto_rs(setup_mock_redis):
|
||||
PGVectoRSVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,27 @@
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class PGVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = PGVector(
|
||||
collection_name=self.collection_name,
|
||||
config=PGVectorConfig(
|
||||
host="localhost",
|
||||
port=5433,
|
||||
user="postgres",
|
||||
password="difyai123456",
|
||||
database="dify",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_pgvector(setup_mock_redis):
|
||||
PGVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,26 @@
|
||||
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVector, VastbaseVectorConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class VastbaseVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = VastbaseVector(
|
||||
collection_name=self.collection_name,
|
||||
config=VastbaseVectorConfig(
|
||||
host="localhost",
|
||||
port=5434,
|
||||
user="dify",
|
||||
password="Difyai123456",
|
||||
database="dify",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_vastbase_vector(setup_mock_redis):
|
||||
VastbaseVectorTest().run_all_tests()
|
||||
32
dify/api/tests/integration_tests/vdb/qdrant/test_qdrant.py
Normal file
32
dify/api/tests/integration_tests/vdb/qdrant/test_qdrant.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class QdrantVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
|
||||
self.vector = QdrantVector(
|
||||
collection_name=self.collection_name,
|
||||
group_id=self.dataset_id,
|
||||
config=QdrantConfig(
|
||||
endpoint="http://localhost:6333",
|
||||
api_key="difyai123456",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
super().search_by_vector()
|
||||
# only test for qdrant, may not work on other vector stores
|
||||
hits_by_vector: list[Document] = self.vector.search_by_vector(
|
||||
query_vector=self.example_embedding, score_threshold=1
|
||||
)
|
||||
assert len(hits_by_vector) == 0
|
||||
|
||||
|
||||
def test_qdrant_vector(setup_mock_redis):
|
||||
QdrantVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,100 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import tablestore
|
||||
from _pytest.python_api import approx
|
||||
|
||||
from core.rag.datasource.vdb.tablestore.tablestore_vector import (
|
||||
TableStoreConfig,
|
||||
TableStoreVector,
|
||||
)
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_document,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class TableStoreVectorTest(AbstractVectorTest):
|
||||
def __init__(self, normalize_full_text_score: bool = False):
|
||||
super().__init__()
|
||||
self.vector = TableStoreVector(
|
||||
collection_name=self.collection_name,
|
||||
config=TableStoreConfig(
|
||||
endpoint=os.getenv("TABLESTORE_ENDPOINT"),
|
||||
instance_name=os.getenv("TABLESTORE_INSTANCE_NAME"),
|
||||
access_key_id=os.getenv("TABLESTORE_ACCESS_KEY_ID"),
|
||||
access_key_secret=os.getenv("TABLESTORE_ACCESS_KEY_SECRET"),
|
||||
normalize_full_text_bm25_score=normalize_full_text_score,
|
||||
),
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||
assert ids is not None
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == self.example_doc_id
|
||||
|
||||
def create_vector(self):
|
||||
self.vector.create(
|
||||
texts=[get_example_document(doc_id=self.example_doc_id)],
|
||||
embeddings=[self.example_embedding],
|
||||
)
|
||||
while True:
|
||||
search_response = self.vector._tablestore_client.search(
|
||||
table_name=self.vector._table_name,
|
||||
index_name=self.vector._index_name,
|
||||
search_query=tablestore.SearchQuery(query=tablestore.MatchAllQuery(), get_total_count=True, limit=0),
|
||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||
)
|
||||
if search_response.total_count == 1:
|
||||
break
|
||||
|
||||
def search_by_vector(self):
|
||||
super().search_by_vector()
|
||||
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[self.example_doc_id])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||
assert docs[0].metadata["score"] > 0
|
||||
|
||||
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[str(uuid.uuid4())])
|
||||
assert len(docs) == 0
|
||||
|
||||
def search_by_full_text(self):
|
||||
super().search_by_full_text()
|
||||
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||
if self.vector._config.normalize_full_text_bm25_score:
|
||||
assert docs[0].metadata["score"] == approx(0.1214, abs=1e-3)
|
||||
else:
|
||||
assert docs[0].metadata.get("score") is None
|
||||
|
||||
# return none if normalize_full_text_score=true and score_threshold > 0
|
||||
docs = self.vector.search_by_full_text(
|
||||
get_example_text(), document_ids_filter=[self.example_doc_id], score_threshold=0.5
|
||||
)
|
||||
if self.vector._config.normalize_full_text_bm25_score:
|
||||
assert len(docs) == 0
|
||||
else:
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||
assert docs[0].metadata.get("score") is None
|
||||
|
||||
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())])
|
||||
assert len(docs) == 0
|
||||
|
||||
def run_all_tests(self):
|
||||
try:
|
||||
self.vector.delete()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return super().run_all_tests()
|
||||
|
||||
|
||||
def test_tablestore_vector(setup_mock_redis):
|
||||
TableStoreVectorTest().run_all_tests()
|
||||
TableStoreVectorTest(normalize_full_text_score=True).run_all_tests()
|
||||
TableStoreVectorTest(normalize_full_text_score=False).run_all_tests()
|
||||
@@ -0,0 +1,38 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.rag.datasource.vdb.tencent.tencent_vector import TencentConfig, TencentVector
|
||||
from tests.integration_tests.vdb.__mock.tcvectordb import setup_tcvectordb_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.list_databases.return_value = [{"name": "test"}]
|
||||
|
||||
|
||||
class TencentVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = TencentVector(
|
||||
"dify",
|
||||
TencentConfig(
|
||||
url="http://127.0.0.1",
|
||||
api_key="dify",
|
||||
timeout=30,
|
||||
username="dify",
|
||||
database="dify",
|
||||
shard=1,
|
||||
replicas=2,
|
||||
enable_hybrid_search=True,
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) >= 0
|
||||
|
||||
|
||||
def test_tencent_vector(setup_mock_redis, setup_tcvectordb_mock):
|
||||
TencentVectorTest().run_all_tests()
|
||||
95
dify/api/tests/integration_tests/vdb/test_vector_store.py
Normal file
95
dify/api/tests/integration_tests/vdb/test_vector_store.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions import ext_redis
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
def get_example_text() -> str:
|
||||
return "test_text"
|
||||
|
||||
|
||||
def get_example_document(doc_id: str) -> Document:
|
||||
doc = Document(
|
||||
page_content=get_example_text(),
|
||||
metadata={
|
||||
"doc_id": doc_id,
|
||||
"doc_hash": doc_id,
|
||||
"document_id": doc_id,
|
||||
"dataset_id": doc_id,
|
||||
},
|
||||
)
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_redis():
|
||||
# get
|
||||
ext_redis.redis_client.get = MagicMock(return_value=None)
|
||||
|
||||
# set
|
||||
ext_redis.redis_client.set = MagicMock(return_value=None)
|
||||
|
||||
# lock
|
||||
mock_redis_lock = MagicMock()
|
||||
mock_redis_lock.__enter__ = MagicMock()
|
||||
mock_redis_lock.__exit__ = MagicMock()
|
||||
ext_redis.redis_client.lock = mock_redis_lock
|
||||
|
||||
|
||||
class AbstractVectorTest:
|
||||
def __init__(self):
|
||||
self.vector = None
|
||||
self.dataset_id = str(uuid.uuid4())
|
||||
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
|
||||
self.example_doc_id = str(uuid.uuid4())
|
||||
self.example_embedding = [1.001 * i for i in range(128)]
|
||||
|
||||
def create_vector(self):
|
||||
self.vector.create(
|
||||
texts=[get_example_document(doc_id=self.example_doc_id)],
|
||||
embeddings=[self.example_embedding],
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
assert hits_by_vector[0].metadata["doc_id"] == self.example_doc_id
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 1
|
||||
assert hits_by_full_text[0].metadata["doc_id"] == self.example_doc_id
|
||||
|
||||
def delete_vector(self):
|
||||
self.vector.delete()
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
self.vector.delete_by_ids(ids=ids)
|
||||
|
||||
def add_texts(self) -> list[str]:
|
||||
batch_size = 100
|
||||
documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]
|
||||
embeddings = [self.example_embedding] * batch_size
|
||||
self.vector.add_texts(documents=documents, embeddings=embeddings)
|
||||
return [doc.metadata["doc_id"] for doc in documents]
|
||||
|
||||
def text_exists(self):
|
||||
assert self.vector.text_exists(self.example_doc_id)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
with pytest.raises(NotImplementedError):
|
||||
self.vector.get_ids_by_metadata_field(key="key", value="value")
|
||||
|
||||
def run_all_tests(self):
|
||||
self.create_vector()
|
||||
self.search_by_vector()
|
||||
self.search_by_full_text()
|
||||
self.text_exists()
|
||||
self.get_ids_by_metadata_field()
|
||||
added_doc_ids = self.add_texts()
|
||||
self.delete_by_ids(added_doc_ids)
|
||||
self.delete_vector()
|
||||
@@ -0,0 +1,59 @@
|
||||
import time
|
||||
|
||||
import pymysql
|
||||
|
||||
|
||||
def check_tiflash_ready() -> bool:
|
||||
try:
|
||||
connection = pymysql.connect(
|
||||
host="localhost",
|
||||
port=4000,
|
||||
user="root",
|
||||
password="",
|
||||
)
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
# Doc reference:
|
||||
# https://docs.pingcap.com/zh/tidb/stable/information-schema-cluster-hardware
|
||||
select_tiflash_query = """
|
||||
SELECT * FROM information_schema.cluster_hardware
|
||||
WHERE TYPE='tiflash'
|
||||
LIMIT 1;
|
||||
"""
|
||||
cursor.execute(select_tiflash_query)
|
||||
result = cursor.fetchall()
|
||||
return result is not None and len(result) > 0
|
||||
except Exception as e:
|
||||
print(f"TiFlash is not ready. Exception: {e}")
|
||||
return False
|
||||
finally:
|
||||
if connection:
|
||||
connection.close()
|
||||
|
||||
|
||||
def main():
|
||||
max_attempts = 30
|
||||
retry_interval_seconds = 2
|
||||
is_tiflash_ready = False
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
is_tiflash_ready = check_tiflash_ready()
|
||||
except Exception as e:
|
||||
print(f"TiFlash is not ready. Exception: {e}")
|
||||
is_tiflash_ready = False
|
||||
|
||||
if is_tiflash_ready:
|
||||
break
|
||||
else:
|
||||
print(f"Attempt {attempt + 1} failed, retry in {retry_interval_seconds} seconds...")
|
||||
time.sleep(retry_interval_seconds)
|
||||
|
||||
if is_tiflash_ready:
|
||||
print("TiFlash is ready in TiDB.")
|
||||
else:
|
||||
print(f"TiFlash is not ready in TiDB after {max_attempts} attempting checks.")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,40 @@
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
|
||||
from models.dataset import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tidb_vector():
|
||||
return TiDBVector(
|
||||
collection_name="test_collection",
|
||||
config=TiDBVectorConfig(
|
||||
host="localhost",
|
||||
port=4000,
|
||||
user="root",
|
||||
password="",
|
||||
database="test",
|
||||
program_name="langgenius/dify",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TiDBVectorTest(AbstractVectorTest):
|
||||
def __init__(self, vector):
|
||||
super().__init__()
|
||||
self.vector = vector
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
def test_tidb_vector(setup_mock_redis, tidb_vector):
|
||||
# TiDBVectorTest(vector=tidb_vector).run_all_tests()
|
||||
# something wrong with tidb,ignore tidb test
|
||||
return
|
||||
@@ -0,0 +1,28 @@
|
||||
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVector, UpstashVectorConfig
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.__mock.upstashvectordb import setup_upstashvector_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
|
||||
|
||||
|
||||
class UpstashVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = UpstashVector(
|
||||
collection_name="test_collection",
|
||||
config=UpstashVectorConfig(
|
||||
url="your-server-url",
|
||||
token="your-access-token",
|
||||
),
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) != 0
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
|
||||
def test_upstash_vector(setup_upstashvector_mock):
|
||||
UpstashVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,37 @@
|
||||
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBConfig, VikingDBVector
|
||||
from tests.integration_tests.vdb.__mock.vikingdb import setup_vikingdb_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class VikingDBVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = VikingDBVector(
|
||||
"test_collection",
|
||||
"test_group",
|
||||
config=VikingDBConfig(
|
||||
access_key="test_access_key",
|
||||
host="test_host",
|
||||
region="test_region",
|
||||
scheme="test_scheme",
|
||||
secret_key="test_secret_key",
|
||||
connection_timeout=30,
|
||||
socket_timeout=30,
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value="test_document_id")
|
||||
assert len(ids) > 0
|
||||
|
||||
|
||||
def test_vikingdb_vector(setup_mock_redis, setup_vikingdb_mock):
|
||||
VikingDBVectorTest().run_all_tests()
|
||||
@@ -0,0 +1,23 @@
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class WeaviateVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
|
||||
self.vector = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint="http://localhost:8080",
|
||||
api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih",
|
||||
),
|
||||
attributes=self.attributes,
|
||||
)
|
||||
|
||||
|
||||
def test_weaviate_vector(setup_mock_redis):
|
||||
WeaviateVectorTest().run_all_tests()
|
||||
Reference in New Issue
Block a user