dify
This commit is contained in:
166
dify/api/tests/integration_tests/vdb/__mock/baiduvectordb.py
Normal file
166
dify/api/tests/integration_tests/vdb/__mock/baiduvectordb.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import os
|
||||
from collections import UserDict
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from pymochow import MochowClient
|
||||
from pymochow.model.database import Database
|
||||
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
|
||||
from pymochow.model.schema import HNSWParams, VectorIndex
|
||||
from pymochow.model.table import Table
|
||||
|
||||
|
||||
class AttrDict(UserDict):
|
||||
def __getattr__(self, item):
|
||||
return self.get(item)
|
||||
|
||||
|
||||
class MockBaiduVectorDBClass:
|
||||
def mock_vector_db_client(
|
||||
self,
|
||||
config=None,
|
||||
adapter: Any | None = None,
|
||||
):
|
||||
self.conn = MagicMock()
|
||||
self._config = MagicMock()
|
||||
|
||||
def list_databases(self, config=None) -> list[Database]:
|
||||
return [
|
||||
Database(
|
||||
conn=self.conn,
|
||||
database_name="dify",
|
||||
config=self._config,
|
||||
)
|
||||
]
|
||||
|
||||
def create_database(self, database_name: str, config=None) -> Database:
|
||||
return Database(conn=self.conn, database_name=database_name, config=config)
|
||||
|
||||
def list_table(self, config=None) -> list[Table]:
|
||||
return []
|
||||
|
||||
def drop_table(self, table_name: str, config=None):
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def create_table(
|
||||
self,
|
||||
table_name: str,
|
||||
replication: int,
|
||||
partition: int,
|
||||
schema,
|
||||
enable_dynamic_field=False,
|
||||
description: str = "",
|
||||
config=None,
|
||||
) -> Table:
|
||||
return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config)
|
||||
|
||||
def describe_table(self, table_name: str, config=None) -> Table:
|
||||
return Table(
|
||||
self,
|
||||
table_name,
|
||||
3,
|
||||
1,
|
||||
None,
|
||||
enable_dynamic_field=False,
|
||||
description="table for dify",
|
||||
config=config,
|
||||
state=TableState.NORMAL,
|
||||
)
|
||||
|
||||
def upsert(self, rows, config=None):
|
||||
return {"code": 0, "msg": "operation success", "affectedCount": 1}
|
||||
|
||||
def rebuild_index(self, index_name: str, config=None):
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def describe_index(self, index_name: str, config=None):
|
||||
return VectorIndex(
|
||||
index_name=index_name,
|
||||
index_type=IndexType.HNSW,
|
||||
field="vector",
|
||||
metric_type=MetricType.L2,
|
||||
params=HNSWParams(m=16, efconstruction=200),
|
||||
auto_build=False,
|
||||
state=IndexState.NORMAL,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
primary_key,
|
||||
partition_key=None,
|
||||
projections=None,
|
||||
retrieve_vector=False,
|
||||
read_consistency=ReadConsistency.EVENTUAL,
|
||||
config=None,
|
||||
):
|
||||
return AttrDict(
|
||||
{
|
||||
"row": {
|
||||
"id": primary_key.get("id"),
|
||||
"vector": [0.23432432, 0.8923744, 0.89238432],
|
||||
"page_content": "text",
|
||||
"metadata": {"doc_id": "doc_id_001"},
|
||||
},
|
||||
"code": 0,
|
||||
"msg": "Success",
|
||||
}
|
||||
)
|
||||
|
||||
def delete(self, primary_key=None, partition_key=None, filter=None, config=None):
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def search(
|
||||
self,
|
||||
anns,
|
||||
partition_key=None,
|
||||
projections=None,
|
||||
retrieve_vector=False,
|
||||
read_consistency=ReadConsistency.EVENTUAL,
|
||||
config=None,
|
||||
):
|
||||
return AttrDict(
|
||||
{
|
||||
"rows": [
|
||||
{
|
||||
"row": {
|
||||
"id": "doc_id_001",
|
||||
"vector": [0.23432432, 0.8923744, 0.89238432],
|
||||
"page_content": "text",
|
||||
"metadata": {"doc_id": "doc_id_001"},
|
||||
},
|
||||
"distance": 0.1,
|
||||
"score": 0.5,
|
||||
}
|
||||
],
|
||||
"code": 0,
|
||||
"msg": "Success",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client)
|
||||
monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases)
|
||||
monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database)
|
||||
monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table)
|
||||
monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table)
|
||||
monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table)
|
||||
monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table)
|
||||
monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table)
|
||||
monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index)
|
||||
monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index)
|
||||
monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete)
|
||||
monkeypatch.setattr(Table, "query", MockBaiduVectorDBClass.query)
|
||||
monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
|
||||
|
||||
class MockIndicesClient:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def create(self, index, mappings, settings):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def refresh(self, index):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def delete(self, index):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def exists(self, index):
|
||||
return True
|
||||
|
||||
|
||||
class MockClient:
|
||||
def __init__(self, **kwargs):
|
||||
self.indices = MockIndicesClient()
|
||||
|
||||
def index(self, **kwargs):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def exists(self, **kwargs):
|
||||
return True
|
||||
|
||||
def delete(self, **kwargs):
|
||||
return {"acknowledge": True}
|
||||
|
||||
def search(self, **kwargs):
|
||||
return {
|
||||
"took": 1,
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
Field.CONTENT_KEY: "abcdef",
|
||||
Field.VECTOR: [1, 2],
|
||||
Field.METADATA_KEY: {},
|
||||
},
|
||||
"_score": 1.0,
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
Field.CONTENT_KEY: "123456",
|
||||
Field.VECTOR: [2, 2],
|
||||
Field.METADATA_KEY: {},
|
||||
},
|
||||
"_score": 0.9,
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
Field.CONTENT_KEY: "a1b2c3",
|
||||
Field.VECTOR: [3, 2],
|
||||
Field.METADATA_KEY: {},
|
||||
},
|
||||
"_score": 0.8,
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_client_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Elasticsearch, "__init__", MockClient.__init__)
|
||||
monkeypatch.setattr(Elasticsearch, "index", MockClient.index)
|
||||
monkeypatch.setattr(Elasticsearch, "exists", MockClient.exists)
|
||||
monkeypatch.setattr(Elasticsearch, "delete", MockClient.delete)
|
||||
monkeypatch.setattr(Elasticsearch, "search", MockClient.search)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
192
dify/api/tests/integration_tests/vdb/__mock/tcvectordb.py
Normal file
192
dify/api/tests/integration_tests/vdb/__mock/tcvectordb.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import os
|
||||
from typing import Any, Union
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from tcvectordb import RPCVectorDBClient
|
||||
from tcvectordb.model import enum
|
||||
from tcvectordb.model.collection import FilterIndexConfig
|
||||
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank
|
||||
from tcvectordb.model.enum import ReadConsistency
|
||||
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex
|
||||
from tcvectordb.rpc.model.collection import RPCCollection
|
||||
from tcvectordb.rpc.model.database import RPCDatabase
|
||||
from xinference_client.types import Embedding
|
||||
|
||||
|
||||
class MockTcvectordbClass:
|
||||
def mock_vector_db_client(
|
||||
self,
|
||||
url: str,
|
||||
username="",
|
||||
key="",
|
||||
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
|
||||
timeout=10,
|
||||
adapter: Any | None = None,
|
||||
pool_size: int = 2,
|
||||
proxies: dict | None = None,
|
||||
password: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._conn = None
|
||||
self._read_consistency = read_consistency
|
||||
|
||||
def create_database_if_not_exists(self, database_name: str, timeout: float | None = None) -> RPCDatabase:
|
||||
return RPCDatabase(
|
||||
name="dify",
|
||||
read_consistency=self._read_consistency,
|
||||
)
|
||||
|
||||
def exists_collection(self, database_name: str, collection_name: str) -> bool:
|
||||
return True
|
||||
|
||||
def describe_collection(
|
||||
self, database_name: str, collection_name: str, timeout: float | None = None
|
||||
) -> RPCCollection:
|
||||
index = Index(
|
||||
FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
|
||||
VectorIndex(
|
||||
"vector",
|
||||
128,
|
||||
enum.IndexType.HNSW,
|
||||
enum.MetricType.IP,
|
||||
HNSWParams(m=16, efconstruction=200),
|
||||
),
|
||||
FilterIndex("text", enum.FieldType.String, enum.IndexType.FILTER),
|
||||
FilterIndex("metadata", enum.FieldType.String, enum.IndexType.FILTER),
|
||||
)
|
||||
return RPCCollection(
|
||||
RPCDatabase(
|
||||
name=database_name,
|
||||
read_consistency=self._read_consistency,
|
||||
),
|
||||
collection_name,
|
||||
index=index,
|
||||
)
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
shard: int,
|
||||
replicas: int,
|
||||
description: str | None = None,
|
||||
index: Index | None = None,
|
||||
embedding: Embedding | None = None,
|
||||
timeout: float | None = None,
|
||||
ttl_config: dict | None = None,
|
||||
filter_index_config: FilterIndexConfig | None = None,
|
||||
indexes: list[IndexField] | None = None,
|
||||
) -> RPCCollection:
|
||||
return RPCCollection(
|
||||
RPCDatabase(
|
||||
name="dify",
|
||||
read_consistency=self._read_consistency,
|
||||
),
|
||||
collection_name,
|
||||
shard,
|
||||
replicas,
|
||||
description,
|
||||
index,
|
||||
embedding=embedding,
|
||||
read_consistency=self._read_consistency,
|
||||
timeout=timeout,
|
||||
ttl_config=ttl_config,
|
||||
filter_index_config=filter_index_config,
|
||||
indexes=indexes,
|
||||
)
|
||||
|
||||
def collection_upsert(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
documents: list[Union[Document, dict]],
|
||||
timeout: float | None = None,
|
||||
build_index: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
def collection_search(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
vectors: list[list[float]],
|
||||
filter: Filter | None = None,
|
||||
params=None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: int = 10,
|
||||
output_fields: list[str] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> list[list[dict]]:
|
||||
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
|
||||
|
||||
def collection_hybrid_search(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
ann: Union[list[AnnSearch], AnnSearch] | None = None,
|
||||
match: Union[list[KeywordSearch], KeywordSearch] | None = None,
|
||||
filter: Union[Filter, str] | None = None,
|
||||
rerank: Rerank | None = None,
|
||||
retrieve_vector: bool | None = None,
|
||||
output_fields: list[str] | None = None,
|
||||
limit: int | None = None,
|
||||
timeout: float | None = None,
|
||||
return_pd_object=False,
|
||||
**kwargs,
|
||||
) -> list[list[dict]]:
|
||||
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
|
||||
|
||||
def collection_query(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
document_ids: list | None = None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
filter: Filter | None = None,
|
||||
output_fields: list[str] | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
|
||||
|
||||
def collection_delete(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
document_ids: list[str] | None = None,
|
||||
filter: Filter | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
def drop_collection(self, database_name: str, collection_name: str, timeout: float | None = None):
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(RPCVectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
|
||||
monkeypatch.setattr(
|
||||
RPCVectorDBClient, "create_database_if_not_exists", MockTcvectordbClass.create_database_if_not_exists
|
||||
)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "exists_collection", MockTcvectordbClass.exists_collection)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "create_collection", MockTcvectordbClass.create_collection)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "describe_collection", MockTcvectordbClass.describe_collection)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "upsert", MockTcvectordbClass.collection_upsert)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "search", MockTcvectordbClass.collection_search)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "hybrid_search", MockTcvectordbClass.collection_hybrid_search)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "query", MockTcvectordbClass.collection_query)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "delete", MockTcvectordbClass.collection_delete)
|
||||
monkeypatch.setattr(RPCVectorDBClient, "drop_collection", MockTcvectordbClass.drop_collection)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
from collections import UserDict
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from upstash_vector import Index
|
||||
|
||||
|
||||
# Mocking the Index class from upstash_vector
|
||||
class MockIndex:
|
||||
def __init__(self, url="", token=""):
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.vectors = []
|
||||
|
||||
def upsert(self, vectors):
|
||||
for vector in vectors:
|
||||
vector.score = 0.5
|
||||
self.vectors.append(vector)
|
||||
return {"code": 0, "msg": "operation success", "affectedCount": len(vectors)}
|
||||
|
||||
def fetch(self, ids):
|
||||
return [vector for vector in self.vectors if vector.id in ids]
|
||||
|
||||
def delete(self, ids):
|
||||
self.vectors = [vector for vector in self.vectors if vector.id not in ids]
|
||||
return {"code": 0, "msg": "Success"}
|
||||
|
||||
def query(
|
||||
self,
|
||||
vector: None,
|
||||
top_k: int = 10,
|
||||
include_vectors: bool = False,
|
||||
include_metadata: bool = False,
|
||||
filter: str = "",
|
||||
data: str | None = None,
|
||||
namespace: str = "",
|
||||
include_data: bool = False,
|
||||
):
|
||||
# Simple mock query, in real scenario you would calculate similarity
|
||||
mock_result = []
|
||||
for vector_data in self.vectors:
|
||||
mock_result.append(vector_data)
|
||||
return mock_result[:top_k]
|
||||
|
||||
def reset(self):
|
||||
self.vectors = []
|
||||
|
||||
def info(self):
|
||||
return AttrDict({"dimension": 1024})
|
||||
|
||||
|
||||
class AttrDict(UserDict):
|
||||
def __getattr__(self, item):
|
||||
return self.get(item)
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_upstashvector_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Index, "__init__", MockIndex.__init__)
|
||||
monkeypatch.setattr(Index, "upsert", MockIndex.upsert)
|
||||
monkeypatch.setattr(Index, "fetch", MockIndex.fetch)
|
||||
monkeypatch.setattr(Index, "delete", MockIndex.delete)
|
||||
monkeypatch.setattr(Index, "query", MockIndex.query)
|
||||
monkeypatch.setattr(Index, "reset", MockIndex.reset)
|
||||
monkeypatch.setattr(Index, "info", MockIndex.info)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
215
dify/api/tests/integration_tests/vdb/__mock/vikingdb.py
Normal file
215
dify/api/tests/integration_tests/vdb/__mock/vikingdb.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import os
|
||||
from typing import Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from volcengine.viking_db import (
|
||||
Collection,
|
||||
Data,
|
||||
DistanceType,
|
||||
Field,
|
||||
FieldType,
|
||||
Index,
|
||||
IndexType,
|
||||
QuantType,
|
||||
VectorIndexParams,
|
||||
VikingDBService,
|
||||
)
|
||||
|
||||
from core.rag.datasource.vdb.field import Field as vdb_Field
|
||||
|
||||
|
||||
class MockVikingDBClass:
|
||||
def __init__(
|
||||
self,
|
||||
host="api-vikingdb.volces.com",
|
||||
region="cn-north-1",
|
||||
ak="",
|
||||
sk="",
|
||||
scheme="http",
|
||||
connection_timeout=30,
|
||||
socket_timeout=30,
|
||||
proxy=None,
|
||||
):
|
||||
self._viking_db_service = MagicMock()
|
||||
self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}')
|
||||
|
||||
def get_collection(self, collection_name) -> Collection:
|
||||
return Collection(
|
||||
collection_name=collection_name,
|
||||
description="Collection For Dify",
|
||||
viking_db_service=self._viking_db_service,
|
||||
primary_key=vdb_Field.PRIMARY_KEY,
|
||||
fields=[
|
||||
Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True),
|
||||
Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text),
|
||||
Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=768),
|
||||
],
|
||||
indexes=[
|
||||
Index(
|
||||
collection_name=collection_name,
|
||||
index_name=f"{collection_name}_idx",
|
||||
vector_index=VectorIndexParams(
|
||||
distance=DistanceType.L2,
|
||||
index_type=IndexType.HNSW,
|
||||
quant=QuantType.Float,
|
||||
),
|
||||
scalar_index=None,
|
||||
stat=None,
|
||||
viking_db_service=self._viking_db_service,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def drop_collection(self, collection_name):
|
||||
assert collection_name != ""
|
||||
|
||||
def create_collection(self, collection_name, fields, description="") -> Collection:
|
||||
return Collection(
|
||||
collection_name=collection_name,
|
||||
description=description,
|
||||
primary_key=vdb_Field.PRIMARY_KEY,
|
||||
viking_db_service=self._viking_db_service,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
def get_index(self, collection_name, index_name) -> Index:
|
||||
return Index(
|
||||
collection_name=collection_name,
|
||||
index_name=index_name,
|
||||
viking_db_service=self._viking_db_service,
|
||||
stat=None,
|
||||
scalar_index=None,
|
||||
vector_index=VectorIndexParams(
|
||||
distance=DistanceType.L2,
|
||||
index_type=IndexType.HNSW,
|
||||
quant=QuantType.Float,
|
||||
),
|
||||
)
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
collection_name,
|
||||
index_name,
|
||||
vector_index=None,
|
||||
cpu_quota=2,
|
||||
description="",
|
||||
partition_by="",
|
||||
scalar_index=None,
|
||||
shard_count=None,
|
||||
shard_policy=None,
|
||||
):
|
||||
return Index(
|
||||
collection_name=collection_name,
|
||||
index_name=index_name,
|
||||
vector_index=vector_index,
|
||||
cpu_quota=cpu_quota,
|
||||
description=description,
|
||||
partition_by=partition_by,
|
||||
scalar_index=scalar_index,
|
||||
shard_count=shard_count,
|
||||
shard_policy=shard_policy,
|
||||
viking_db_service=self._viking_db_service,
|
||||
stat=None,
|
||||
)
|
||||
|
||||
def drop_index(self, collection_name, index_name):
|
||||
assert collection_name != ""
|
||||
assert index_name != ""
|
||||
|
||||
def upsert_data(self, data: Union[Data, list[Data]]):
|
||||
assert data is not None
|
||||
|
||||
def fetch_data(self, id: Union[str, list[str], int, list[int]]):
|
||||
return Data(
|
||||
fields={
|
||||
vdb_Field.GROUP_KEY: "test_group",
|
||||
vdb_Field.METADATA_KEY: "{}",
|
||||
vdb_Field.CONTENT_KEY: "content",
|
||||
vdb_Field.PRIMARY_KEY: id,
|
||||
vdb_Field.VECTOR: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
|
||||
},
|
||||
id=id,
|
||||
)
|
||||
|
||||
def delete_data(self, id: Union[str, list[str], int, list[int]]):
|
||||
assert id is not None
|
||||
|
||||
def search_by_vector(
|
||||
self,
|
||||
vector,
|
||||
sparse_vectors=None,
|
||||
filter=None,
|
||||
limit=10,
|
||||
output_fields=None,
|
||||
partition="default",
|
||||
dense_weight=None,
|
||||
) -> list[Data]:
|
||||
return [
|
||||
Data(
|
||||
fields={
|
||||
vdb_Field.GROUP_KEY: "test_group",
|
||||
vdb_Field.METADATA_KEY: '\
|
||||
{"source": "/var/folders/ml/xxx/xxx.txt", \
|
||||
"document_id": "test_document_id", \
|
||||
"dataset_id": "test_dataset_id", \
|
||||
"doc_id": "test_id", \
|
||||
"doc_hash": "test_hash"}',
|
||||
vdb_Field.CONTENT_KEY: "content",
|
||||
vdb_Field.PRIMARY_KEY: "test_id",
|
||||
vdb_Field.VECTOR: vector,
|
||||
},
|
||||
id="test_id",
|
||||
score=0.10,
|
||||
)
|
||||
]
|
||||
|
||||
def search(
|
||||
self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None
|
||||
) -> list[Data]:
|
||||
return [
|
||||
Data(
|
||||
fields={
|
||||
vdb_Field.GROUP_KEY: "test_group",
|
||||
vdb_Field.METADATA_KEY: '\
|
||||
{"source": "/var/folders/ml/xxx/xxx.txt", \
|
||||
"document_id": "test_document_id", \
|
||||
"dataset_id": "test_dataset_id", \
|
||||
"doc_id": "test_id", \
|
||||
"doc_hash": "test_hash"}',
|
||||
vdb_Field.CONTENT_KEY: "content",
|
||||
vdb_Field.PRIMARY_KEY: "test_id",
|
||||
vdb_Field.VECTOR: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],
|
||||
},
|
||||
id="test_id",
|
||||
score=0.10,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_vikingdb_mock(monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__)
|
||||
monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection)
|
||||
monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection)
|
||||
monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection)
|
||||
monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index)
|
||||
monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index)
|
||||
monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index)
|
||||
monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data)
|
||||
monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data)
|
||||
monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data)
|
||||
monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector)
|
||||
monkeypatch.setattr(Index, "search", MockVikingDBClass.search)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
Reference in New Issue
Block a user