This commit is contained in:
2025-12-01 17:21:38 +08:00
parent 32fee2b8ab
commit fab8c13cb3
7511 changed files with 996300 additions and 0 deletions

View File

View File

@@ -0,0 +1,262 @@
from collections import defaultdict
from typing import Any
import orjson
from pydantic import BaseModel
from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
class KeywordTableConfig(BaseModel):
max_keywords_per_chunk: int = 10
class Jieba(BaseKeyword):
def __init__(self, dataset: Dataset):
super().__init__(dataset)
self._config = KeywordTableConfig()
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(
keyword_table or {}, text.metadata["doc_id"], list(keywords)
)
self._save_dataset_keyword_table(keyword_table)
return self
def add_texts(self, texts: list[Document], **kwargs):
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get("keywords_list")
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
if not keywords:
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, keyword_number)
if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(
keyword_table or {}, text.metadata["doc_id"], list(keywords)
)
self._save_dataset_keyword_table(keyword_table)
def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
if keyword_table is None:
return False
return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]):
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
keyword_table = self._get_dataset_keyword_table()
if keyword_table is not None:
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
def search(self, query: str, **kwargs: Any) -> list[Document]:
keyword_table = self._get_dataset_keyword_table()
k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
documents = []
for chunk_index in sorted_chunk_indices:
segment_query = db.session.query(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
)
if document_ids_filter:
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first()
if segment:
documents.append(
Document(
page_content=segment.content,
metadata={
"doc_id": chunk_index,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
)
return documents
def delete(self):
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
if dataset_keyword_table.data_source_type != "database":
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
storage.delete(file_key)
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
}
dataset_keyword_table = self.dataset.dataset_keyword_table
keyword_data_source_type = dataset_keyword_table.data_source_type
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
db.session.commit()
else:
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8"))
def _get_dataset_keyword_table(self) -> dict | None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return dict(keyword_table_dict["__data__"]["table"])
else:
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table="",
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = dumps_with_sets(
{
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": {}},
}
)
db.session.add(dataset_keyword_table)
db.session.commit()
return {}
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]):
for keyword in keywords:
if keyword not in keyword_table:
keyword_table[keyword] = set()
keyword_table[keyword].add(id)
return keyword_table
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]):
# get set of ids that correspond to node
node_idxs_to_delete = set(ids)
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete = set()
for keyword, node_idxs in keyword_table.items():
if node_idxs_to_delete.intersection(node_idxs):
keyword_table[keyword] = node_idxs.difference(node_idxs_to_delete)
if not keyword_table[keyword]:
keywords_to_delete.add(keyword)
for keyword in keywords_to_delete:
del keyword_table[keyword]
return keyword_table
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
keyword_table_handler = JiebaKeywordTableHandler()
keywords = keyword_table_handler.extract_keywords(query)
# go through text chunks in order of most matching keywords
chunk_indices_count: dict[str, int] = defaultdict(int)
keywords_list = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
for keyword in keywords_list:
for node_id in keyword_table[keyword]:
chunk_indices_count[node_id] += 1
sorted_chunk_indices = sorted(
chunk_indices_count.keys(),
key=lambda x: chunk_indices_count[x],
reverse=True,
)
return sorted_chunk_indices[:k]
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id
)
document_segment = db.session.scalar(stmt)
if document_segment:
document_segment.keywords = keywords
db.session.add(document_segment)
db.session.commit()
def create_segment_keywords(self, node_id: str, keywords: list[str]):
keyword_table = self._get_dataset_keyword_table()
self._update_segment_keywords(self.dataset.id, node_id, keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
def multi_create_segment_keywords(self, pre_segment_data_list: list):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for pre_segment_data in pre_segment_data_list:
segment = pre_segment_data["segment"]
if pre_segment_data["keywords"]:
segment.keywords = pre_segment_data["keywords"]
keyword_table = self._add_text_to_keyword_table(
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
)
else:
keyword_number = self.dataset.keyword_number or self._config.max_keywords_per_chunk
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
segment.keywords = list(keywords)
keyword_table = self._add_text_to_keyword_table(
keyword_table or {}, segment.index_node_id, list(keywords)
)
self._save_dataset_keyword_table(keyword_table)
def update_segment_keywords_index(self, node_id: str, keywords: list[str]):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
def set_orjson_default(obj: Any):
"""Default function for orjson serialization of set types"""
if isinstance(obj, set):
return list(obj)
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
def dumps_with_sets(obj: Any) -> str:
"""JSON dumps with set support using orjson"""
return orjson.dumps(obj, default=set_orjson_default).decode("utf-8")

View File

@@ -0,0 +1,127 @@
import re
from operator import itemgetter
from typing import cast
class JiebaKeywordTableHandler:
def __init__(self):
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
tfidf = self._load_tfidf_extractor()
tfidf.stop_words = STOPWORDS # type: ignore[attr-defined]
self._tfidf = tfidf
def _load_tfidf_extractor(self):
"""
Load jieba TFIDF extractor with fallback strategy.
Loading Flow:
┌─────────────────────────────────────────────────────────────────────┐
│ jieba.analyse.default_tfidf │
│ exists? │
└─────────────────────────────────────────────────────────────────────┘
│ │
YES NO
│ │
▼ ▼
┌──────────────────┐ ┌──────────────────────────────────┐
│ Return default │ │ jieba.analyse.TFIDF exists? │
│ TFIDF │ └──────────────────────────────────┘
└──────────────────┘ │ │
YES NO
│ │
│ ▼
│ ┌────────────────────────────┐
│ │ Try import from │
│ │ jieba.analyse.tfidf.TFIDF │
│ └────────────────────────────┘
│ │ │
│ SUCCESS FAILED
│ │ │
▼ ▼ ▼
┌────────────────────────┐ ┌─────────────────┐
│ Instantiate TFIDF() │ │ Build fallback │
│ & cache to default │ │ _SimpleTFIDF │
└────────────────────────┘ └─────────────────┘
"""
import jieba.analyse # type: ignore
tfidf = getattr(jieba.analyse, "default_tfidf", None)
if tfidf is not None:
return tfidf
tfidf_class = getattr(jieba.analyse, "TFIDF", None)
if tfidf_class is None:
try:
from jieba.analyse.tfidf import TFIDF # type: ignore
tfidf_class = TFIDF
except Exception:
tfidf_class = None
if tfidf_class is not None:
tfidf = tfidf_class()
jieba.analyse.default_tfidf = tfidf # type: ignore[attr-defined]
return tfidf
return self._build_fallback_tfidf()
@staticmethod
def _build_fallback_tfidf():
"""Fallback lightweight TFIDF for environments missing jieba's TFIDF."""
import jieba # type: ignore
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
class _SimpleTFIDF:
def __init__(self):
self.stop_words = STOPWORDS
self._lcut = getattr(jieba, "lcut", None)
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
top_k = kwargs.pop("topK", top_k)
cut = getattr(jieba, "cut", None)
if self._lcut:
tokens = self._lcut(sentence)
elif callable(cut):
tokens = list(cut(sentence))
else:
tokens = re.findall(r"\w+", sentence)
words = [w for w in tokens if w and w not in self.stop_words]
freq: dict[str, int] = {}
for w in words:
freq[w] = freq.get(w, 0) + 1
sorted_words = sorted(freq.items(), key=itemgetter(1), reverse=True)
if top_k is not None:
sorted_words = sorted_words[:top_k]
return [item[0] for item in sorted_words]
return _SimpleTFIDF()
def extract_keywords(self, text: str, max_keywords_per_chunk: int | None = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = self._tfidf.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
)
# jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
keywords = cast(list[str], keywords)
return set(self._expand_tokens_with_subtokens(set(keywords)))
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
results = set()
for token in tokens:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
from core.rag.models.document import Document
from models.dataset import Dataset
class BaseKeyword(ABC):
def __init__(self, dataset: Dataset):
self.dataset = dataset
@abstractmethod
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
raise NotImplementedError
@abstractmethod
def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError
@abstractmethod
def text_exists(self, id: str) -> bool:
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]):
raise NotImplementedError
@abstractmethod
def delete(self):
raise NotImplementedError
@abstractmethod
def search(self, query: str, **kwargs: Any) -> list[Document]:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts.copy():
if text.metadata is None:
continue
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata["doc_id"] for text in texts if text.metadata]

View File

@@ -0,0 +1,54 @@
from typing import Any
from configs import dify_config
from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.datasource.keyword.keyword_type import KeyWordType
from core.rag.models.document import Document
from models.dataset import Dataset
class Keyword:
def __init__(self, dataset: Dataset):
self._dataset = dataset
self._keyword_processor = self._init_keyword()
def _init_keyword(self) -> BaseKeyword:
keyword_type = dify_config.KEYWORD_STORE
keyword_factory = self.get_keyword_factory(keyword_type)
return keyword_factory(self._dataset)
@staticmethod
def get_keyword_factory(keyword_type: str) -> type[BaseKeyword]:
match keyword_type:
case KeyWordType.JIEBA:
from core.rag.datasource.keyword.jieba.jieba import Jieba
return Jieba
case _:
raise ValueError(f"Keyword store {keyword_type} is not supported.")
def create(self, texts: list[Document], **kwargs):
self._keyword_processor.create(texts, **kwargs)
def add_texts(self, texts: list[Document], **kwargs):
self._keyword_processor.add_texts(texts, **kwargs)
def text_exists(self, id: str) -> bool:
return self._keyword_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]):
self._keyword_processor.delete_by_ids(ids)
def delete(self):
self._keyword_processor.delete()
def search(self, query: str, **kwargs: Any) -> list[Document]:
return self._keyword_processor.search(query, **kwargs)
def __getattr__(self, name):
if self._keyword_processor is not None:
method = getattr(self._keyword_processor, name)
if callable(method):
return method
raise AttributeError(f"'Keyword' object has no attribute '{name}'")

View File

@@ -0,0 +1,5 @@
from enum import StrEnum
class KeyWordType(StrEnum):
JIEBA = "jieba"

View File

@@ -0,0 +1,465 @@
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
"score_threshold_enabled": False,
}
class RetrievalService:
# Cache precompiled regular expressions to avoid repeated compilation
@classmethod
def retrieve(
cls,
retrieval_method: RetrievalMethod,
dataset_id: str,
query: str,
top_k: int,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
document_ids_filter: list[str] | None = None,
):
if not query:
return []
dataset = cls._get_dataset(dataset_id)
if not dataset:
return []
all_documents: list[Document] = []
exceptions: list[str] = []
# Optimize multithreading with thread pools
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
futures = []
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
futures.append(
executor.submit(
cls.keyword_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
all_documents=all_documents,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_semantic_search(retrieval_method):
futures.append(
executor.submit(
cls.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
futures.append(
executor.submit(
cls.full_text_index_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED)
if exceptions:
raise ValueError(";\n".join(exceptions))
# Deduplicate documents for hybrid search to avoid duplicate chunks
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
all_documents = cls._deduplicate_documents(all_documents)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k,
)
return all_documents
@classmethod
def external_retrieve(
cls,
dataset_id: str,
query: str,
external_retrieval_model: dict | None = None,
metadata_filtering_conditions: dict | None = None,
):
stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
return []
metadata_condition = (
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
)
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id,
dataset_id,
query,
external_retrieval_model or {},
metadata_condition=metadata_condition,
)
return all_documents
@classmethod
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search."""
if not documents:
return documents
unique_documents = []
seen_doc_ids = set()
for document in documents:
# For dify provider documents, use doc_id for deduplication
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata:
doc_id = document.metadata["doc_id"]
if doc_id not in seen_doc_ids:
seen_doc_ids.add(doc_id)
unique_documents.append(document)
# If duplicate, keep the one with higher score
elif "score" in document.metadata:
# Find existing document with same doc_id and compare scores
for i, existing_doc in enumerate(unique_documents):
if (
existing_doc.metadata
and existing_doc.metadata.get("doc_id") == doc_id
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0)
):
unique_documents[i] = document
break
else:
# For non-dify documents, use content-based deduplication
if document not in unique_documents:
unique_documents.append(document)
return unique_documents
@classmethod
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
with Session(db.engine) as session:
return session.query(Dataset).where(Dataset.id == dataset_id).first()
@classmethod
def keyword_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
all_documents: list,
exceptions: list,
document_ids_filter: list[str] | None = None,
):
with flask_app.app_context():
try:
dataset = cls._get_dataset(dataset_id)
if not dataset:
raise ValueError("dataset not found")
keyword = Keyword(dataset=dataset)
documents = keyword.search(
cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter
)
all_documents.extend(documents)
except Exception as e:
exceptions.append(str(e))
@classmethod
def embedding_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: float | None,
reranking_model: dict | None,
all_documents: list,
retrieval_method: RetrievalMethod,
exceptions: list,
document_ids_filter: list[str] | None = None,
):
with flask_app.app_context():
try:
dataset = cls._get_dataset(dataset_id)
if not dataset:
raise ValueError("dataset not found")
vector = Vector(dataset=dataset)
documents = vector.search_by_vector(
query,
search_type="similarity_score_threshold",
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
)
)
else:
all_documents.extend(documents)
except Exception as e:
exceptions.append(str(e))
@classmethod
def full_text_index_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: float | None,
reranking_model: dict | None,
all_documents: list,
retrieval_method: str,
exceptions: list,
document_ids_filter: list[str] | None = None,
):
with flask_app.app_context():
try:
dataset = cls._get_dataset(dataset_id)
if not dataset:
raise ValueError("dataset not found")
vector_processor = Vector(dataset=dataset)
documents = vector_processor.search_by_full_text(
cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter
)
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
)
)
else:
all_documents.extend(documents)
except Exception as e:
exceptions.append(str(e))
@staticmethod
def escape_query_for_search(query: str) -> str:
return query.replace('"', '\\"')
@classmethod
def format_retrieval_documents(cls, documents: list[Document]) -> list[RetrievalSegments]:
"""Format retrieval documents with optimized batch processing"""
if not documents:
return []
try:
# Collect document IDs
document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata}
if not document_ids:
return []
# Batch query dataset documents
dataset_documents = {
doc.id: doc
for doc in db.session.query(DatasetDocument)
.where(DatasetDocument.id.in_(document_ids))
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
.all()
}
records = []
include_segment_ids = set()
segment_child_map = {}
# Process documents
for document in documents:
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = db.session.scalar(child_chunk_stmt)
if not child_chunk:
continue
segment = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == child_chunk.segment_id,
)
.options(
load_only(
DocumentSegment.id,
DocumentSegment.content,
DocumentSegment.answer,
)
)
.first()
)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
# Handle normal documents
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = db.session.scalar(document_segment_stmt)
if not segment:
continue
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
}
records.append(record)
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"]
result = []
for record in records:
# Extract segment
segment = record["segment"]
# Extract child_chunks, ensuring it's a list or None
child_chunks = record.get("child_chunks")
if not isinstance(child_chunks, list):
child_chunks = None
# Extract score, ensuring it's a float or None
score_value = record.get("score")
score = (
float(score_value)
if score_value is not None and isinstance(score_value, int | float | str)
else None
)
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score)
result.append(retrieval_segment)
return result
except Exception as e:
db.session.rollback()
raise e

View File

@@ -0,0 +1,388 @@
import hashlib
import json
import logging
import uuid
from contextlib import contextmanager
from typing import Any, Literal, cast
import mysql.connector
from mysql.connector import Error as MySQLError
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class AlibabaCloudMySQLVectorConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
max_connection: int
charset: str = "utf8mb4"
distance_function: Literal["cosine", "euclidean"] = "cosine"
hnsw_m: int = 6
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values.get("host"):
raise ValueError("config ALIBABACLOUD_MYSQL_HOST is required")
if not values.get("port"):
raise ValueError("config ALIBABACLOUD_MYSQL_PORT is required")
if not values.get("user"):
raise ValueError("config ALIBABACLOUD_MYSQL_USER is required")
if values.get("password") is None:
raise ValueError("config ALIBABACLOUD_MYSQL_PASSWORD is required")
if not values.get("database"):
raise ValueError("config ALIBABACLOUD_MYSQL_DATABASE is required")
if not values.get("max_connection"):
raise ValueError("config ALIBABACLOUD_MYSQL_MAX_CONNECTION is required")
return values
SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} (
id VARCHAR(36) PRIMARY KEY,
text LONGTEXT NOT NULL,
meta JSON NOT NULL,
embedding VECTOR({dimension}) NOT NULL,
VECTOR INDEX (embedding) M={hnsw_m} DISTANCE={distance_function}
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
"""
SQL_CREATE_META_INDEX = """
CREATE INDEX idx_{index_hash}_meta ON {table_name}
((CAST(JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) AS CHAR(36))));
"""
SQL_CREATE_FULLTEXT_INDEX = """
CREATE FULLTEXT INDEX idx_{index_hash}_text ON {table_name} (text) WITH PARSER ngram;
"""
class AlibabaCloudMySQLVector(BaseVector):
def __init__(self, collection_name: str, config: AlibabaCloudMySQLVectorConfig):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = collection_name.lower()
self.index_hash = hashlib.md5(self.table_name.encode()).hexdigest()[:8]
self.distance_function = config.distance_function.lower()
self.hnsw_m = config.hnsw_m
self._check_vector_support()
def get_type(self) -> str:
return VectorType.ALIBABACLOUD_MYSQL
def _create_connection_pool(self, config: AlibabaCloudMySQLVectorConfig):
# Create connection pool using mysql-connector-python pooling
pool_config: dict[str, Any] = {
"host": config.host,
"port": config.port,
"user": config.user,
"password": config.password,
"database": config.database,
"charset": config.charset,
"autocommit": True,
"pool_name": f"pool_{self.collection_name}",
"pool_size": config.max_connection,
"pool_reset_session": True,
}
return mysql.connector.pooling.MySQLConnectionPool(**pool_config)
def _check_vector_support(self):
"""Check if the MySQL server supports vector operations."""
try:
with self._get_cursor() as cur:
# Check MySQL version and vector support
cur.execute("SELECT VERSION()")
version = cur.fetchone()["VERSION()"]
logger.debug("Connected to MySQL version: %s", version)
# Try to execute a simple vector function to verify support
cur.execute("SELECT VEC_FromText('[1,2,3]') IS NOT NULL as vector_support")
result = cur.fetchone()
if not result or not result.get("vector_support"):
raise ValueError(
"RDS MySQL Vector functions are not available."
" Please ensure you're using RDS MySQL 8.0.36+ with Vector support."
)
except MySQLError as e:
if "FUNCTION" in str(e) and "VEC_FromText" in str(e):
raise ValueError(
"RDS MySQL Vector functions are not available."
" Please ensure you're using RDS MySQL 8.0.36+ with Vector support."
) from e
raise e
@contextmanager
def _get_cursor(self):
conn = self.pool.get_connection()
cur = conn.cursor(dictionary=True)
try:
yield cur
finally:
cur.close()
conn.close()
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
pks = []
for i, doc in enumerate(documents):
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
# Convert embedding list to Aliyun MySQL vector format
vector_str = "[" + ",".join(map(str, embeddings[i])) + "]"
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
vector_str,
)
)
with self._get_cursor() as cur:
insert_sql = (
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (%s, %s, %s, VEC_FromText(%s))"
)
cur.executemany(insert_sql, values)
return pks
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
return cur.fetchone() is not None
def get_by_ids(self, ids: list[str]) -> list[Document]:
if not ids:
return []
with self._get_cursor() as cur:
placeholders = ",".join(["%s"] * len(ids))
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
docs = []
for record in cur:
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
docs.append(Document(page_content=record["text"], metadata=metadata))
return docs
def delete_by_ids(self, ids: list[str]):
# Avoiding crashes caused by performing delete operations on empty lists
if not ids:
return
with self._get_cursor() as cur:
try:
placeholders = ",".join(["%s"] * len(ids))
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
except MySQLError as e:
if e.errno == 1146: # Table doesn't exist
logger.warning("Table %s not found, skipping delete operation.", self.table_name)
return
else:
raise e
def delete_by_metadata_field(self, key: str, value: str):
with self._get_cursor() as cur:
cur.execute(
f"DELETE FROM {self.table_name} WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, %s)) = %s", (f"$.{key}", value)
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search the nearest neighbors to a vector using RDS MySQL vector distance functions.
:param query_vector: The input vector to search for similar items.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
params = []
if document_ids_filter:
placeholders = ",".join(["%s"] * len(document_ids_filter))
where_clause = f" WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) IN ({placeholders}) "
params.extend(document_ids_filter)
# Convert query vector to RDS MySQL vector format
query_vector_str = "[" + ",".join(map(str, query_vector)) + "]"
# Use RSD MySQL's native vector distance functions
with self._get_cursor() as cur:
# Choose distance function based on configuration
distance_func = "VEC_DISTANCE_COSINE" if self.distance_function == "cosine" else "VEC_DISTANCE_EUCLIDEAN"
# Note: RDS MySQL optimizer will use vector index when ORDER BY + LIMIT are present
# Use column alias in ORDER BY to avoid calculating distance twice
sql = f"""
SELECT meta, text,
{distance_func}(embedding, VEC_FromText(%s)) AS distance
FROM {self.table_name}
{where_clause}
ORDER BY distance
LIMIT %s
"""
query_params = [query_vector_str] + params + [top_k]
cur.execute(sql, query_params)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
try:
distance = float(record["distance"])
# Convert distance to similarity score
if self.distance_function == "cosine":
# For cosine distance: similarity = 1 - distance
similarity = 1.0 - distance
else:
# For euclidean distance: use inverse relationship
# similarity = 1 / (1 + distance)
similarity = 1.0 / (1.0 + distance)
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
metadata["score"] = similarity
metadata["distance"] = distance
if similarity >= score_threshold:
docs.append(Document(page_content=record["text"], metadata=metadata))
except (ValueError, json.JSONDecodeError) as e:
logger.warning("Error processing search result: %s", e)
continue
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
params = []
if document_ids_filter:
placeholders = ",".join(["%s"] * len(document_ids_filter))
where_clause = f" AND JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) IN ({placeholders}) "
params.extend(document_ids_filter)
with self._get_cursor() as cur:
# Build query parameters: query (twice for MATCH clauses), document_ids_filter (if any), top_k
query_params = [query, query] + params + [top_k]
cur.execute(
f"""SELECT meta, text,
MATCH(text) AGAINST(%s IN NATURAL LANGUAGE MODE) AS score
FROM {self.table_name}
WHERE MATCH(text) AGAINST(%s IN NATURAL LANGUAGE MODE)
{where_clause}
ORDER BY score DESC
LIMIT %s""",
query_params,
)
docs = []
for record in cur:
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
metadata["score"] = float(record["score"])
docs.append(Document(page_content=record["text"], metadata=metadata))
return docs
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
def _create_collection(self, dimension: int):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{collection_exist_cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
# Create table with vector column and vector index
cur.execute(
SQL_CREATE_TABLE.format(
table_name=self.table_name,
dimension=dimension,
distance_function=self.distance_function,
hnsw_m=self.hnsw_m,
)
)
# Create metadata index (check if exists first)
try:
cur.execute(SQL_CREATE_META_INDEX.format(table_name=self.table_name, index_hash=self.index_hash))
except MySQLError as e:
if e.errno != 1061: # Duplicate key name
logger.warning("Could not create meta index: %s", e)
# Create full-text index for text search
try:
cur.execute(
SQL_CREATE_FULLTEXT_INDEX.format(table_name=self.table_name, index_hash=self.index_hash)
)
except MySQLError as e:
if e.errno != 1061: # Duplicate key name
logger.warning("Could not create fulltext index: %s", e)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class AlibabaCloudMySQLVectorFactory(AbstractVectorFactory):
def _validate_distance_function(self, distance_function: str) -> Literal["cosine", "euclidean"]:
"""Validate and return the distance function as a proper Literal type."""
if distance_function not in ["cosine", "euclidean"]:
raise ValueError(f"Invalid distance function: {distance_function}. Must be 'cosine' or 'euclidean'")
return cast(Literal["cosine", "euclidean"], distance_function)
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AlibabaCloudMySQLVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ALIBABACLOUD_MYSQL, collection_name)
)
return AlibabaCloudMySQLVector(
collection_name=collection_name,
config=AlibabaCloudMySQLVectorConfig(
host=dify_config.ALIBABACLOUD_MYSQL_HOST or "localhost",
port=dify_config.ALIBABACLOUD_MYSQL_PORT,
user=dify_config.ALIBABACLOUD_MYSQL_USER or "root",
password=dify_config.ALIBABACLOUD_MYSQL_PASSWORD or "",
database=dify_config.ALIBABACLOUD_MYSQL_DATABASE or "dify",
max_connection=dify_config.ALIBABACLOUD_MYSQL_MAX_CONNECTION,
charset=dify_config.ALIBABACLOUD_MYSQL_CHARSET or "utf8mb4",
distance_function=self._validate_distance_function(
dify_config.ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION or "cosine"
),
hnsw_m=dify_config.ALIBABACLOUD_MYSQL_HNSW_M or 6,
),
)

View File

@@ -0,0 +1,104 @@
import json
from typing import Any
from configs import dify_config
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from models.dataset import Dataset
class AnalyticdbVector(BaseVector):
def __init__(
self,
collection_name: str,
api_config: AnalyticdbVectorOpenAPIConfig | None,
sql_config: AnalyticdbVectorBySqlConfig | None,
):
super().__init__(collection_name)
if api_config is not None:
self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI(
collection_name, api_config
)
else:
if sql_config is None:
raise ValueError("Either api_config or sql_config must be provided")
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
def get_type(self) -> str:
return VectorType.ANALYTICDB
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self.analyticdb_vector._create_collection_if_not_exists(dimension)
self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
self.analyticdb_vector.add_texts(documents, embeddings)
def text_exists(self, id: str) -> bool:
return self.analyticdb_vector.text_exists(id)
def delete_by_ids(self, ids: list[str]):
self.analyticdb_vector.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str):
self.analyticdb_vector.delete_by_metadata_field(key, value)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
return self.analyticdb_vector.search_by_vector(query_vector, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
def delete(self):
self.analyticdb_vector.delete()
class AnalyticdbVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
if dify_config.ANALYTICDB_HOST is None:
# implemented through OpenAPI
apiConfig = AnalyticdbVectorOpenAPIConfig(
access_key_id=dify_config.ANALYTICDB_KEY_ID or "",
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "",
region_id=dify_config.ANALYTICDB_REGION_ID or "",
instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "",
account=dify_config.ANALYTICDB_ACCOUNT or "",
account_password=dify_config.ANALYTICDB_PASSWORD or "",
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
)
sqlConfig = None
else:
# implemented through sql
sqlConfig = AnalyticdbVectorBySqlConfig(
host=dify_config.ANALYTICDB_HOST,
port=dify_config.ANALYTICDB_PORT,
account=dify_config.ANALYTICDB_ACCOUNT or "",
account_password=dify_config.ANALYTICDB_PASSWORD or "",
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
)
apiConfig = None
return AnalyticdbVector(
collection_name,
apiConfig,
sqlConfig,
)

View File

@@ -0,0 +1,321 @@
import json
from typing import Any
from pydantic import BaseModel, model_validator
_import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbVectorOpenAPIConfig(BaseModel):
access_key_id: str
access_key_secret: str
region_id: str
instance_id: str
account: str
account_password: str
namespace: str = "dify"
namespace_password: str | None = None
metrics: str = "cosine"
read_timeout: int = 60000
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["access_key_id"]:
raise ValueError("config ANALYTICDB_KEY_ID is required")
if not values["access_key_secret"]:
raise ValueError("config ANALYTICDB_KEY_SECRET is required")
if not values["region_id"]:
raise ValueError("config ANALYTICDB_REGION_ID is required")
if not values["instance_id"]:
raise ValueError("config ANALYTICDB_INSTANCE_ID is required")
if not values["account"]:
raise ValueError("config ANALYTICDB_ACCOUNT is required")
if not values["account_password"]:
raise ValueError("config ANALYTICDB_PASSWORD is required")
if not values["namespace_password"]:
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
return values
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
class AnalyticdbVectorOpenAPI:
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
try:
from alibabacloud_gpdb20160503.client import Client # type: ignore
from alibabacloud_tea_openapi import models as open_api_models # type: ignore
except:
raise ImportError(_import_err_msg)
self._collection_name = collection_name.lower()
self.config = config
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
def _initialize(self):
cache_key = f"vector_initialize_{self.config.instance_id}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
database_exist_cache_key = f"vector_initialize_{self.config.instance_id}"
if redis_client.get(database_exist_cache_key):
return
self._initialize_vector_database()
self._create_namespace_if_not_exists()
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.init_vector_database(request)
def _create_namespace_if_not_exists(self):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException # type: ignore
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.describe_namespace(request)
except TeaException as e:
if e.statusCode == 404:
request = gpdb_20160503_models.CreateNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
)
self._client.create_namespace(request)
else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
try:
request = gpdb_20160503_models.DescribeCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
)
self._client.describe_collection(request)
except TeaException as e:
if e.statusCode == 404:
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
full_text_retrieval_fields = "page_content"
request = gpdb_20160503_models.CreateCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
collection=self._collection_name,
dimension=embedding_dimension,
metrics=self.config.metrics,
metadata=metadata,
full_text_retrieval_fields=full_text_retrieval_fields,
)
self._client.create_collection(request)
else:
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
if doc.metadata is not None:
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
rows=rows,
)
self._client.upsert_collection_data(request)
def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None, # ty: ignore [invalid-argument-type]
content=None, # ty: ignore [invalid-argument-type]
top_k=1,
filter=f"ref_doc_id='{id}'",
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None, # ty: ignore [invalid-argument-type]
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
def delete_by_metadata_field(self, key: str, value: str):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None, # ty: ignore [invalid-argument-type]
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause += f"metadata_->>'document_id' IN ({document_ids})"
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None, # ty: ignore [invalid-argument-type]
top_k=kwargs.get("top_k", 4),
filter=where_clause,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.values.value,
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause += f"metadata_->>'document_id' IN ({document_ids})"
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None, # ty: ignore [invalid-argument-type]
content=query,
top_k=kwargs.get("top_k", 4),
filter=where_clause,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.values.value,
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def delete(self):
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
region_id=self.config.region_id,
)
self._client.delete_collection(request)
except Exception as e:
raise e

View File

@@ -0,0 +1,275 @@
import json
import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbVectorBySqlConfig(BaseModel):
host: str
port: int
account: str
account_password: str
min_connection: int
max_connection: int
namespace: str = "dify"
metrics: str = "cosine"
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config ANALYTICDB_HOST is required")
if not values["port"]:
raise ValueError("config ANALYTICDB_PORT is required")
if not values["account"]:
raise ValueError("config ANALYTICDB_ACCOUNT is required")
if not values["account_password"]:
raise ValueError("config ANALYTICDB_PASSWORD is required")
if not values["min_connection"]:
raise ValueError("config ANALYTICDB_MIN_CONNECTION is required")
if not values["max_connection"]:
raise ValueError("config ANALYTICDB_MAX_CONNECTION is required")
if values["min_connection"] > values["max_connection"]:
raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION")
return values
class AnalyticdbVectorBySql:
def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig):
self._collection_name = collection_name.lower()
self.databaseName = "knowledgebase"
self.config = config
self.table_name = f"{self.config.namespace}.{self._collection_name}"
self.pool = None
self._initialize()
if not self.pool:
self.pool = self._create_connection_pool()
def _initialize(self):
cache_key = f"vector_initialize_{self.config.host}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
database_exist_cache_key = f"vector_initialize_{self.config.host}"
if redis_client.get(database_exist_cache_key):
return
self._initialize_vector_database()
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _create_connection_pool(self):
return psycopg2.pool.SimpleConnectionPool(
self.config.min_connection,
self.config.max_connection,
host=self.config.host,
port=self.config.port,
user=self.config.account,
password=self.config.account_password,
database=self.databaseName,
)
@contextmanager
def _get_cursor(self):
assert self.pool is not None, "Connection pool is not initialized"
conn = self.pool.getconn()
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
self.pool.putconn(conn)
def _initialize_vector_database(self):
conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
user=self.config.account,
password=self.config.account_password,
database="postgres",
)
conn.autocommit = True
cur = conn.cursor()
try:
cur.execute(f"CREATE DATABASE {self.databaseName}")
except Exception as e:
if "already exists" not in str(e):
raise e
finally:
cur.close()
conn.close()
self.pool = self._create_connection_pool()
with self._get_cursor() as cur:
conn = cur.connection
try:
cur.execute("CREATE EXTENSION IF NOT EXISTS zhparser;")
except Exception as e:
conn.rollback()
raise RuntimeError(
"Failed to create zhparser extension. Please ensure it is available in your AnalyticDB."
) from e
try:
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
except Exception as e:
conn.rollback()
if "already exists" not in str(e):
raise e
cur.execute(
"CREATE OR REPLACE FUNCTION "
"public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) "
"RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ "
"SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) "
"FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) "
"AS words_only;$function$"
)
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
cur.execute(
f"CREATE TABLE IF NOT EXISTS {self.table_name}("
f"id text PRIMARY KEY,"
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
f"to_tsvector TSVECTOR"
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
)
if embedding_dimension is not None:
index_name = f"{self._collection_name}_embedding_idx"
try:
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
cur.execute(
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
f"pq_enable=0, external_storage=0)"
)
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
except Exception as e:
if "already exists" not in str(e):
raise e
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
id_prefix = str(uuid.uuid4()) + "_"
sql = f"""
INSERT INTO {self.table_name}
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
"""
for i, doc in enumerate(documents):
if doc.metadata is not None:
values.append(
(
id_prefix + str(i),
doc.metadata.get("doc_id", str(uuid.uuid4())),
embeddings[i],
doc.page_content,
json.dumps(doc.metadata),
doc.page_content,
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_batch(cur, sql, values)
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
return cur.fetchone() is not None
def delete_by_ids(self, ids: list[str]):
if not ids:
return
with self._get_cursor() as cur:
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id = ANY(%s)", (ids,))
except Exception as e:
if "does not exist" not in str(e):
raise e
def delete_by_metadata_field(self, key: str, value: str):
with self._get_cursor() as cur:
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
except Exception as e:
if "does not exist" not in str(e):
raise e
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = "WHERE 1=1"
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
score_threshold = float(kwargs.get("score_threshold") or 0.0)
with self._get_cursor() as cur:
query_vector_str = json.dumps(query_vector)
query_vector_str = "{" + query_vector_str[1:-1] + "}"
cur.execute(
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
f"t.page_content as page_content, t.metadata_ AS metadata_ "
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t",
(query_vector_str,),
)
documents = []
for record in cur:
_, vector, score, page_content, metadata = record
if score >= score_threshold:
metadata["score"] = score
doc = Document(
page_content=page_content,
vector=vector,
metadata=metadata,
)
documents.append(doc)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
with self._get_cursor() as cur:
cur.execute(
f"""SELECT id, vector, page_content, metadata_,
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
ORDER BY score DESC, id DESC
LIMIT {top_k}""",
(f"'{query}'", f"'{query}'"),
)
documents = []
for record in cur:
_, vector, page_content, metadata, score = record
metadata["score"] = score
doc = Document(
page_content=page_content,
vector=vector,
metadata=metadata,
)
documents.append(doc)
return documents
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View File

@@ -0,0 +1,373 @@
import json
import logging
import time
import uuid
from typing import Any
import numpy as np
from pydantic import BaseModel, model_validator
from pymochow import MochowClient # type: ignore
from pymochow.auth.bce_credentials import BceCredentials # type: ignore
from pymochow.configuration import Configuration # type: ignore
from pymochow.exception import ServerError # type: ignore
from pymochow.model.database import Database
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
from pymochow.model.schema import (
Field,
FilteringIndex,
HNSWParams,
InvertedIndex,
InvertedIndexAnalyzer,
InvertedIndexFieldAttribute,
InvertedIndexParams,
InvertedIndexParseMode,
Schema,
VectorIndex,
) # type: ignore
from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams, Partition, Row # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import Field as VDBField
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class BaiduConfig(BaseModel):
endpoint: str
connection_timeout_in_mills: int = 30 * 1000
account: str
api_key: str
database: str
index_type: str = "HNSW"
metric_type: str = "IP"
shard: int = 1
replicas: int = 3
inverted_index_analyzer: str = "DEFAULT_ANALYZER"
inverted_index_parser_mode: str = "COARSE_MODE"
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["endpoint"]:
raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required")
if not values["account"]:
raise ValueError("config BAIDU_VECTOR_DB_ACCOUNT is required")
if not values["api_key"]:
raise ValueError("config BAIDU_VECTOR_DB_API_KEY is required")
if not values["database"]:
raise ValueError("config BAIDU_VECTOR_DB_DATABASE is required")
return values
class BaiduVector(BaseVector):
vector_index: str = "vector_idx"
filtering_index: str = "filtering_idx"
inverted_index: str = "content_inverted_idx"
def __init__(self, collection_name: str, config: BaiduConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._db = self._init_database()
def get_type(self) -> str:
return VectorType.BAIDU
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._create_table(len(embeddings[0]))
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
total_count = len(documents)
batch_size = 1000
# upsert texts and embeddings batch by batch
table = self._db.table(self._collection_name)
for start in range(0, total_count, batch_size):
end = min(start + batch_size, total_count)
rows = []
for i in range(start, end, 1):
metadata = documents[i].metadata
row = Row(
id=metadata.get("doc_id", str(uuid.uuid4())),
page_content=documents[i].page_content,
metadata=metadata,
vector=embeddings[i],
)
rows.append(row)
table.upsert(rows=rows)
# rebuild vector index after upsert finished
table.rebuild_index(self.vector_index)
timeout = 3600 # 1 hour timeout
start_time = time.time()
while True:
time.sleep(1)
index = table.describe_index(self.vector_index)
if index.state == IndexState.NORMAL:
break
if time.time() - start_time > timeout:
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
def text_exists(self, id: str) -> bool:
res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id})
if res and res.code == 0:
return True
return False
def delete_by_ids(self, ids: list[str]):
if not ids:
return
quoted_ids = [f"'{id}'" for id in ids]
self._db.table(self._collection_name).delete(filter=f"{VDBField.PRIMARY_KEY} IN({', '.join(quoted_ids)})")
def delete_by_metadata_field(self, key: str, value: str):
# Escape double quotes in value to prevent injection
escaped_value = value.replace('"', '\\"')
self._db.table(self._collection_name).delete(filter=f'metadata["{key}"] = "{escaped_value}"')
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f'metadata["document_id"] IN({document_ids})'
anns = AnnSearch(
vector_field=VDBField.VECTOR,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 20), limit=kwargs.get("top_k", 4)),
filter=filter,
)
res = self._db.table(self._collection_name).search(
anns=anns,
projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY],
retrieve_vector=False,
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# document ids filter
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f'metadata["document_id"] IN({document_ids})'
request = BM25SearchRequest(
index_name=self.inverted_index, search_text=query, limit=kwargs.get("top_k", 4), filter=filter
)
res = self._db.table(self._collection_name).bm25_search(
request=request, projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY]
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)
def _get_search_res(self, res, score_threshold) -> list[Document]:
docs = []
for row in res.rows:
row_data = row.get("row", {})
score = row.get("score", 0.0)
meta = row_data.get(VDBField.METADATA_KEY, {})
# Handle both JSON string and dict formats for backward compatibility
if isinstance(meta, str):
try:
import json
meta = json.loads(meta)
except (json.JSONDecodeError, TypeError):
meta = {}
elif not isinstance(meta, dict):
meta = {}
if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=row_data.get(VDBField.CONTENT_KEY), metadata=meta)
docs.append(doc)
return docs
def delete(self):
try:
self._db.drop_table(table_name=self._collection_name)
except ServerError as e:
if e.code == ServerErrCode.TABLE_NOT_EXIST:
pass
else:
raise
def _init_client(self, config) -> MochowClient:
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
client = MochowClient(config)
return client
def _init_database(self) -> Database:
exists = False
for db in self._client.list_databases():
if db.database_name == self._client_config.database:
exists = True
break
# Create database if not existed
if exists:
return self._client.database(self._client_config.database)
else:
try:
self._client.create_database(database_name=self._client_config.database)
except ServerError as e:
if e.code == ServerErrCode.DB_ALREADY_EXIST:
return self._client.database(self._client_config.database)
else:
raise
return self._client.database(self._client_config.database)
def _table_existed(self) -> bool:
tables = self._db.list_table()
return any(table.table_name == self._collection_name for table in tables)
def _create_table(self, dimension: int):
# Try to grab distributed lock and create table
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=60):
table_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(table_exist_cache_key):
return
if self._table_existed():
return
self.delete()
# check IndexType and MetricType
index_type = None
for k, v in IndexType.__members__.items():
if k == self._client_config.index_type:
index_type = v
if index_type is None:
raise ValueError("unsupported index_type")
metric_type = None
for k, v in MetricType.__members__.items():
if k == self._client_config.metric_type:
metric_type = v
if metric_type is None:
raise ValueError("unsupported metric_type")
# Construct field schema
fields = []
fields.append(
Field(
VDBField.PRIMARY_KEY,
FieldType.STRING,
primary_key=True,
partition_key=True,
auto_increment=False,
not_null=True,
)
)
fields.append(Field(VDBField.CONTENT_KEY, FieldType.TEXT, not_null=False))
fields.append(Field(VDBField.METADATA_KEY, FieldType.JSON, not_null=False))
fields.append(Field(VDBField.VECTOR, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))
# Construct vector index params
indexes = []
indexes.append(
VectorIndex(
index_name=self.vector_index,
index_type=index_type,
field=VDBField.VECTOR,
metric_type=metric_type,
params=HNSWParams(m=16, efconstruction=200),
)
)
# Filtering index
indexes.append(
FilteringIndex(
index_name=self.filtering_index,
fields=[VDBField.METADATA_KEY],
)
)
# Get analyzer and parse_mode from config
analyzer = getattr(
InvertedIndexAnalyzer,
self._client_config.inverted_index_analyzer,
InvertedIndexAnalyzer.DEFAULT_ANALYZER,
)
parse_mode = getattr(
InvertedIndexParseMode,
self._client_config.inverted_index_parser_mode,
InvertedIndexParseMode.COARSE_MODE,
)
# Inverted index
indexes.append(
InvertedIndex(
index_name=self.inverted_index,
fields=[VDBField.CONTENT_KEY],
params=InvertedIndexParams(
analyzer=analyzer,
parse_mode=parse_mode,
case_sensitive=True,
),
field_attributes=[InvertedIndexFieldAttribute.ANALYZED],
)
)
# Create table
self._db.create_table(
table_name=self._collection_name,
replication=self._client_config.replicas,
partition=Partition(partition_num=self._client_config.shard),
schema=Schema(fields=fields, indexes=indexes),
description="Table for Dify",
)
# Wait for table created
timeout = 300 # 5 minutes timeout
start_time = time.time()
while True:
time.sleep(1)
table = self._db.describe_table(self._collection_name)
if table.state == TableState.NORMAL:
break
if time.time() - start_time > timeout:
raise TimeoutError(f"Table creation timeout after {timeout} seconds")
redis_client.set(table_exist_cache_key, 1, ex=3600)
class BaiduVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.BAIDU, collection_name))
return BaiduVector(
collection_name=collection_name,
config=BaiduConfig(
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT or "",
connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS,
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT or "",
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY or "",
database=dify_config.BAIDU_VECTOR_DB_DATABASE or "",
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER,
inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE,
),
)

View File

@@ -0,0 +1,160 @@
import json
from typing import Any
import chromadb
from chromadb import QueryResult, Settings
from pydantic import BaseModel
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
class ChromaConfig(BaseModel):
host: str
port: int
tenant: str
database: str
auth_provider: str | None = None
auth_credentials: str | None = None
def to_chroma_params(self):
settings = Settings(
# auth
chroma_client_auth_provider=self.auth_provider,
chroma_client_auth_credentials=self.auth_credentials,
)
return {
"host": self.host,
"port": self.port,
"ssl": False,
"tenant": self.tenant,
"database": self.database,
"settings": settings,
}
class ChromaVector(BaseVector):
def __init__(self, collection_name: str, config: ChromaConfig):
super().__init__(collection_name)
self._client_config = config
self._client = chromadb.HttpClient(**self._client_config.to_chroma_params())
def get_type(self) -> str:
return VectorType.CHROMA
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:
# create collection
self.create_collection(self._collection_name)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str):
lock_name = f"vector_indexing_lock_{collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
self._client.get_or_create_collection(collection_name)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
collection = self._client.get_or_create_collection(self._collection_name)
# FIXME: chromadb using numpy array, fix the type error later
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
def delete_by_metadata_field(self, key: str, value: str):
collection = self._client.get_or_create_collection(self._collection_name)
# FIXME: fix the type error later
collection.delete(where={key: {"$eq": value}}) # type: ignore
def delete(self):
self._client.delete_collection(self._collection_name)
def delete_by_ids(self, ids: list[str]):
if not ids:
return
collection = self._client.get_or_create_collection(self._collection_name)
collection.delete(ids=ids)
def text_exists(self, id: str) -> bool:
collection = self._client.get_or_create_collection(self._collection_name)
response = collection.get(ids=[id])
return len(response) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
results: QueryResult = collection.query(
query_embeddings=query_vector,
n_results=kwargs.get("top_k", 4),
where={"document_id": {"$in": document_ids_filter}}, # type: ignore
)
else:
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# Check if results contain data
if not results["ids"] or not results["documents"] or not results["metadatas"] or not results["distances"]:
return []
ids = results["ids"][0]
documents = results["documents"][0]
metadatas = results["metadatas"][0]
distances = results["distances"][0]
docs = []
for index in range(len(ids)):
distance = distances[index]
metadata = dict(metadatas[index])
score = 1 - distance
if score >= score_threshold:
metadata["score"] = score
doc = Document(
page_content=documents[index],
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# chroma does not support BM25 full text searching
return []
class ChromaVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
return ChromaVector(
collection_name=collection_name,
config=ChromaConfig(
host=dify_config.CHROMA_HOST or "",
port=dify_config.CHROMA_PORT,
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
),
)

View File

@@ -0,0 +1,201 @@
# Clickzetta Vector Database Integration
This module provides integration with Clickzetta Lakehouse as a vector database for Dify.
## Features
- **Vector Storage**: Store and retrieve high-dimensional vectors using Clickzetta's native VECTOR type
- **Vector Search**: Efficient similarity search using HNSW algorithm
- **Full-Text Search**: Leverage Clickzetta's inverted index for powerful text search capabilities
- **Hybrid Search**: Combine vector similarity and full-text search for better results
- **Multi-language Support**: Built-in support for Chinese, English, and Unicode text processing
- **Scalable**: Leverage Clickzetta's distributed architecture for large-scale deployments
## Configuration
### Required Environment Variables
All seven configuration parameters are required:
```bash
# Authentication
CLICKZETTA_USERNAME=your_username
CLICKZETTA_PASSWORD=your_password
# Instance configuration
CLICKZETTA_INSTANCE=your_instance_id
CLICKZETTA_SERVICE=api.clickzetta.com
CLICKZETTA_WORKSPACE=your_workspace
CLICKZETTA_VCLUSTER=your_vcluster
CLICKZETTA_SCHEMA=your_schema
```
### Optional Configuration
```bash
# Batch processing
CLICKZETTA_BATCH_SIZE=100
# Full-text search configuration
CLICKZETTA_ENABLE_INVERTED_INDEX=true
CLICKZETTA_ANALYZER_TYPE=chinese # Options: keyword, english, chinese, unicode
CLICKZETTA_ANALYZER_MODE=smart # Options: max_word, smart
# Vector search configuration
CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance # Options: l2_distance, cosine_distance
```
## Usage
### 1. Set Clickzetta as the Vector Store
In your Dify configuration, set:
```bash
VECTOR_STORE=clickzetta
```
### 2. Table Structure
Clickzetta will automatically create tables with the following structure:
```sql
CREATE TABLE <collection_name> (
id STRING NOT NULL,
content STRING NOT NULL,
metadata JSON,
vector VECTOR(FLOAT, <dimension>) NOT NULL,
PRIMARY KEY (id)
);
-- Vector index for similarity search
CREATE VECTOR INDEX idx_<collection_name>_vec
ON TABLE <schema>.<collection_name>(vector)
PROPERTIES (
"distance.function" = "cosine_distance",
"scalar.type" = "f32"
);
-- Inverted index for full-text search (if enabled)
CREATE INVERTED INDEX idx_<collection_name>_text
ON <schema>.<collection_name>(content)
PROPERTIES (
"analyzer" = "chinese",
"mode" = "smart"
);
```
## Full-Text Search Capabilities
Clickzetta supports advanced full-text search with multiple analyzers:
### Analyzer Types
1. **keyword**: No tokenization, treats the entire string as a single token
- Best for: Exact matching, IDs, codes
1. **english**: Designed for English text
- Features: Recognizes ASCII letters and numbers, converts to lowercase
- Best for: English content
1. **chinese**: Chinese text tokenizer
- Features: Recognizes Chinese and English characters, removes punctuation
- Best for: Chinese or mixed Chinese-English content
1. **unicode**: Multi-language tokenizer based on Unicode
- Features: Recognizes text boundaries in multiple languages
- Best for: Multi-language content
### Analyzer Modes
- **max_word**: Fine-grained tokenization (more tokens)
- **smart**: Intelligent tokenization (balanced)
### Full-Text Search Functions
- `MATCH_ALL(column, query)`: All terms must be present
- `MATCH_ANY(column, query)`: At least one term must be present
- `MATCH_PHRASE(column, query)`: Exact phrase matching
- `MATCH_PHRASE_PREFIX(column, query)`: Phrase prefix matching
- `MATCH_REGEXP(column, pattern)`: Regular expression matching
## Performance Optimization
### Vector Search
1. **Adjust exploration factor** for accuracy vs speed trade-off:
```sql
SET cz.vector.index.search.ef=64;
```
1. **Use appropriate distance functions**:
- `cosine_distance`: Best for normalized embeddings (e.g., from language models)
- `l2_distance`: Best for raw feature vectors
### Full-Text Search
1. **Choose the right analyzer**:
- Use `keyword` for exact matching
- Use language-specific analyzers for better tokenization
1. **Combine with vector search**:
- Pre-filter with full-text search for better performance
- Use hybrid search for improved relevance
## Troubleshooting
### Connection Issues
1. Verify all 7 required configuration parameters are set
1. Check network connectivity to Clickzetta service
1. Ensure the user has proper permissions on the schema
### Search Performance
1. Verify vector index exists:
```sql
SHOW INDEX FROM <schema>.<table_name>;
```
1. Check if vector index is being used:
```sql
EXPLAIN SELECT ... WHERE l2_distance(...) < threshold;
```
Look for `vector_index_search_type` in the execution plan.
### Full-Text Search Not Working
1. Verify inverted index is created
1. Check analyzer configuration matches your content language
1. Use `TOKENIZE()` function to test tokenization:
```sql
SELECT TOKENIZE('your text', map('analyzer', 'chinese', 'mode', 'smart'));
```
## Limitations
1. Vector operations don't support `ORDER BY` or `GROUP BY` directly on vector columns
1. Full-text search relevance scores are not provided by Clickzetta
1. Inverted index creation may fail for very large existing tables (continue without error)
1. Index naming constraints:
- Index names must be unique within a schema
- Only one vector index can be created per column
- The implementation uses timestamps to ensure unique index names
1. A column can only have one vector index at a time
## References
- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search)
- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index)
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)

View File

@@ -0,0 +1 @@
# Clickzetta Vector Database Integration for Dify

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,378 @@
import json
import logging
import time
import uuid
from datetime import timedelta
from typing import Any
from couchbase import search # type: ignore
from couchbase.auth import PasswordAuthenticator # type: ignore
from couchbase.cluster import Cluster # type: ignore
from couchbase.management.search import SearchIndex # type: ignore
# needed for options -- cluster, timeout, SQL++ (N1QL) query, etc.
from couchbase.options import ClusterOptions, SearchOptions # type: ignore
from couchbase.vector_search import VectorQuery, VectorSearch # type: ignore
from flask import current_app
from pydantic import BaseModel, model_validator
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class CouchbaseConfig(BaseModel):
connection_string: str
user: str
password: str
bucket_name: str
scope_name: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values.get("connection_string"):
raise ValueError("config COUCHBASE_CONNECTION_STRING is required")
if not values.get("user"):
raise ValueError("config COUCHBASE_USER is required")
if not values.get("password"):
raise ValueError("config COUCHBASE_PASSWORD is required")
if not values.get("bucket_name"):
raise ValueError("config COUCHBASE_PASSWORD is required")
if not values.get("scope_name"):
raise ValueError("config COUCHBASE_SCOPE_NAME is required")
return values
class CouchbaseVector(BaseVector):
def __init__(self, collection_name: str, config: CouchbaseConfig):
super().__init__(collection_name)
self._client_config = config
"""Connect to couchbase"""
auth = PasswordAuthenticator(config.user, config.password)
options = ClusterOptions(auth)
self._cluster = Cluster(config.connection_string, options)
self._bucket = self._cluster.bucket(config.bucket_name)
self._scope = self._bucket.scope(config.scope_name)
self._bucket_name = config.bucket_name
self._scope_name = config.scope_name
# Wait until the cluster is ready for use.
self._cluster.wait_until_ready(timedelta(seconds=5))
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_id = str(uuid.uuid4()).replace("-", "")
self._create_collection(uuid=index_id, vector_length=len(embeddings[0]))
self.add_texts(texts, embeddings)
def _create_collection(self, vector_length: int, uuid: str):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
if self._collection_exists(self._collection_name):
return
manager = self._bucket.collections()
manager.create_collection(self._client_config.scope_name, self._collection_name)
index_manager = self._scope.search_indexes()
index_definition = json.loads("""
{
"type": "fulltext-index",
"name": "Embeddings._default.Vector_Search",
"uuid": "26d4db528e78b716",
"sourceType": "gocbcore",
"sourceName": "Embeddings",
"sourceUUID": "2242e4a25b4decd6650c9c7b3afa1dbf",
"planParams": {
"maxPartitionsPerPIndex": 1024,
"indexPartitions": 1
},
"params": {
"doc_config": {
"docid_prefix_delim": "",
"docid_regexp": "",
"mode": "scope.collection.type_field",
"type_field": "type"
},
"mapping": {
"analysis": { },
"default_analyzer": "standard",
"default_datetime_parser": "dateTimeOptional",
"default_field": "_all",
"default_mapping": {
"dynamic": true,
"enabled": true
},
"default_type": "_default",
"docvalues_dynamic": false,
"index_dynamic": true,
"store_dynamic": true,
"type_field": "_type",
"types": {
"collection_name": {
"dynamic": true,
"enabled": true,
"properties": {
"embedding": {
"dynamic": false,
"enabled": true,
"fields": [
{
"dims": 1536,
"index": true,
"name": "embedding",
"similarity": "dot_product",
"type": "vector",
"vector_index_optimized_for": "recall"
}
]
},
"metadata": {
"dynamic": true,
"enabled": true
},
"text": {
"dynamic": false,
"enabled": true,
"fields": [
{
"index": true,
"name": "text",
"store": true,
"type": "text"
}
]
}
}
}
}
},
"store": {
"indexType": "scorch",
"segmentVersion": 16
}
},
"sourceParams": { }
}
""")
index_definition["name"] = self._collection_name + "_search"
index_definition["uuid"] = uuid
index_definition["params"]["mapping"]["types"]["collection_name"]["properties"]["embedding"]["fields"][0][
"dims"
] = vector_length
index_definition["params"]["mapping"]["types"][self._scope_name + "." + self._collection_name] = (
index_definition["params"]["mapping"]["types"].pop("collection_name")
)
time.sleep(2)
index_manager.upsert_index(
SearchIndex(
index_definition["name"],
params=index_definition["params"],
source_name=self._bucket_name,
),
)
time.sleep(1)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _collection_exists(self, name: str):
scope_collection_map: dict[str, Any] = {}
# Get a list of all scopes in the bucket
for scope in self._bucket.collections().get_all_scopes():
scope_collection_map[scope.name] = []
# Get a list of all the collections in the scope
for collection in scope.collections:
scope_collection_map[scope.name].append(collection.name)
# Check if the collection exists in the scope
return self._collection_name in scope_collection_map[self._scope_name]
def get_type(self) -> str:
return VectorType.COUCHBASE
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
doc_ids = []
documents_to_insert = [
{"text": text, "embedding": vector, "metadata": metadata}
for _, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
]
for doc, id in zip(documents_to_insert, uuids):
_ = self._scope.collection(self._collection_name).upsert(id, doc)
doc_ids.extend(uuids)
return doc_ids
def text_exists(self, id: str) -> bool:
# Use a parameterized query for safety and correctness
query = f"""
SELECT COUNT(1) AS count FROM
`{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE META().id = $doc_id
"""
# Pass the id as a parameter to the query
result = self._cluster.query(query, named_parameters={"doc_id": id}).execute()
for row in result:
return bool(row["count"] > 0)
return False # Return False if no rows are returned
def delete_by_ids(self, ids: list[str]):
query = f"""
DELETE FROM `{self._bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE META().id IN $doc_ids;
"""
try:
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
except Exception:
logger.exception("Failed to delete documents, ids: %s", ids)
def delete_by_document_id(self, document_id: str):
query = f"""
DELETE FROM
`{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE META().id = $doc_id;
"""
self._cluster.query(query, named_parameters={"doc_id": document_id}).execute()
# def get_ids_by_metadata_field(self, key: str, value: str):
# query = f"""
# SELECT id FROM
# `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
# WHERE `metadata.{key}` = $value;
# """
# result = self._cluster.query(query, named_parameters={'value':value})
# return [row['id'] for row in result.rows()]
def delete_by_metadata_field(self, key: str, value: str):
query = f"""
DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE metadata.{key} = $value;
"""
self._cluster.query(query, named_parameters={"value": value}).execute()
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold") or 0.0
search_req = search.SearchRequest.create(
VectorSearch.from_vector_query(
VectorQuery(
"embedding",
query_vector,
top_k,
)
)
)
try:
search_iter = self._scope.search(
self._collection_name + "_search",
search_req,
SearchOptions(limit=top_k, collections=[self._collection_name], fields=["*"]),
)
docs = []
# Parse the results
for row in search_iter.rows():
text = row.fields.pop("text")
metadata = self._format_metadata(row.fields)
score = row.score
metadata["score"] = score
doc = Document(page_content=text, metadata=metadata)
if score >= score_threshold:
docs.append(doc)
except Exception as e:
raise ValueError(f"Search failed with error: {e}")
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments]
search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
)
docs = []
for row in search_iter.rows():
text = row.fields.pop("text")
metadata = self._format_metadata(row.fields)
score = row.score
metadata["score"] = score
doc = Document(page_content=text, metadata=metadata)
docs.append(doc)
except Exception as e:
raise ValueError(f"Search failed with error: {e}")
return docs
def delete(self):
manager = self._bucket.collections()
scopes = manager.get_all_scopes()
for scope in scopes:
for collection in scope.collections:
if collection.name == self._collection_name:
manager.drop_collection("_default", self._collection_name)
def _format_metadata(self, row_fields: dict[str, Any]) -> dict[str, Any]:
"""Helper method to format the metadata from the Couchbase Search API.
Args:
row_fields (Dict[str, Any]): The fields to format.
Returns:
Dict[str, Any]: The formatted metadata.
"""
metadata = {}
for key, value in row_fields.items():
# Couchbase Search returns the metadata key with a prefix
# `metadata.` We remove it to get the original metadata key
if key.startswith("metadata"):
new_key = key.split("metadata" + ".")[-1]
metadata[new_key] = value
else:
metadata[key] = value
return metadata
class CouchbaseVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> CouchbaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.COUCHBASE, collection_name))
config = current_app.config
return CouchbaseVector(
collection_name=collection_name,
config=CouchbaseConfig(
connection_string=config.get("COUCHBASE_CONNECTION_STRING", ""),
user=config.get("COUCHBASE_USER", ""),
password=config.get("COUCHBASE_PASSWORD", ""),
bucket_name=config.get("COUCHBASE_BUCKET_NAME", ""),
scope_name=config.get("COUCHBASE_SCOPE_NAME", ""),
),
)

View File

@@ -0,0 +1,104 @@
import json
import logging
from typing import Any
from flask import current_app
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchConfig,
ElasticSearchVector,
ElasticSearchVectorFactory,
)
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class ElasticSearchJaVector(ElasticSearchVector):
def create_collection(
self,
embeddings: list[list[float]],
metadatas: list[dict[Any, Any]] | None = None,
index_params: dict | None = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info("Collection %s already exists.", self._collection_name)
return
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
settings = {
"analysis": {
"analyzer": {
"ja_analyzer": {
"type": "custom",
"char_filter": [
"icu_normalizer",
"kuromoji_iteration_mark",
],
"tokenizer": "kuromoji_tokenizer",
"filter": [
"kuromoji_baseform",
"kuromoji_part_of_speech",
"ja_stop",
"kuromoji_number",
"kuromoji_stemmer",
],
}
}
}
}
mappings = {
"properties": {
Field.CONTENT_KEY: {
"type": "text",
"analyzer": "ja_analyzer",
"search_analyzer": "ja_analyzer",
},
Field.VECTOR: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
},
},
}
}
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchJaVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get("ELASTICSEARCH_HOST", "localhost"),
port=config.get("ELASTICSEARCH_PORT", 9200),
username=config.get("ELASTICSEARCH_USERNAME", ""),
password=config.get("ELASTICSEARCH_PASSWORD", ""),
),
attributes=[],
)

View File

@@ -0,0 +1,357 @@
import json
import logging
import math
from typing import Any, cast
from urllib.parse import urlparse
from elasticsearch import ConnectionError as ElasticsearchConnectionError
from elasticsearch import Elasticsearch
from flask import current_app
from packaging.version import parse as parse_version
from pydantic import BaseModel, model_validator
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class ElasticSearchConfig(BaseModel):
# Regular Elasticsearch config
host: str | None = None
port: int | None = None
username: str | None = None
password: str | None = None
# Elastic Cloud specific config
cloud_url: str | None = None # Cloud URL for Elasticsearch Cloud
api_key: str | None = None
# Common config
use_cloud: bool = False
ca_certs: str | None = None
verify_certs: bool = False
request_timeout: int = 100000
retry_on_timeout: bool = True
max_retries: int = 10000
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
use_cloud = values.get("use_cloud", False)
cloud_url = values.get("cloud_url")
if use_cloud:
# Cloud configuration validation - requires cloud_url and api_key
if not cloud_url:
raise ValueError("cloud_url is required for Elastic Cloud")
api_key = values.get("api_key")
if not api_key:
raise ValueError("api_key is required for Elastic Cloud")
else:
# Regular Elasticsearch validation
if not values.get("host"):
raise ValueError("config HOST is required for regular Elasticsearch")
if not values.get("port"):
raise ValueError("config PORT is required for regular Elasticsearch")
if not values.get("username"):
raise ValueError("config USERNAME is required for regular Elasticsearch")
if not values.get("password"):
raise ValueError("config PASSWORD is required for regular Elasticsearch")
return values
class ElasticSearchVector(BaseVector):
def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list):
super().__init__(index_name.lower())
self._client = self._init_client(config)
self._version = self._get_version()
self._check_version()
self._attributes = attributes
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
"""
Initialize Elasticsearch client for both regular Elasticsearch and Elastic Cloud.
"""
try:
# Check if using Elastic Cloud
client_config: dict[str, Any]
if config.use_cloud and config.cloud_url:
client_config = {
"request_timeout": config.request_timeout,
"retry_on_timeout": config.retry_on_timeout,
"max_retries": config.max_retries,
"verify_certs": config.verify_certs,
}
# Parse cloud URL and configure hosts
parsed_url = urlparse(config.cloud_url)
host = f"{parsed_url.scheme}://{parsed_url.hostname}"
if parsed_url.port:
host += f":{parsed_url.port}"
client_config["hosts"] = [host]
# API key authentication for cloud
client_config["api_key"] = config.api_key
# SSL settings
if config.ca_certs:
client_config["ca_certs"] = config.ca_certs
else:
# Regular Elasticsearch configuration
parsed_url = urlparse(config.host or "")
if parsed_url.scheme in {"http", "https"}:
hosts = f"{config.host}:{config.port}"
use_https = parsed_url.scheme == "https"
else:
hosts = f"http://{config.host}:{config.port}"
use_https = False
client_config = {
"hosts": [hosts],
"basic_auth": (config.username, config.password),
"request_timeout": config.request_timeout,
"retry_on_timeout": config.retry_on_timeout,
"max_retries": config.max_retries,
}
# Only add SSL settings if using HTTPS
if use_https:
client_config["verify_certs"] = config.verify_certs
if config.ca_certs:
client_config["ca_certs"] = config.ca_certs
client = Elasticsearch(**client_config)
# Test connection
if not client.ping():
raise ConnectionError("Failed to connect to Elasticsearch")
except ElasticsearchConnectionError as e:
raise ConnectionError(f"Vector database connection error: {str(e)}")
except Exception as e:
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
return client
def _get_version(self) -> str:
info = self._client.info()
# remove any suffix like "-SNAPSHOT" from the version string
return cast(str, info["version"]["number"]).split("-")[0]
def _check_version(self):
if parse_version(self._version) < parse_version("8.0.0"):
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
def get_type(self) -> str:
return VectorType.ELASTICSEARCH
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
for i in range(len(documents)):
self._client.index(
index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i] or None,
Field.METADATA_KEY: documents[i].metadata or {},
},
)
self._client.indices.refresh(index=self._collection_name)
return uuids
def text_exists(self, id: str) -> bool:
return bool(self._client.exists(index=self._collection_name, id=id))
def delete_by_ids(self, ids: list[str]):
if not ids:
return
for id in ids:
self._client.delete(index=self._collection_name, id=id)
def delete_by_metadata_field(self, key: str, value: str):
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit["_id"] for hit in results["hits"]["hits"]]
if ids:
self.delete_by_ids(ids)
def delete(self):
self._client.indices.delete(index=self._collection_name)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
num_candidates = math.ceil(top_k * 1.5)
knn = {"field": Field.VECTOR, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
docs_and_scores = []
for hit in results["hits"]["hits"]:
docs_and_scores.append(
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
),
hit["_score"],
)
)
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY: query}}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query_str = {
"bool": {
"must": {"match": {Field.CONTENT_KEY: query}},
"filter": {"terms": {"metadata.document_id": document_ids_filter}},
}
}
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
docs = []
for hit in results["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
)
)
return docs
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(
self,
embeddings: list[list[float]],
metadatas: list[dict[Any, Any]] | None = None,
index_params: dict | None = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info("Collection %s already exists.", self._collection_name)
return
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
mappings = {
"properties": {
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type
"document_id": {"type": "keyword"}, # Map doc_id to keyword type
},
},
}
}
self._client.indices.create(index=self._collection_name, mappings=mappings)
logger.info("Created index %s with dimension %s", self._collection_name, dim)
else:
logger.info("Collection %s already exists.", self._collection_name)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class ElasticSearchVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
# Check if ELASTICSEARCH_USE_CLOUD is explicitly set to false (boolean)
use_cloud_env = config.get("ELASTICSEARCH_USE_CLOUD", False)
if use_cloud_env is False:
# Use regular Elasticsearch with config values
config_dict = {
"use_cloud": False,
"host": config.get("ELASTICSEARCH_HOST", "elasticsearch"),
"port": config.get("ELASTICSEARCH_PORT", 9200),
"username": config.get("ELASTICSEARCH_USERNAME", "elastic"),
"password": config.get("ELASTICSEARCH_PASSWORD", "elastic"),
}
else:
# Check for cloud configuration
cloud_url = config.get("ELASTICSEARCH_CLOUD_URL")
if cloud_url:
config_dict = {
"use_cloud": True,
"cloud_url": cloud_url,
"api_key": config.get("ELASTICSEARCH_API_KEY"),
}
else:
# Fallback to regular Elasticsearch
config_dict = {
"use_cloud": False,
"host": config.get("ELASTICSEARCH_HOST", "localhost"),
"port": config.get("ELASTICSEARCH_PORT", 9200),
"username": config.get("ELASTICSEARCH_USERNAME", "elastic"),
"password": config.get("ELASTICSEARCH_PASSWORD", ""),
}
# Common configuration
config_dict.update(
{
"ca_certs": str(config.get("ELASTICSEARCH_CA_CERTS")) if config.get("ELASTICSEARCH_CA_CERTS") else None,
"verify_certs": bool(config.get("ELASTICSEARCH_VERIFY_CERTS", False)),
"request_timeout": int(config.get("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)),
"retry_on_timeout": bool(config.get("ELASTICSEARCH_RETRY_ON_TIMEOUT", True)),
"max_retries": int(config.get("ELASTICSEARCH_MAX_RETRIES", 10000)),
}
)
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(**config_dict),
attributes=[],
)

View File

@@ -0,0 +1,14 @@
from enum import StrEnum, auto
class Field(StrEnum):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = auto()
# Sparse Vector aims to support full text search
SPARSE_VECTOR = auto()
TEXT_KEY = "text"
PRIMARY_KEY = "id"
DOC_ID = "metadata.doc_id"
DOCUMENT_ID = "metadata.document_id"

View File

@@ -0,0 +1,215 @@
import json
import logging
import ssl
from typing import Any
from elasticsearch import Elasticsearch
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
def create_ssl_context() -> ssl.SSLContext:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
return ssl_context
class HuaweiCloudVectorConfig(BaseModel):
hosts: str
username: str | None = None
password: str | None = None
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["hosts"]:
raise ValueError("config HOSTS is required")
return values
def to_elasticsearch_params(self) -> dict[str, Any]:
params = {
"hosts": self.hosts.split(","),
"verify_certs": False,
"ssl_show_warn": False,
"request_timeout": 30000,
"retry_on_timeout": True,
"max_retries": 10,
}
if self.username and self.password:
params["basic_auth"] = (self.username, self.password)
return params
class HuaweiCloudVector(BaseVector):
def __init__(self, index_name: str, config: HuaweiCloudVectorConfig):
super().__init__(index_name.lower())
self._client = Elasticsearch(**config.to_elasticsearch_params())
def get_type(self) -> str:
return VectorType.HUAWEI_CLOUD
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
for i in range(len(documents)):
self._client.index(
index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i] or None,
Field.METADATA_KEY: documents[i].metadata or {},
},
)
self._client.indices.refresh(index=self._collection_name)
return uuids
def text_exists(self, id: str) -> bool:
return bool(self._client.exists(index=self._collection_name, id=id))
def delete_by_ids(self, ids: list[str]):
if not ids:
return
for id in ids:
self._client.delete(index=self._collection_name, id=id)
def delete_by_metadata_field(self, key: str, value: str):
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit["_id"] for hit in results["hits"]["hits"]]
if ids:
self.delete_by_ids(ids)
def delete(self):
self._client.indices.delete(index=self._collection_name)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
query = {
"size": top_k,
"query": {
"vector": {
Field.VECTOR: {
"vector": query_vector,
"topk": top_k,
}
}
},
}
results = self._client.search(index=self._collection_name, body=query)
docs_and_scores = []
for hit in results["hits"]["hits"]:
docs_and_scores.append(
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
),
hit["_score"],
)
)
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {"match": {Field.CONTENT_KEY: query}}
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
docs = []
for hit in results["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
)
)
return docs
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(
self,
embeddings: list[list[float]],
metadatas: list[dict[Any, Any]] | None = None,
index_params: dict | None = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info("Collection %s already exists.", self._collection_name)
return
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
mappings = {
"properties": {
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: { # Make sure the dimension is correct here
"type": "vector",
"dimension": dim,
"indexing": True,
"algorithm": "GRAPH",
"metric": "cosine",
"neighbors": 32,
"efc": 128,
},
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
},
},
}
}
settings = {"index.vector": True}
self._client.indices.create(index=self._collection_name, mappings=mappings, settings=settings)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class HuaweiCloudVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> HuaweiCloudVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.HUAWEI_CLOUD, collection_name))
return HuaweiCloudVector(
index_name=collection_name,
config=HuaweiCloudVectorConfig(
hosts=dify_config.HUAWEI_CLOUD_HOSTS or "http://localhost:9200",
username=dify_config.HUAWEI_CLOUD_USER,
password=dify_config.HUAWEI_CLOUD_PASSWORD,
),
)

View File

@@ -0,0 +1,413 @@
import json
import logging
import time
from typing import Any
from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_exponential
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("lindorm").setLevel(logging.WARN)
ROUTING_FIELD = "routing_field"
UGC_INDEX_PREFIX = "ugc_index"
class LindormVectorStoreConfig(BaseModel):
hosts: str | None
username: str | None = None
password: str | None = None
using_ugc: bool | None = False
request_timeout: float | None = 1.0 # timeout units: s
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["hosts"]:
raise ValueError("config URL is required")
if not values["username"]:
raise ValueError("config USERNAME is required")
if not values["password"]:
raise ValueError("config PASSWORD is required")
return values
def to_opensearch_params(self) -> dict[str, Any]:
params: dict[str, Any] = {
"hosts": self.hosts,
"use_ssl": False,
"pool_maxsize": 128,
"timeout": 30,
}
if self.username and self.password:
params["http_auth"] = (self.username, self.password)
return params
class LindormVectorStore(BaseVector):
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, using_ugc: bool, **kwargs):
self._routing: str | None = None
if using_ugc:
routing_value: str | None = kwargs.get("routing_value")
if routing_value is None:
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
self._routing = routing_value.lower()
super().__init__(collection_name.lower())
self._client_config = config
self._client = OpenSearch(**config.to_opensearch_params())
self._using_ugc = using_ugc
self.kwargs = kwargs
def get_type(self) -> str:
return VectorType.LINDORM
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings)
def refresh(self):
self._client.indices.refresh(index=self._collection_name)
def add_texts(
self,
documents: list[Document],
embeddings: list[list[float]],
batch_size: int = 64,
timeout: int = 60,
**kwargs,
):
logger.info("Total documents to add: %s", len(documents))
uuids = self._get_uuids(documents)
total_docs = len(documents)
num_batches = (total_docs + batch_size - 1) // batch_size
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
)
def _bulk_with_retry(actions):
try:
response = self._client.bulk(actions, timeout=timeout)
if response["errors"]:
error_items = [item for item in response["items"] if "error" in item["index"]]
error_msg = f"Bulk indexing had {len(error_items)} errors"
logger.exception(error_msg)
raise Exception(error_msg)
return response
except Exception:
logger.exception("Bulk indexing error")
raise
for batch_num in range(num_batches):
start_idx = batch_num * batch_size
end_idx = min((batch_num + 1) * batch_size, total_docs)
actions = []
for i in range(start_idx, end_idx):
action_header = {
"index": {
"_index": self.collection_name,
"_id": uuids[i],
}
}
action_values: dict[str, Any] = {
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i],
Field.METADATA_KEY: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
action_values[ROUTING_FIELD] = self._routing
actions.append(action_header)
actions.append(action_values)
try:
_bulk_with_retry(actions)
# logger.info(f"Successfully processed batch {batch_num + 1}")
# simple latency to avoid too many requests in a short time
if batch_num < num_batches - 1:
time.sleep(0.5)
except Exception:
logger.exception("Failed to process batch %s", batch_num + 1)
raise
def get_ids_by_metadata_field(self, key: str, value: str):
query: dict[str, Any] = {
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}}
}
if self._using_ugc:
query["query"]["bool"]["must"].append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
response = self._client.search(index=self._collection_name, body=query)
if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self.delete_by_ids(ids)
def delete_by_ids(self, ids: list[str]):
"""Delete documents by their IDs in batch.
Args:
ids: List of document IDs to delete
"""
if not ids:
return
params = {"routing": self._routing} if self._using_ugc else {}
# 1. First check if collection exists
if not self._client.indices.exists(index=self._collection_name):
logger.warning("Collection %s does not exist", self._collection_name)
return
# 2. Batch process deletions
actions = []
for id in ids:
if self._client.exists(index=self._collection_name, id=id, params=params):
actions.append(
{
"_op_type": "delete",
"_index": self._collection_name,
"_id": id,
**params, # Include routing if using UGC
}
)
else:
logger.warning("DELETE BY ID: ID %s does not exist in the index.", id)
# 3. Perform bulk deletion if there are valid documents to delete
if actions:
try:
helpers.bulk(self._client, actions)
except BulkIndexError as e:
for error in e.errors:
delete_error = error.get("delete", {})
status = delete_error.get("status")
doc_id = delete_error.get("_id")
if status == 404:
logger.warning("Document not found for deletion: %s", doc_id)
else:
logger.exception("Error deleting document: %s", error)
def delete(self):
if self._using_ugc:
routing_filter_query = {
"query": {"bool": {"must": [{"term": {f"{ROUTING_FIELD}.keyword": self._routing}}]}}
}
self._client.delete_by_query(self._collection_name, body=routing_filter_query)
self.refresh()
else:
if self._client.indices.exists(index=self._collection_name):
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
logger.info("Delete index success")
else:
logger.warning("Index '%s' does not exist. No deletion performed.", self._collection_name)
def text_exists(self, id: str) -> bool:
try:
params: dict[str, Any] = {}
if self._using_ugc:
params["routing"] = self._routing
self._client.get(index=self._collection_name, id=id, params=params)
return True
except:
return False
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
if not isinstance(query_vector, list):
raise ValueError("query_vector should be a list of floats")
if not all(isinstance(x, float) for x in query_vector):
raise ValueError("All elements in query_vector should be floats")
filters = []
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filters.append({"terms": {"metadata.document_id.keyword": document_ids_filter}})
if self._using_ugc:
filters.append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
top_k = kwargs.get("top_k", 5)
search_query: dict[str, Any] = {
"size": top_k,
"_source": True,
"query": {"knn": {Field.VECTOR: {"vector": query_vector, "k": top_k}}},
}
final_ext: dict[str, Any] = {"lvector": {}}
if filters is not None and len(filters) > 0:
# when using filter, transform filter from List[Dict] to Dict as valid format
filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
search_query["query"]["knn"][Field.VECTOR]["filter"] = filter_dict # filter should be Dict
final_ext["lvector"]["filter_type"] = "pre_filter"
if final_ext != {"lvector": {}}:
search_query["ext"] = final_ext
try:
params = {"timeout": self._client_config.request_timeout}
if self._using_ugc:
params["routing"] = self._routing # type: ignore
response = self._client.search(index=self._collection_name, body=search_query, params=params)
except Exception:
logger.exception("Error executing vector search, query: %s", search_query)
raise
docs_and_scores = []
for hit in response["hits"]["hits"]:
docs_and_scores.append(
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
),
hit["_score"],
)
)
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
full_text_query = {"query": {"bool": {"must": [{"match": {Field.CONTENT_KEY.value: query}}]}}}
filters = []
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filters.append({"terms": {"metadata.document_id.keyword": document_ids_filter}})
if self._using_ugc:
filters.append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
if filters:
full_text_query["query"]["bool"]["filter"] = filters
try:
params: dict[str, Any] = {"timeout": self._client_config.request_timeout}
if self._using_ugc:
params["routing"] = self._routing
response = self._client.search(index=self._collection_name, body=full_text_query, params=params)
except Exception:
logger.exception("Error executing vector search, query: %s", full_text_query)
raise
docs = []
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY)
vector = hit["_source"].get(Field.VECTOR)
page_content = hit["_source"].get(Field.CONTENT_KEY)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
return docs
def create_collection(
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
):
if not embeddings:
raise ValueError(f"Embeddings list cannot be empty for collection create '{self._collection_name}'")
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info("Collection %s already exists.", self._collection_name)
return
if not self._client.indices.exists(index=self._collection_name):
index_body = {
"settings": {"index": {"knn": True, "knn_routing": self._using_ugc}},
"mappings": {
"properties": {
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: {
"type": "knn_vector",
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
"method": {
"name": index_params.get("index_type", "hnsw")
if index_params
else dify_config.LINDORM_INDEX_TYPE,
"space_type": index_params.get("space_type", "l2")
if index_params
else dify_config.LINDORM_DISTANCE_TYPE,
"engine": "lvector",
},
},
}
},
}
logger.info("Creating Lindorm Search index %s", self._collection_name)
self._client.indices.create(index=self._collection_name, body=index_body)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class LindormVectorStoreFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
lindorm_config = LindormVectorStoreConfig(
hosts=dify_config.LINDORM_URL,
username=dify_config.LINDORM_USERNAME,
password=dify_config.LINDORM_PASSWORD,
using_ugc=dify_config.LINDORM_USING_UGC,
request_timeout=dify_config.LINDORM_QUERY_TIMEOUT,
)
using_ugc = dify_config.LINDORM_USING_UGC
if using_ugc is None:
raise ValueError("LINDORM_USING_UGC is not set")
routing_value = None
if dataset.index_struct:
# if an existed record's index_struct_dict doesn't contain using_ugc field,
# it actually stores in the normal index format
stored_in_ugc: bool = dataset.index_struct_dict.get("using_ugc", False)
using_ugc = stored_in_ugc
if stored_in_ugc:
dimension = dataset.index_struct_dict["dimension"]
index_type = dataset.index_struct_dict["index_type"]
distance_type = dataset.index_struct_dict["distance_type"]
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}".lower()
else:
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"].lower()
else:
embedding_vector = embeddings.embed_query("hello word")
dimension = len(embedding_vector)
class_prefix = Dataset.gen_collection_name_by_id(dataset.id)
index_struct_dict = {
"type": VectorType.LINDORM,
"vector_store": {"class_prefix": class_prefix},
"index_type": dify_config.LINDORM_INDEX_TYPE,
"dimension": dimension,
"distance_type": dify_config.LINDORM_DISTANCE_TYPE,
"using_ugc": using_ugc,
}
dataset.index_struct = json.dumps(index_struct_dict)
if using_ugc:
index_type = dify_config.LINDORM_INDEX_TYPE
distance_type = dify_config.LINDORM_DISTANCE_TYPE
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}".lower()
routing_value = class_prefix.lower()
else:
index_name = class_prefix.lower()
return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value, using_ugc=using_ugc)

View File

@@ -0,0 +1,239 @@
import json
import logging
import uuid
from collections.abc import Callable
from functools import wraps
from typing import Any, Concatenate, ParamSpec, TypeVar
from mo_vector.client import MoVectorClient # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T", bound="MatrixoneVector")
def ensure_client(func: Callable[Concatenate[T, P], R]):
@wraps(func)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)
return wrapper
class MatrixoneConfig(BaseModel):
host: str = "localhost"
port: int = 6001
user: str = "dump"
password: str = "111"
database: str = "dify"
metric: str = "l2"
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config host is required")
if not values["port"]:
raise ValueError("config port is required")
if not values["user"]:
raise ValueError("config user is required")
if not values["password"]:
raise ValueError("config password is required")
if not values["database"]:
raise ValueError("config database is required")
return values
class MatrixoneVector(BaseVector):
"""
Matrixone vector storage implementation.
"""
def __init__(self, collection_name: str, config: MatrixoneConfig):
super().__init__(collection_name)
self.config = config
self.collection_name = collection_name.lower()
self.client = None
@property
def collection_name(self):
return self._collection_name
@collection_name.setter
def collection_name(self, value):
self._collection_name = value
def get_type(self) -> str:
return VectorType.MATRIXONE
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if self.client is None:
self.client = self._get_client(len(embeddings[0]), True)
return self.add_texts(texts, embeddings)
def _get_client(self, dimension: int | None = None, create_table: bool = False) -> MoVectorClient:
"""
Create a new client for the collection.
The collection will be created if it doesn't exist.
"""
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
client = MoVectorClient(
connection_string=f"mysql+pymysql://{self.config.user}:{self.config.password}@{self.config.host}:{self.config.port}/{self.config.database}",
table_name=self.collection_name,
vector_dimension=dimension,
create_table=create_table,
)
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return client
try:
client.create_full_text_index()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
except Exception:
logger.exception("Failed to create full text index")
return client
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
if self.client is None:
self.client = self._get_client(len(embeddings[0]), True)
assert self.client is not None
ids = []
for doc in documents:
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
ids.append(doc_id)
self.client.insert(
texts=[doc.page_content for doc in documents],
embeddings=embeddings,
metadatas=[doc.metadata for doc in documents],
ids=ids,
)
return ids
@ensure_client
def text_exists(self, id: str) -> bool:
assert self.client is not None
result = self.client.get(ids=[id])
return len(result) > 0
@ensure_client
def delete_by_ids(self, ids: list[str]):
assert self.client is not None
if not ids:
return
self.client.delete(ids=ids)
@ensure_client
def get_ids_by_metadata_field(self, key: str, value: str):
assert self.client is not None
results = self.client.query_by_metadata(filter={key: value})
return [result.id for result in results]
@ensure_client
def delete_by_metadata_field(self, key: str, value: str):
assert self.client is not None
self.client.delete(filter={key: value})
@ensure_client
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
assert self.client is not None
top_k = kwargs.get("top_k", 5)
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = {"document_id": {"$in": document_ids_filter}}
results = self.client.query(
query_vector=query_vector,
k=top_k,
filter=filter,
)
docs = []
# TODO: add the score threshold to the query
for result in results:
metadata = result.metadata
docs.append(
Document(
page_content=result.document,
metadata=metadata,
)
)
return docs
@ensure_client
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
assert self.client is not None
top_k = kwargs.get("top_k", 5)
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = {"document_id": {"$in": document_ids_filter}}
score_threshold = float(kwargs.get("score_threshold", 0.0))
results = self.client.full_text_query(
keywords=[query],
k=top_k,
filter=filter,
)
docs = []
for result in results:
metadata = result.metadata
if isinstance(metadata, str):
import json
metadata = json.loads(metadata)
score = 1 - result.distance
if score >= score_threshold:
metadata["score"] = score
docs.append(
Document(
page_content=result.document,
metadata=metadata,
)
)
return docs
@ensure_client
def delete(self):
assert self.client is not None
self.client.delete()
class MatrixoneVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MATRIXONE, collection_name))
config = MatrixoneConfig(
host=dify_config.MATRIXONE_HOST or "localhost",
port=dify_config.MATRIXONE_PORT or 6001,
user=dify_config.MATRIXONE_USER or "dump",
password=dify_config.MATRIXONE_PASSWORD or "111",
database=dify_config.MATRIXONE_DATABASE or "dify",
metric=dify_config.MATRIXONE_METRIC or "l2",
)
return MatrixoneVector(collection_name=collection_name, config=config)

View File

@@ -0,0 +1,416 @@
import json
import logging
from typing import Any
from packaging import version
from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException # type: ignore
from pymilvus.milvus_client import IndexParams # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class MilvusConfig(BaseModel):
"""
Configuration class for Milvus connection.
"""
uri: str # Milvus server URI
token: str | None = None # Optional token for authentication
user: str | None = None # Username for authentication
password: str | None = None # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
analyzer_params: str | None = None # Analyzer params
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
"""
Validate the configuration values.
Raises ValueError if required fields are missing.
"""
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
if not values.get("token"):
if not values.get("user"):
raise ValueError("config MILVUS_USER is required")
if not values.get("password"):
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
"""
Convert the configuration to a dictionary of Milvus connection parameters.
"""
return {
"uri": self.uri,
"token": self.token,
"user": self.user,
"password": self.password,
"db_name": self.database,
"analyzer_params": self.analyzer_params,
}
class MilvusVector(BaseVector):
"""
Milvus vector storage implementation.
"""
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = "Session" # Consistency level for Milvus operations
self._fields: list[str] = [] # List of fields in the collection
if self._client.has_collection(collection_name):
self._load_collection_fields()
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def _load_collection_fields(self, fields: list[str] | None = None):
if fields is None:
# Load collection fields from remote server
collection_info = self._client.describe_collection(self._collection_name)
fields = [field["name"] for field in collection_info["fields"]]
# Since primary field is auto-id, no need to track it
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
def _check_hybrid_search_support(self) -> bool:
"""
Check if the current Milvus version supports hybrid search.
Returns True if the version is >= 2.5.0, otherwise False.
"""
if not self._client_config.enable_hybrid_search:
return False
try:
milvus_version = self._client.get_server_version()
# Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility
if "Zilliz Cloud" in milvus_version:
return True
# For standard Milvus installations, check version number
return version.parse(milvus_version) >= version.parse("2.5.0")
except Exception as e:
logger.warning("Failed to check Milvus version: %s. Disabling hybrid search.", str(e))
return False
def get_type(self) -> str:
"""
Get the type of vector storage (Milvus).
"""
return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""
Create a collection and add texts with embeddings.
"""
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Add texts and their embeddings to the collection.
"""
insert_dict_list = []
for i in range(len(documents)):
insert_dict = {
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
# function will automatically convert the native text into a sparse vector for us.
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i],
Field.METADATA_KEY: documents[i].metadata,
}
insert_dict_list.append(insert_dict)
# Total insert count
total_count = len(insert_dict_list)
pks: list[str] = []
for i in range(0, total_count, 1000):
# Insert into the collection.
batch_insert_list = insert_dict_list[i : i + 1000]
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
except MilvusException as e:
logger.exception("Failed to insert batch starting at entity: %s/%s", i, total_count)
raise e
return pks
def get_ids_by_metadata_field(self, key: str, value: str):
"""
Get document IDs by metadata field key and value.
"""
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
)
if result:
return [item["id"] for item in result]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
"""
Delete documents by metadata field key and value.
"""
if self._client.has_collection(self._collection_name):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, ids: list[str]):
"""
Delete documents by their IDs.
"""
if self._client.has_collection(self._collection_name):
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
)
if result:
ids = [item["id"] for item in result]
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self):
"""
Delete the entire collection.
"""
if self._client.has_collection(self._collection_name):
self._client.drop_collection(self._collection_name, None)
def text_exists(self, id: str) -> bool:
"""
Check if a text with the given ID exists in the collection.
"""
if not self._client.has_collection(self._collection_name):
return False
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] == "{id}"', output_fields=["id"]
)
return len(result) > 0
def field_exists(self, field: str) -> bool:
"""
Check if a field exists in the collection.
"""
return field in self._fields
def _process_search_results(
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
) -> list[Document]:
"""
Common method to process search results
:param results: Search results
:param output_fields: Fields to be output
:param score_threshold: Score threshold for filtering
:return: List of documents
"""
docs = []
for result in results[0]:
metadata = result["entity"].get(output_fields[1], {})
metadata["score"] = result["distance"]
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
docs.append(doc)
return docs
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search for documents by vector similarity.
"""
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""
if document_ids_filter:
document_ids = ", ".join(f'"{id}"' for id in document_ids_filter)
filter = f'metadata["document_id"] in [{document_ids}]'
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
anns_field=Field.VECTOR,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
filter=filter,
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""
Search for documents by full-text search (if hybrid search is enabled).
"""
if not self._hybrid_search_enabled:
logger.warning(
"Full-text search is disabled: set MILVUS_ENABLE_HYBRID_SEARCH=true (requires Milvus >= 2.5.0)."
)
return []
if not self.field_exists(Field.SPARSE_VECTOR):
logger.warning(
"Full-text search unavailable: collection missing 'sparse_vector' field; "
"recreate the collection after enabling MILVUS_ENABLE_HYBRID_SEARCH to add BM25 sparse index."
)
return []
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f'metadata["document_id"] in [{document_ids}]'
results = self._client.search(
collection_name=self._collection_name,
data=[query],
anns_field=Field.SPARSE_VECTOR,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
filter=filter,
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
def create_collection(
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
):
"""
Create a new collection in Milvus with the specified schema and index parameters.
"""
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
# Grab the existing collection if it exists
if not self._client.has_collection(self._collection_name):
from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
# Determine embedding dim
dim = len(embeddings[0])
fields = []
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY, DataType.JSON, max_length=65_535))
# Create the text field, enable_analyzer will be set True to support milvus automatically
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
content_field_kwargs: dict[str, Any] = {
"max_length": 65_535,
"enable_analyzer": self._hybrid_search_enabled,
}
if (
self._hybrid_search_enabled
and self._client_config.analyzer_params is not None
and self._client_config.analyzer_params.strip()
):
content_field_kwargs["analyzer_params"] = self._client_config.analyzer_params
fields.append(FieldSchema(Field.CONTENT_KEY, DataType.VARCHAR, **content_field_kwargs))
# Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
fields.append(FieldSchema(Field.VECTOR, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
fields.append(FieldSchema(Field.SPARSE_VECTOR, DataType.SPARSE_FLOAT_VECTOR))
schema = CollectionSchema(fields)
# Create custom function to support text to sparse vector by BM25
if self._hybrid_search_enabled:
bm25_function = Function(
name="text_bm25_emb",
input_field_names=[Field.CONTENT_KEY],
output_field_names=[Field.SPARSE_VECTOR],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
self._load_collection_fields([f.name for f in schema.fields])
# Create Index params for the collection
index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR, **index_params)
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
index_params_obj.add_index(
field_name=Field.SPARSE_VECTOR, index_type="AUTOINDEX", metric_type="BM25"
)
# Create the collection
self._client.create_collection(
collection_name=self._collection_name,
schema=schema,
index_params=index_params_obj,
consistency_level=self._consistency_level,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config: MilvusConfig) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
if config.token:
client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
else:
client = MilvusClient(
uri=config.uri,
user=config.user or "",
password=config.password or "",
db_name=config.database,
)
return client
class MilvusVectorFactory(AbstractVectorFactory):
"""
Factory class for creating MilvusVector instances.
"""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
"""
Initialize a MilvusVector instance for the given dataset.
"""
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
uri=dify_config.MILVUS_URI or "",
token=dify_config.MILVUS_TOKEN or "",
user=dify_config.MILVUS_USER or "",
password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
analyzer_params=dify_config.MILVUS_ANALYZER_PARAMS or "",
),
)

View File

@@ -0,0 +1,183 @@
import json
import logging
import uuid
from enum import StrEnum
from typing import Any
from clickhouse_connect import get_client
from pydantic import BaseModel
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class MyScaleConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
fts_params: str
class SortOrder(StrEnum):
ASC = "ASC"
DESC = "DESC"
class MyScaleVector(BaseVector):
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
super().__init__(collection_name)
self._config = config
self._metric = metric
self._vec_order = SortOrder.ASC if metric.upper() in {"COSINE", "L2"} else SortOrder.DESC
self._client = get_client(
host=config.host,
port=config.port,
username=config.user,
password=config.password,
)
self._client.command("SET allow_experimental_object_type=1")
def get_type(self) -> str:
return VectorType.MYSCALE
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(documents=texts, embeddings=embeddings, **kwargs)
def _create_collection(self, dimension: int):
logger.info("create MyScale collection %s with dimension %s", self._collection_name, dimension)
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
sql = f"""
CREATE TABLE IF NOT EXISTS {self._config.database}.{self._collection_name}(
id String,
text String,
vector Array(Float32),
metadata JSON,
CONSTRAINT cons_vec_len CHECK length(vector) = {dimension},
VECTOR INDEX vidx vector TYPE DEFAULT('metric_type = {self._metric}'),
INDEX text_idx text TYPE fts{fts_params}
) ENGINE = MergeTree ORDER BY id
"""
self._client.command(sql)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = []
columns = ["id", "text", "vector", "metadata"]
values = []
for i, doc in enumerate(documents):
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
row = (
doc_id,
self.escape_str(doc.page_content),
embeddings[i],
json.dumps(doc.metadata) if doc.metadata else {},
)
values.append(str(row))
ids.append(doc_id)
sql = f"""
INSERT INTO {self._config.database}.{self._collection_name}
({",".join(columns)}) VALUES {",".join(values)}
"""
self._client.command(sql)
return ids
@staticmethod
def escape_str(value: Any) -> str:
return "".join(" " if c in {"\\", "'"} else c for c in str(value))
def text_exists(self, id: str) -> bool:
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
return results.row_count > 0
def delete_by_ids(self, ids: list[str]):
if not ids:
return
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
)
def get_ids_by_metadata_field(self, key: str, value: str):
rows = self._client.query(
f"SELECT DISTINCT id FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
).result_rows
return [row[0] for row in rows]
def delete_by_metadata_field(self, key: str, value: str):
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs)
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
score_threshold = float(kwargs.get("score_threshold") or 0.0)
where_str = (
f"WHERE dist < {1 - score_threshold}"
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
else ""
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
sql = f"""
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
"""
try:
return [
Document(
page_content=r["text"],
vector=r["vector"],
metadata=r["metadata"],
)
for r in self._client.query(sql).named_results()
]
except Exception:
logger.exception("Vector search operation failed")
return []
def delete(self):
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")
class MyScaleVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
return MyScaleVector(
collection_name=collection_name,
config=MyScaleConfig(
host=dify_config.MYSCALE_HOST,
port=dify_config.MYSCALE_PORT,
user=dify_config.MYSCALE_USER,
password=dify_config.MYSCALE_PASSWORD,
database=dify_config.MYSCALE_DATABASE,
fts_params=dify_config.MYSCALE_FTS_PARAMS,
),
)

View File

@@ -0,0 +1,326 @@
import json
import logging
import math
from typing import Any
from pydantic import BaseModel, model_validator
from pyobvector import VECTOR, ObVecClient, l2_distance # type: ignore
from sqlalchemy import JSON, Column, String
from sqlalchemy.dialects.mysql import LONGTEXT
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256}
DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64}
OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE = "HNSW"
DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2"
class OceanBaseVectorConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
enable_hybrid_search: bool = False
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config OCEANBASE_VECTOR_HOST is required")
if not values["port"]:
raise ValueError("config OCEANBASE_VECTOR_PORT is required")
if not values["user"]:
raise ValueError("config OCEANBASE_VECTOR_USER is required")
if not values["database"]:
raise ValueError("config OCEANBASE_VECTOR_DATABASE is required")
return values
class OceanBaseVector(BaseVector):
def __init__(self, collection_name: str, config: OceanBaseVectorConfig):
super().__init__(collection_name)
self._config = config
self._hnsw_ef_search = -1
self._client = ObVecClient(
uri=f"{self._config.host}:{self._config.port}",
user=self._config.user,
password=self._config.password,
db_name=self._config.database,
)
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def get_type(self) -> str:
return VectorType.OCEANBASE
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._vec_dim = len(embeddings[0])
self._create_collection()
self.add_texts(texts, embeddings)
def _create_collection(self):
lock_name = "vector_indexing_lock_" + self._collection_name
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_" + self._collection_name
if redis_client.get(collection_exist_cache_key):
return
if self._client.check_table_exists(self._collection_name):
return
self.delete()
vals = []
params = self._client.perform_raw_text_sql("SHOW PARAMETERS LIKE '%ob_vector_memory_limit_percentage%'")
for row in params:
val = int(row[6])
vals.append(val)
if len(vals) == 0:
raise ValueError("ob_vector_memory_limit_percentage not found in parameters.")
if any(val == 0 for val in vals):
try:
self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30")
except Exception as e:
raise Exception(
"Failed to set ob_vector_memory_limit_percentage. "
+ "Maybe the database user has insufficient privilege.",
e,
)
cols = [
Column("id", String(36), primary_key=True, autoincrement=False),
Column("vector", VECTOR(self._vec_dim)),
Column("text", LONGTEXT),
Column("metadata", JSON),
]
vidx_params = self._client.prepare_index_params()
vidx_params.add_index(
field_name="vector",
index_type=OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE,
index_name="vector_index",
metric_type=DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE,
params=DEFAULT_OCEANBASE_HNSW_BUILD_PARAM,
)
self._client.create_table_with_index_params(
table_name=self._collection_name,
columns=cols,
vidxs=vidx_params,
)
logger.debug("DEBUG: Table '%s' created successfully", self._collection_name)
if self._hybrid_search_enabled:
# Get parser from config or use default ik parser
parser_name = dify_config.OCEANBASE_FULLTEXT_PARSER or "ik"
allowed_parsers = ["ngram", "beng", "space", "ngram2", "ik", "japanese_ftparser", "thai_ftparser"]
if parser_name not in allowed_parsers:
raise ValueError(
f"Invalid OceanBase full-text parser: {parser_name}. "
f"Allowed values are: {', '.join(allowed_parsers)}"
)
logger.debug("Hybrid search is enabled, parser_name='%s'", parser_name)
logger.debug(
"About to create fulltext index for collection '%s' using parser '%s'",
self._collection_name,
parser_name,
)
try:
sql_command = f"""ALTER TABLE {self._collection_name}
ADD FULLTEXT INDEX fulltext_index_for_col_text (text) WITH PARSER {parser_name}"""
logger.debug("DEBUG: Executing SQL: %s", sql_command)
self._client.perform_raw_text_sql(sql_command)
logger.debug("DEBUG: Fulltext index created successfully for '%s'", self._collection_name)
except Exception as e:
logger.exception("Exception occurred while creating fulltext index")
raise Exception(
"Failed to add fulltext index to the target table, your OceanBase version must be "
"4.3.5.1 or above to support fulltext index and vector index in the same table"
) from e
else:
logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name)
self._client.refresh_metadata([self._collection_name])
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _check_hybrid_search_support(self) -> bool:
"""
Check if the current OceanBase version supports hybrid search.
Returns True if the version is >= 4.3.5.1, otherwise False.
"""
if not self._config.enable_hybrid_search:
return False
try:
from packaging import version
# return OceanBase_CE 4.3.5.1 (r101000042025031818-bxxxx) (Built Mar 18 2025 18:13:36)
result = self._client.perform_raw_text_sql("SELECT @@version_comment AS version")
ob_full_version = result.fetchone()[0]
ob_version = ob_full_version.split()[1]
logger.debug("Current OceanBase version is %s", ob_version)
return version.parse(ob_version) >= version.parse("4.3.5.1")
except Exception as e:
logger.warning("Failed to check OceanBase version: %s. Disabling hybrid search.", str(e))
return False
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = self._get_uuids(documents)
for id, doc, emb in zip(ids, documents, embeddings):
self._client.insert(
table_name=self._collection_name,
data={
"id": id,
"vector": emb,
"text": doc.page_content,
"metadata": doc.metadata,
},
)
def text_exists(self, id: str) -> bool:
cur = self._client.get(table_name=self._collection_name, ids=id)
return bool(cur.rowcount != 0)
def delete_by_ids(self, ids: list[str]):
if not ids:
return
self._client.delete(table_name=self._collection_name, ids=ids)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
from sqlalchemy import text
cur = self._client.get(
table_name=self._collection_name,
ids=None,
where_clause=[text(f"metadata->>'$.{key}' = '{value}'")],
output_column_name=["id"],
)
return [row[0] for row in cur]
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
self.delete_by_ids(ids)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
if not self._hybrid_search_enabled:
return []
try:
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'$.document_id' IN ({document_ids})"
full_sql = f"""SELECT metadata, text, MATCH (text) AGAINST (:query) AS score
FROM {self._collection_name}
WHERE MATCH (text) AGAINST (:query) > 0
{where_clause}
ORDER BY score DESC
LIMIT {top_k}"""
with self._client.engine.connect() as conn:
with conn.begin():
from sqlalchemy import text
result = conn.execute(text(full_sql), {"query": query})
rows = result.fetchall()
docs = []
for row in rows:
metadata_str, _text, score = row
try:
metadata = json.loads(metadata_str)
except json.JSONDecodeError:
logger.warning("Invalid JSON metadata: %s", metadata_str)
metadata = {}
metadata["score"] = score
docs.append(Document(page_content=_text, metadata=metadata))
return docs
except Exception as e:
logger.warning("Failed to fulltext search: %s.", str(e))
return []
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
_where_clause = None
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
from sqlalchemy import text
_where_clause = [text(where_clause)]
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
if ef_search != self._hnsw_ef_search:
self._client.set_ob_hnsw_ef_search(ef_search)
self._hnsw_ef_search = ef_search
topk = kwargs.get("top_k", 10)
try:
cur = self._client.ann_search(
table_name=self._collection_name,
vec_column_name="vector",
vec_data=query_vector,
topk=topk,
distance_func=l2_distance,
output_column_names=["text", "metadata"],
with_dist=True,
where_clause=_where_clause,
)
except Exception as e:
raise Exception("Failed to search by vector. ", e)
docs = []
for _text, metadata, distance in cur:
metadata = json.loads(metadata)
metadata["score"] = 1 - distance / math.sqrt(2)
docs.append(
Document(
page_content=_text,
metadata=metadata,
)
)
return docs
def delete(self):
self._client.drop_table_if_exist(self._collection_name)
class OceanBaseVectorFactory(AbstractVectorFactory):
def init_vector(
self,
dataset: Dataset,
attributes: list,
embeddings: Embeddings,
) -> BaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OCEANBASE, collection_name))
return OceanBaseVector(
collection_name,
OceanBaseVectorConfig(
host=dify_config.OCEANBASE_VECTOR_HOST or "",
port=dify_config.OCEANBASE_VECTOR_PORT or 0,
user=dify_config.OCEANBASE_VECTOR_USER or "",
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
database=dify_config.OCEANBASE_VECTOR_DATABASE or "",
enable_hybrid_search=dify_config.OCEANBASE_ENABLE_HYBRID_SEARCH or False,
),
)

View File

@@ -0,0 +1,264 @@
import json
import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
class OpenGaussConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
min_connection: int
max_connection: int
enable_pq: bool = False # Enable PQ acceleration
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config OPENGAUSS_HOST is required")
if not values["port"]:
raise ValueError("config OPENGAUSS_PORT is required")
if not values["user"]:
raise ValueError("config OPENGAUSS_USER is required")
if not values["password"]:
raise ValueError("config OPENGAUSS_PASSWORD is required")
if not values["database"]:
raise ValueError("config OPENGAUSS_DATABASE is required")
if not values["min_connection"]:
raise ValueError("config OPENGAUSS_MIN_CONNECTION is required")
if not values["max_connection"]:
raise ValueError("config OPENGAUSS_MAX_CONNECTION is required")
if values["min_connection"] > values["max_connection"]:
raise ValueError("config OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION")
return values
SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} (
id UUID PRIMARY KEY,
text TEXT NOT NULL,
meta JSONB NOT NULL,
embedding vector({dimension}) NOT NULL
);
"""
SQL_CREATE_INDEX_PQ = """
CREATE INDEX IF NOT EXISTS embedding_{table_name}_pq_idx ON {table_name}
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64, enable_pq=on, pq_m={pq_m});
"""
SQL_CREATE_INDEX = """
CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name}
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
"""
class OpenGauss(BaseVector):
def __init__(self, collection_name: str, config: OpenGaussConfig):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}"
self.pq_enabled = config.enable_pq
def get_type(self) -> str:
return VectorType.OPENGAUSS
def _create_connection_pool(self, config: OpenGaussConfig):
return psycopg2.pool.SimpleConnectionPool(
config.min_connection,
config.max_connection,
host=config.host,
port=config.port,
user=config.user,
password=config.password,
database=config.database,
)
@contextmanager
def _get_cursor(self):
conn = self.pool.getconn()
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
self.pool.putconn(conn)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
self.add_texts(texts, embeddings)
self._create_index(dimension)
def _create_index(self, dimension: int):
index_cache_key = f"vector_index_{self._collection_name}"
lock_name = f"{index_cache_key}_lock"
with redis_client.lock(lock_name, timeout=60):
index_exist_cache_key = f"vector_index_{self._collection_name}"
if redis_client.get(index_exist_cache_key):
return
with self._get_cursor() as cur:
if dimension <= 2000:
if self.pq_enabled:
cur.execute(SQL_CREATE_INDEX_PQ.format(table_name=self.table_name, pq_m=int(dimension / 4)))
cur.execute("SET hnsw_earlystop_threshold = 320")
if not self.pq_enabled:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
redis_client.set(index_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
pks = []
for i, doc in enumerate(documents):
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
embeddings[i],
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_values(
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
)
return pks
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
return cur.fetchone() is not None
def get_by_ids(self, ids: list[str]) -> list[Document]:
with self._get_cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs
def delete_by_ids(self, ids: list[str]):
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
# Scenario 1: extract a document fails, resulting in a table not being created.
# Then clicking the retry button triggers a delete operation on an empty list.
if not ids:
return
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
def delete_by_metadata_field(self, key: str, value: str):
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search the nearest neighbors to a vector.
:param query_vector: The input vector to search for similar items.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
f" ORDER BY distance LIMIT {top_k}",
(json.dumps(query_vector),),
)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
with self._get_cursor() as cur:
cur.execute(
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
FROM {self.table_name}
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
ORDER BY score DESC
LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query
(f"'{query}'", f"'{query}'"),
)
docs = []
for record in cur:
metadata, text, score = record
metadata["score"] = score
docs.append(Document(page_content=text, metadata=metadata))
return docs
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
def _create_collection(self, dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class OpenGaussFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenGauss:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENGAUSS, collection_name))
return OpenGauss(
collection_name=collection_name,
config=OpenGaussConfig(
host=dify_config.OPENGAUSS_HOST or "localhost",
port=dify_config.OPENGAUSS_PORT,
user=dify_config.OPENGAUSS_USER or "postgres",
password=dify_config.OPENGAUSS_PASSWORD or "",
database=dify_config.OPENGAUSS_DATABASE or "dify",
min_connection=dify_config.OPENGAUSS_MIN_CONNECTION,
max_connection=dify_config.OPENGAUSS_MAX_CONNECTION,
enable_pq=dify_config.OPENGAUSS_ENABLE_PQ or False,
),
)

View File

@@ -0,0 +1,304 @@
import json
import logging
from typing import Any
from uuid import uuid4
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from configs import dify_config
from configs.middleware.vdb.opensearch_config import AuthMethod
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class OpenSearchConfig(BaseModel):
host: str
port: int
secure: bool = False # use_ssl
verify_certs: bool = True
auth_method: AuthMethod = AuthMethod.BASIC
user: str | None = None
password: str | None = None
aws_region: str | None = None
aws_service: str | None = None
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values.get("host"):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get("port"):
raise ValueError("config OPENSEARCH_PORT is required")
if values.get("auth_method") == "aws_managed_iam":
if not values.get("aws_region"):
raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
if not values.get("aws_service"):
raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
if not values.get("OPENSEARCH_SECURE") and values.get("OPENSEARCH_VERIFY_CERTS"):
raise ValueError("verify_certs=True requires secure (HTTPS) connection")
return values
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
import boto3
return Urllib3AWSV4SignerAuth(
credentials=boto3.Session().get_credentials(),
region=self.aws_region,
service=self.aws_service, # type: ignore[arg-type]
)
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.verify_certs,
"connection_class": Urllib3HttpConnection,
"pool_maxsize": 20,
}
if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB")
params["http_auth"] = (self.user, self.password)
elif self.auth_method == "aws_managed_iam":
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
params["http_auth"] = self.create_aws_managed_iam_auth()
return params
class OpenSearchVector(BaseVector):
def __init__(self, collection_name: str, config: OpenSearchConfig):
super().__init__(collection_name)
self._client_config = config
self._client = OpenSearch(**config.to_opensearch_params())
def get_type(self) -> str:
return VectorType.OPENSEARCH
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
actions = []
for i in range(len(documents)):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
"_source": {
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY: documents[i].metadata,
},
}
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
if self._client_config.aws_service != "aoss":
action["_id"] = uuid4().hex
actions.append(action)
helpers.bulk(
client=self._client,
actions=actions,
timeout=30,
max_retries=3,
)
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY}.{key}": value}}}
response = self._client.search(index=self._collection_name.lower(), body=query)
if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self.delete_by_ids(ids)
def delete_by_ids(self, ids: list[str]):
index_name = self._collection_name.lower()
if not self._client.indices.exists(index=index_name):
logger.warning("Index %s does not exist", index_name)
return
# Obtaining All Actual Documents_ID
actual_ids = []
for doc_id in ids:
es_ids = self.get_ids_by_metadata_field("doc_id", doc_id)
if es_ids:
actual_ids.extend(es_ids)
else:
logger.warning("Document with metadata doc_id %s not found for deletion", doc_id)
if actual_ids:
actions = [{"_op_type": "delete", "_index": index_name, "_id": es_id} for es_id in actual_ids]
try:
helpers.bulk(self._client, actions)
except BulkIndexError as e:
for error in e.errors:
delete_error = error.get("delete", {})
status = delete_error.get("status")
doc_id = delete_error.get("_id")
if status == 404:
logger.warning("Document not found for deletion: %s", doc_id)
else:
logger.exception("Error deleting document: %s", error)
def delete(self):
self._client.indices.delete(index=self._collection_name.lower(), ignore_unavailable=True)
def text_exists(self, id: str) -> bool:
try:
self._client.get(index=self._collection_name.lower(), id=id)
return True
except:
return False
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Make sure query_vector is a list
if not isinstance(query_vector, list):
raise ValueError("query_vector should be a list of floats")
# Check whether query_vector is a floating-point number list
if not all(isinstance(x, float) for x in query_vector):
raise ValueError("All elements in query_vector should be floats")
query = {
"size": kwargs.get("top_k", 4),
"query": {"knn": {Field.VECTOR: {Field.VECTOR: query_vector, "k": kwargs.get("top_k", 4)}}},
}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query["query"] = {
"script_score": {
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID: document_ids_filter}}]}},
"script": {
"source": "knn_score",
"lang": "knn",
"params": {"field": Field.VECTOR, "query_value": query_vector, "space_type": "l2"},
},
}
}
try:
response = self._client.search(index=self._collection_name.lower(), body=query)
except Exception:
logger.exception("Error executing vector search, query: %s", query)
raise
docs = []
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY, {})
# Make sure metadata is a dictionary
if metadata is None:
metadata = {}
metadata["score"] = hit["_score"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if hit["_score"] >= score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY), metadata=metadata)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
full_text_query = {"query": {"bool": {"must": [{"match": {Field.CONTENT_KEY.value: query}}]}}}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
full_text_query["query"]["bool"]["filter"] = [{"terms": {"metadata.document_id": document_ids_filter}}]
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
docs = []
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY)
vector = hit["_source"].get(Field.VECTOR)
page_content = hit["_source"].get(Field.CONTENT_KEY)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
return docs
def create_collection(
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
):
lock_name = f"vector_indexing_lock_{self._collection_name.lower()}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}"
if redis_client.get(collection_exist_cache_key):
logger.info("Collection %s already exists.", self._collection_name.lower())
return
if not self._client.indices.exists(index=self._collection_name.lower()):
index_body = {
"settings": {"index": {"knn": True}},
"mappings": {
"properties": {
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: {
"type": "knn_vector",
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
"method": {
"name": "hnsw",
"space_type": "l2",
"engine": "faiss",
"parameters": {"ef_construction": 64, "m": 8},
},
},
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type
"document_id": {"type": "keyword"},
},
},
}
},
}
logger.info("Creating OpenSearch index %s", self._collection_name.lower())
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class OpenSearchVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST or "localhost",
port=dify_config.OPENSEARCH_PORT,
secure=dify_config.OPENSEARCH_SECURE,
verify_certs=dify_config.OPENSEARCH_VERIFY_CERTS,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
aws_region=dify_config.OPENSEARCH_AWS_REGION,
aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
)
return OpenSearchVector(collection_name=collection_name, config=open_search_config)

View File

@@ -0,0 +1,387 @@
import array
import json
import logging
import re
import uuid
from typing import Any
import jieba.posseg as pseg # type: ignore
import numpy
import oracledb
from oracledb.connection import Connection
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
oracledb.defaults.fetch_lobs = False
class OracleVectorConfig(BaseModel):
user: str
password: str
dsn: str
config_dir: str | None = None
wallet_location: str | None = None
wallet_password: str | None = None
is_autonomous: bool = False
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["user"]:
raise ValueError("config ORACLE_USER is required")
if not values["password"]:
raise ValueError("config ORACLE_PASSWORD is required")
if not values["dsn"]:
raise ValueError("config ORACLE_DSN is required")
if values.get("is_autonomous", False):
if not values.get("config_dir"):
raise ValueError("config_dir is required for autonomous database")
if not values.get("wallet_location"):
raise ValueError("wallet_location is required for autonomous database")
if not values.get("wallet_password"):
raise ValueError("wallet_password is required for autonomous database")
return values
SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} (
id varchar2(100)
,text CLOB NOT NULL
,meta JSON
,embedding vector NOT NULL
)
"""
SQL_CREATE_INDEX = """
CREATE INDEX IF NOT EXISTS idx_docs_{table_name} ON {table_name}(text)
INDEXTYPE IS CTXSYS.CONTEXT PARAMETERS
('FILTER CTXSYS.NULL_FILTER SECTION GROUP CTXSYS.HTML_SECTION_GROUP LEXER world_lexer')
"""
class OracleVector(BaseVector):
def __init__(self, collection_name: str, config: OracleVectorConfig):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}"
self.config = config
def get_type(self) -> str:
return VectorType.ORACLE
def numpy_converter_in(self, value):
if value.dtype == numpy.float64:
dtype = "d"
elif value.dtype == numpy.float32:
dtype = "f"
else:
dtype = "b"
return array.array(dtype, value)
def input_type_handler(self, cursor, value, arraysize):
if isinstance(value, numpy.ndarray):
return cursor.var(
oracledb.DB_TYPE_VECTOR,
arraysize=arraysize,
inconverter=self.numpy_converter_in,
)
def numpy_converter_out(self, value):
if value.typecode == "b":
return numpy.array(value, copy=False, dtype=numpy.int8)
elif value.typecode == "f":
return numpy.array(value, copy=False, dtype=numpy.float32)
else:
return numpy.array(value, copy=False, dtype=numpy.float64)
def output_type_handler(self, cursor, metadata):
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
return cursor.var(
metadata.type_code,
arraysize=cursor.arraysize,
outconverter=self.numpy_converter_out,
)
def _get_connection(self) -> Connection:
if self.config.is_autonomous:
connection = oracledb.connect(
user=self.config.user,
password=self.config.password,
dsn=self.config.dsn,
config_dir=self.config.config_dir,
wallet_location=self.config.wallet_location,
wallet_password=self.config.wallet_password,
)
return connection
else:
connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn)
return connection
def _create_connection_pool(self, config: OracleVectorConfig):
pool_params = {
"user": config.user,
"password": config.password,
"dsn": config.dsn,
"min": 1,
"max": 5,
"increment": 1,
}
if config.is_autonomous:
pool_params.update(
{
"config_dir": config.config_dir,
"wallet_location": config.wallet_location,
"wallet_password": config.wallet_password,
}
)
return oracledb.create_pool(**pool_params)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
pks = []
for i, doc in enumerate(documents):
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
# array.array("f", embeddings[i]),
numpy.array(embeddings[i]),
)
)
with self._get_connection() as conn:
conn.inputtypehandler = self.input_type_handler
conn.outputtypehandler = self.output_type_handler
# with conn.cursor() as cur:
# cur.executemany(
# f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
# )
# conn.commit()
for value in values:
with conn.cursor() as cur:
try:
cur.execute(
f"""INSERT INTO {self.table_name} (id, text, meta, embedding)
VALUES (:1, :2, :3, :4)""",
value,
)
conn.commit()
except Exception:
logger.exception("Failed to insert record %s into %s", value[0], self.table_name)
conn.close()
return pks
def text_exists(self, id: str) -> bool:
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = :1", (id,))
return cur.fetchone() is not None
conn.close()
def get_by_ids(self, ids: list[str]) -> list[Document]:
if not ids:
return []
with self._get_connection() as conn:
with conn.cursor() as cur:
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
self.pool.release(connection=conn)
conn.close()
return docs
def delete_by_ids(self, ids: list[str]):
if not ids:
return
with self._get_connection() as conn:
with conn.cursor() as cur:
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
conn.commit()
conn.close()
def delete_by_metadata_field(self, key: str, value: str):
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,))
conn.commit()
conn.close()
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search the nearest neighbors to a vector.
:param query_vector: The input vector to search for similar items.
:param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector.
"""
# Validate and sanitize top_k to prevent SQL injection
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
top_k = 4 # Use default if invalid
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
params = [numpy.array(query_vector)]
if document_ids_filter:
placeholders = ", ".join(f":{i + 2}" for i in range(len(document_ids_filter)))
where_clause = f"WHERE JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
params.extend(document_ids_filter)
with self._get_connection() as conn:
conn.inputtypehandler = self.input_type_handler
conn.outputtypehandler = self.output_type_handler
with conn.cursor() as cur:
cur.execute(
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
AS distance FROM {self.table_name}
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
params,
)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
conn.close()
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# lazy import
import nltk # type: ignore
from nltk.corpus import stopwords # type: ignore
# Validate and sanitize top_k to prevent SQL injection
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
top_k = 5 # Use default if invalid
# just not implement fetch by score_threshold now, may be later
if len(query) > 0:
# Check which language the query is in
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
match = zh_pattern.search(query)
entities = []
# match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split
if match:
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名ns: 地名nt: 机构名
current_entity += word
else:
if current_entity:
entities.append(current_entity)
current_entity = ""
if current_entity:
entities.append(current_entity)
else:
try:
nltk.data.find("tokenizers/punkt")
nltk.data.find("corpora/stopwords")
except LookupError:
raise LookupError("Unable to find the required NLTK data package: punkt and stopwords")
e_str = re.sub(r"[^\w ]", "", query)
all_tokens = nltk.word_tokenize(e_str)
stop_words = stopwords.words("english")
for token in all_tokens:
if token not in stop_words:
entities.append(token)
with self._get_connection() as conn:
with conn.cursor() as cur:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
params: dict[str, Any] = {"kk": " ACCUM ".join(entities)}
if document_ids_filter:
placeholders = []
for i, doc_id in enumerate(document_ids_filter):
param_name = f"doc_id_{i}"
placeholders.append(f":{param_name}")
params[param_name] = doc_id
where_clause = f" AND JSON_VALUE(meta, '$.document_id') IN ({', '.join(placeholders)}) "
cur.execute(
f"""select meta, text, embedding FROM {self.table_name}
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
order by score(1) desc fetch first {top_k} rows only""",
params,
)
docs = []
for record in cur:
metadata, text, embedding = record
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
conn.close()
return docs
else:
return [Document(page_content="", metadata={})]
def delete(self):
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
conn.commit()
conn.close()
def _create_collection(self, dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
redis_client.set(collection_exist_cache_key, 1, ex=3600)
with conn.cursor() as cur:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
conn.commit()
conn.close()
class OracleVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OracleVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
return OracleVector(
collection_name=collection_name,
config=OracleVectorConfig(
user=dify_config.ORACLE_USER or "system",
password=dify_config.ORACLE_PASSWORD or "oracle",
dsn=dify_config.ORACLE_DSN or "oracle:1521/freepdb1",
config_dir=dify_config.ORACLE_CONFIG_DIR,
wallet_location=dify_config.ORACLE_WALLET_LOCATION,
wallet_password=dify_config.ORACLE_WALLET_PASSWORD,
is_autonomous=dify_config.ORACLE_IS_AUTONOMOUS,
),
)

View File

@@ -0,0 +1,12 @@
from uuid import UUID
from numpy import ndarray
from sqlalchemy.orm import DeclarativeBase, Mapped
class CollectionORM(DeclarativeBase):
__tablename__: str
id: Mapped[UUID]
text: Mapped[str]
meta: Mapped[dict]
vector: Mapped[ndarray]

View File

@@ -0,0 +1,235 @@
import json
import logging
from typing import Any
from uuid import UUID, uuid4
from numpy import ndarray
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
from pydantic import BaseModel, model_validator
from sqlalchemy import Float, create_engine, insert, select, text
from sqlalchemy import text as sql_text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class PgvectoRSConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config PGVECTO_RS_HOST is required")
if not values["port"]:
raise ValueError("config PGVECTO_RS_PORT is required")
if not values["user"]:
raise ValueError("config PGVECTO_RS_USER is required")
if not values["password"]:
raise ValueError("config PGVECTO_RS_PASSWORD is required")
if not values["database"]:
raise ValueError("config PGVECTO_RS_DATABASE is required")
return values
class PGVectoRS(BaseVector):
def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int):
super().__init__(collection_name)
self._client_config = config
self._url = (
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self._client = create_engine(self._url)
with Session(self._client) as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
session.commit()
self._fields: list[str] = []
class _Table(CollectionORM):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
id: Mapped[UUID] = mapped_column(
postgresql.UUID(as_uuid=True),
primary_key=True,
)
text: Mapped[str]
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
vector: Mapped[ndarray] = mapped_column(VECTOR(dim))
self._table = _Table
self._distance_op = "<=>"
def get_type(self) -> str:
return VectorType.PGVECTO_RS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.create_collection(len(embeddings[0]))
self.add_texts(texts, embeddings)
def create_collection(self, dimension: int):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
with Session(self._client) as session:
create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS {self._collection_name} (
id UUID PRIMARY KEY,
text TEXT NOT NULL,
meta JSONB NOT NULL,
vector vector({dimension}) NOT NULL
) using heap;
""")
session.execute(create_statement)
index_statement = sql_text(f"""
CREATE INDEX IF NOT EXISTS {index_name}
ON {self._collection_name} USING vectors(vector vector_l2_ops)
WITH (options = $$
optimizing.optimizing_threads = 30
segment.max_growing_segment_size = 2000
segment.max_sealed_segment_size = 30000000
[indexing.hnsw]
m=30
ef_construction=500
$$);
""")
session.execute(index_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
pks = []
with Session(self._client) as session:
for document, embedding in zip(documents, embeddings):
pk = uuid4()
session.execute(
insert(self._table).values(
id=pk,
text=document.page_content,
meta=document.metadata,
vector=embedding,
),
)
pks.append(pk)
session.commit()
return pks
def get_ids_by_metadata_field(self, key: str, value: str):
result = None
with Session(self._client) as session:
select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ")
result = session.execute(select_statement).fetchall()
if result:
return [item[0] for item in result]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
with Session(self._client) as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {"ids": ids})
session.commit()
def delete_by_ids(self, ids: list[str]):
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); "
)
result = session.execute(select_statement, {"doc_ids": ids}).fetchall()
if result:
ids = [item[0] for item in result]
if ids:
with Session(self._client) as session:
select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)")
session.execute(select_statement, {"ids": ids})
session.commit()
def delete(self):
with Session(self._client) as session:
session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}"))
session.commit()
def text_exists(self, id: str) -> bool:
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; "
)
result = session.execute(select_statement).fetchall()
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
with Session(self._client) as session:
stmt = (
select(
self._table,
self._table.vector.op(self._distance_op, return_type=Float)(
query_vector,
).label("distance"),
)
.limit(kwargs.get("top_k", 4))
.order_by("distance")
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
stmt = stmt.where(self._table.meta["document_id"].in_(document_ids_filter))
res = session.execute(stmt)
results = [(row[0], row[1]) for row in res]
# Organize results.
docs = []
for record, dis in results:
metadata = record.meta
score = 1 - dis
metadata["score"] = score
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score >= score_threshold:
doc = Document(page_content=record.text, metadata=metadata)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
class PGVectoRSFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTO_RS, collection_name))
dim = len(embeddings.embed_query("pgvecto_rs"))
return PGVectoRS(
collection_name=collection_name,
config=PgvectoRSConfig(
host=dify_config.PGVECTO_RS_HOST or "localhost",
port=dify_config.PGVECTO_RS_PORT or 5432,
user=dify_config.PGVECTO_RS_USER or "postgres",
password=dify_config.PGVECTO_RS_PASSWORD or "",
database=dify_config.PGVECTO_RS_DATABASE or "postgres",
),
dim=dim,
)

View File

@@ -0,0 +1,291 @@
import hashlib
import json
import logging
import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.errors
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class PGVectorConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
min_connection: int
max_connection: int
pg_bigm: bool = False
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config PGVECTOR_HOST is required")
if not values["port"]:
raise ValueError("config PGVECTOR_PORT is required")
if not values["user"]:
raise ValueError("config PGVECTOR_USER is required")
if not values["password"]:
raise ValueError("config PGVECTOR_PASSWORD is required")
if not values["database"]:
raise ValueError("config PGVECTOR_DATABASE is required")
if not values["min_connection"]:
raise ValueError("config PGVECTOR_MIN_CONNECTION is required")
if not values["max_connection"]:
raise ValueError("config PGVECTOR_MAX_CONNECTION is required")
if values["min_connection"] > values["max_connection"]:
raise ValueError("config PGVECTOR_MIN_CONNECTION should less than PGVECTOR_MAX_CONNECTION")
return values
SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} (
id UUID PRIMARY KEY,
text TEXT NOT NULL,
meta JSONB NOT NULL,
embedding vector({dimension}) NOT NULL
) using heap;
"""
SQL_CREATE_INDEX = """
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx_{index_hash} ON {table_name}
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
"""
SQL_CREATE_INDEX_PG_BIGM = """
CREATE INDEX IF NOT EXISTS bigm_idx_{index_hash} ON {table_name}
USING gin (text gin_bigm_ops);
"""
class PGVector(BaseVector):
def __init__(self, collection_name: str, config: PGVectorConfig):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}"
self.index_hash = hashlib.md5(self.table_name.encode()).hexdigest()[:8]
self.pg_bigm = config.pg_bigm
def get_type(self) -> str:
return VectorType.PGVECTOR
def _create_connection_pool(self, config: PGVectorConfig):
return psycopg2.pool.SimpleConnectionPool(
config.min_connection,
config.max_connection,
host=config.host,
port=config.port,
user=config.user,
password=config.password,
database=config.database,
)
@contextmanager
def _get_cursor(self):
conn = self.pool.getconn()
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
self.pool.putconn(conn)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
pks = []
for i, doc in enumerate(documents):
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
embeddings[i],
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_values(
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
)
return pks
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
return cur.fetchone() is not None
def get_by_ids(self, ids: list[str]) -> list[Document]:
with self._get_cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs
def delete_by_ids(self, ids: list[str]):
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
# Scenario 1: extract a document fails, resulting in a table not being created.
# Then clicking the retry button triggers a delete operation on an empty list.
if not ids:
return
with self._get_cursor() as cur:
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
except psycopg2.errors.UndefinedTable:
# table not exists
logger.warning("Table %s not found, skipping delete operation.", self.table_name)
return
except Exception as e:
raise e
def delete_by_metadata_field(self, key: str, value: str):
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search the nearest neighbors to a vector.
:param query_vector: The input vector to search for similar items.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" WHERE meta->>'document_id' in ({document_ids}) "
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
f" {where_clause}"
f" ORDER BY distance LIMIT {top_k}",
(json.dumps(query_vector),),
)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
with self._get_cursor() as cur:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND meta->>'document_id' in ({document_ids}) "
if self.pg_bigm:
cur.execute("SET pg_bigm.similarity_limit TO 0.000001")
cur.execute(
f"""SELECT meta, text, bigm_similarity(unistr(%s), coalesce(text, '')) AS score
FROM {self.table_name}
WHERE text =%% unistr(%s)
{where_clause}
ORDER BY score DESC
LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query
(f"'{query}'", f"'{query}'"),
)
else:
cur.execute(
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
FROM {self.table_name}
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
{where_clause}
ORDER BY score DESC
LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query
(f"'{query}'", f"'{query}'"),
)
docs = []
for record in cur:
metadata, text, score = record
metadata["score"] = score
docs.append(Document(page_content=text, metadata=metadata))
return docs
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
def _create_collection(self, dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# PG hnsw index only support 2000 dimension or less
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
if dimension <= 2000:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name, index_hash=self.index_hash))
if self.pg_bigm:
cur.execute(SQL_CREATE_INDEX_PG_BIGM.format(table_name=self.table_name, index_hash=self.index_hash))
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class PGVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
return PGVector(
collection_name=collection_name,
config=PGVectorConfig(
host=dify_config.PGVECTOR_HOST or "localhost",
port=dify_config.PGVECTOR_PORT,
user=dify_config.PGVECTOR_USER or "postgres",
password=dify_config.PGVECTOR_PASSWORD or "",
database=dify_config.PGVECTOR_DATABASE or "postgres",
min_connection=dify_config.PGVECTOR_MIN_CONNECTION,
max_connection=dify_config.PGVECTOR_MAX_CONNECTION,
pg_bigm=dify_config.PGVECTOR_PG_BIGM,
),
)

View File

@@ -0,0 +1,243 @@
import json
import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
class VastbaseVectorConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
min_connection: int
max_connection: int
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config VASTBASE_HOST is required")
if not values["port"]:
raise ValueError("config VASTBASE_PORT is required")
if not values["user"]:
raise ValueError("config VASTBASE_USER is required")
if not values["password"]:
raise ValueError("config VASTBASE_PASSWORD is required")
if not values["database"]:
raise ValueError("config VASTBASE_DATABASE is required")
if not values["min_connection"]:
raise ValueError("config VASTBASE_MIN_CONNECTION is required")
if not values["max_connection"]:
raise ValueError("config VASTBASE_MAX_CONNECTION is required")
if values["min_connection"] > values["max_connection"]:
raise ValueError("config VASTBASE_MIN_CONNECTION should less than VASTBASE_MAX_CONNECTION")
return values
SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} (
id UUID PRIMARY KEY,
text TEXT NOT NULL,
meta JSONB NOT NULL,
embedding floatvector({dimension}) NOT NULL
);
"""
SQL_CREATE_INDEX = """
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
USING hnsw (embedding floatvector_cosine_ops) WITH (m = 16, ef_construction = 64);
"""
class VastbaseVector(BaseVector):
def __init__(self, collection_name: str, config: VastbaseVectorConfig):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}"
def get_type(self) -> str:
return VectorType.VASTBASE
def _create_connection_pool(self, config: VastbaseVectorConfig):
return psycopg2.pool.SimpleConnectionPool(
config.min_connection,
config.max_connection,
host=config.host,
port=config.port,
user=config.user,
password=config.password,
database=config.database,
)
@contextmanager
def _get_cursor(self):
conn = self.pool.getconn()
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
self.pool.putconn(conn)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
pks = []
for i, doc in enumerate(documents):
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
embeddings[i],
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_values(
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
)
return pks
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
return cur.fetchone() is not None
def get_by_ids(self, ids: list[str]) -> list[Document]:
with self._get_cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs
def delete_by_ids(self, ids: list[str]):
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
# Scenario 1: extract a document fails, resulting in a table not being created.
# Then clicking the retry button triggers a delete operation on an empty list.
if not ids:
return
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
def delete_by_metadata_field(self, key: str, value: str):
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search the nearest neighbors to a vector.
:param query_vector: The input vector to search for similar items.
:param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
f" ORDER BY distance LIMIT {top_k}",
(json.dumps(query_vector),),
)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
with self._get_cursor() as cur:
cur.execute(
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
FROM {self.table_name}
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
ORDER BY score DESC
LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query
(f"'{query}'", f"'{query}'"),
)
docs = []
for record in cur:
metadata, text, score = record
metadata["score"] = score
docs.append(Document(page_content=text, metadata=metadata))
return docs
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
def _create_collection(self, dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# Vastbase 支持的向量维度取值范围为 [1,16000]
if dimension <= 16000:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class VastbaseVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VastbaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.VASTBASE, collection_name))
return VastbaseVector(
collection_name=collection_name,
config=VastbaseVectorConfig(
host=dify_config.VASTBASE_HOST or "localhost",
port=dify_config.VASTBASE_PORT,
user=dify_config.VASTBASE_USER or "dify",
password=dify_config.VASTBASE_PASSWORD or "",
database=dify_config.VASTBASE_DATABASE or "dify",
min_connection=dify_config.VASTBASE_MIN_CONNECTION,
max_connection=dify_config.VASTBASE_MAX_CONNECTION,
),
)

View File

@@ -0,0 +1,487 @@
import json
import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Union
import qdrant_client
from flask import current_app
from pydantic import BaseModel
from qdrant_client.http import models as rest
from qdrant_client.http.models import (
FilterSelector,
HnswConfigDiff,
PayloadSchemaType,
TextIndexParams,
TextIndexType,
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetCollectionBinding
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
class PathQdrantParams(BaseModel):
path: str
class UrlQdrantParams(BaseModel):
url: str
api_key: str | None
timeout: float
verify: bool
grpc_port: int
prefer_grpc: bool
class QdrantConfig(BaseModel):
endpoint: str
api_key: str | None = None
timeout: float = 20
root_path: str | None = None
grpc_port: int = 6334
prefer_grpc: bool = False
replication_factor: int = 1
write_consistency_factor: int = 1
def to_qdrant_params(self) -> PathQdrantParams | UrlQdrantParams:
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
if not self.root_path:
raise ValueError("Root path is not set")
path = os.path.join(self.root_path, path)
return PathQdrantParams(path=path)
else:
return UrlQdrantParams(
url=self.endpoint,
api_key=self.api_key,
timeout=self.timeout,
verify=self.endpoint.startswith("https"),
grpc_port=self.grpc_port,
prefer_grpc=self.prefer_grpc,
)
class QdrantVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"):
super().__init__(collection_name)
self._client_config = config
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params().model_dump())
self._distance_func = distance_func.upper()
self._group_id = group_id
def get_type(self) -> str:
return VectorType.QDRANT
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:
# get embedding vector size
vector_size = len(embeddings[0])
# get collection name
collection_name = self._collection_name
# create collection
self.create_collection(collection_name, vector_size)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
lock_name = f"vector_indexing_lock_{collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
collection_name = collection_name or uuid.uuid4().hex
all_collection_name = []
collections_response = self._client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
from qdrant_client.http import models as rest
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[self._distance_func],
)
hnsw_config = HnswConfigDiff(
m=0,
payload_m=16,
ef_construct=100,
full_scan_threshold=10000,
max_indexing_threads=0,
on_disk=False,
)
self._client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
replication_factor=self._client_config.replication_factor,
write_consistency_factor=self._client_config.write_consistency_factor,
)
# create group_id payload index
self._client.create_payload_index(
collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD
)
# create doc_id payload index
self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD)
# create document_id payload index
self._client.create_payload_index(
collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD
)
# create full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True,
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
added_ids = []
# Filter out None values from metadatas list to match expected type
filtered_metadatas = [m for m in metadatas if m is not None]
for batch_ids, points in self._generate_rest_batches(
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
):
self._client.upsert(collection_name=self._collection_name, points=points)
added_ids.extend(batch_ids)
return added_ids
def _generate_rest_batches(
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: list[dict] | None = None,
ids: Sequence[str] | None = None,
batch_size: int = 64,
group_id: str | None = None,
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
texts_iterator = iter(texts)
embeddings_iterator = iter(embeddings)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
# Generate the embeddings for all the texts in a batch
batch_embeddings = list(islice(embeddings_iterator, batch_size))
points = [
rest.PointStruct(
id=point_id,
vector=vector,
payload=payload,
)
for point_id, vector, payload in zip(
batch_ids,
batch_embeddings,
self._build_payloads(
batch_texts,
batch_metadatas,
Field.CONTENT_KEY,
Field.METADATA_KEY,
group_id or "", # Ensure group_id is never None
Field.GROUP_KEY,
),
)
]
yield batch_ids, points
@classmethod
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: list[dict] | None,
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str,
) -> list[dict]:
payloads = []
for i, text in enumerate(texts):
if text is None:
raise ValueError(
"At least one of the texts is None. Please remove it before "
"calling .from_texts or .add_texts on Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id})
return payloads
def delete_by_metadata_field(self, key: str, value: str):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
try:
filter = models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
)
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def delete(self):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def delete_by_ids(self, ids: list[str]):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchAny(any=ids),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def text_exists(self, id: str) -> bool:
all_collection_name = []
collections_response = self._client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if self._collection_name not in all_collection_name:
return False
response = self._client.retrieve(collection_name=self._collection_name, ids=[id])
return len(response) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score_threshold >= 1:
# return empty list because some versions of qdrant may response with 400 bad request
# and at the same time, the score_threshold with value 1 may be valid for other vector stores
return []
filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
],
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
if filter.must:
filter.must.append(
models.FieldCondition(
key="metadata.document_id",
match=models.MatchAny(any=document_ids_filter),
)
)
results = self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
query_filter=filter,
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=score_threshold,
)
docs = []
for result in results:
if result.payload is None:
continue
metadata = result.payload.get(Field.METADATA_KEY) or {}
# duplicate check score threshold
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY, ""),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs most similar by bm25.
Returns:
List of documents most similar to the query text and distance for each.
"""
from qdrant_client.http import models
scroll_filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
),
]
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
if scroll_filter.must:
scroll_filter.must.append(
models.FieldCondition(
key="metadata.document_id",
match=models.MatchAny(any=document_ids_filter),
)
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,
limit=kwargs.get("top_k", 2),
with_payload=True,
with_vectors=True,
)
results = response[0]
documents = []
for result in results:
if result:
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
documents.append(document)
return documents
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client._load()
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
vector=scored_point.vector,
metadata=scored_point.payload.get(metadata_payload_key) or {},
)
class QdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
if dataset.collection_binding_id:
stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id)
dataset_collection_binding = db.session.scalars(stmt).one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError("Dataset Collection Bindings does not exist!")
else:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
if not dataset.index_struct_dict:
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
return QdrantVector(
collection_name=collection_name,
group_id=dataset.id,
config=QdrantConfig(
endpoint=dify_config.QDRANT_URL or "",
api_key=dify_config.QDRANT_API_KEY,
root_path=str(current_app.config.root_path),
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
replication_factor=dify_config.QDRANT_REPLICATION_FACTOR,
),
)

View File

@@ -0,0 +1,318 @@
import json
import logging
import uuid
from typing import Any
from pydantic import BaseModel, model_validator
from sqlalchemy import Column, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from models.dataset import Dataset
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
Base = declarative_base() # type: Any
class RelytConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config RELYT_HOST is required")
if not values["port"]:
raise ValueError("config RELYT_PORT is required")
if not values["user"]:
raise ValueError("config RELYT_USER is required")
if not values["password"]:
raise ValueError("config RELYT_PASSWORD is required")
if not values["database"]:
raise ValueError("config RELYT_DATABASE is required")
return values
class RelytVector(BaseVector):
def __init__(self, collection_name: str, config: RelytConfig, group_id: str):
super().__init__(collection_name)
self.embedding_dimension = 1536
self._client_config = config
self._url = (
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self.client = create_engine(self._url)
self._fields: list[str] = []
self._group_id = group_id
def get_type(self) -> str:
return VectorType.RELYT
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.create_collection(len(embeddings[0]))
self.embedding_dimension = len(embeddings[0])
self.add_texts(texts, embeddings)
def create_collection(self, dimension: int):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
with Session(self.client) as session:
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
session.execute(drop_statement)
create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS "{self._collection_name}" (
id TEXT PRIMARY KEY,
document TEXT NOT NULL,
metadata JSON NOT NULL,
embedding vector({dimension}) NOT NULL
) using heap;
""")
session.execute(create_statement)
index_statement = sql_text(f"""
CREATE INDEX {index_name}
ON "{self._collection_name}" USING vectors(embedding vector_l2_ops)
WITH (options = $$
optimizing.optimizing_threads = 30
segment.max_growing_segment_size = 2000
segment.max_sealed_segment_size = 30000000
[indexing.hnsw]
m=30
ef_construction=500
$$);
""")
session.execute(index_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
ids = [str(uuid.uuid1()) for _ in documents]
metadatas = [d.metadata for d in documents if d.metadata is not None]
for metadata in metadatas:
metadata["group_id"] = self._group_id
texts = [d.page_content for d in documents]
# Define the table schema
chunks_table = Table(
self._collection_name,
Base.metadata,
Column("id", TEXT, primary_key=True),
Column("embedding", VECTOR(len(embeddings[0]))),
Column("document", String, nullable=True),
Column("metadata", JSON, nullable=True),
extend_existing=True,
)
chunks_table_data = []
with self.client.connect() as conn, conn.begin():
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
chunks_table_data.append(
{
"id": chunk_id,
"embedding": embedding,
"document": document,
"metadata": metadata,
}
)
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == 500:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(chunks_table).values(chunks_table_data))
return ids
def get_ids_by_metadata_field(self, key: str, value: str):
result = None
with Session(self.client) as session:
select_statement = sql_text(
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'{key}' = '{value}'; """
)
result = session.execute(select_statement).fetchall()
if result:
return [item[0] for item in result]
else:
return None
def delete_by_uuids(self, ids: list[str] | None = None):
"""Delete by vector IDs.
Args:
ids: List of ids to delete.
"""
from pgvecto_rs.sqlalchemy import VECTOR
if ids is None:
raise ValueError("No ids provided to delete.")
# Define the table schema
chunks_table = Table(
self._collection_name,
Base.metadata,
Column("id", TEXT, primary_key=True),
Column("embedding", VECTOR(self.embedding_dimension)),
Column("document", String, nullable=True),
Column("metadata", JSON, nullable=True),
extend_existing=True,
)
try:
with self.client.connect() as conn, conn.begin():
delete_condition = chunks_table.c.id.in_(ids)
conn.execute(chunks_table.delete().where(delete_condition))
return True
except Exception:
logger.exception("Delete operation failed for collection %s", self._collection_name)
return False
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self.delete_by_uuids(ids)
def delete_by_ids(self, ids: list[str]):
with Session(self.client) as session:
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text(
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
)
result = session.execute(select_statement).fetchall()
if result:
ids = [item[0] for item in result]
self.delete_by_uuids(ids)
def delete(self):
with Session(self.client) as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
session.commit()
def text_exists(self, id: str) -> bool:
with Session(self.client) as session:
select_statement = sql_text(
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = '{id}' limit 1; """
)
result = session.execute(select_statement).fetchall()
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
filter = kwargs.get("filter", {})
if document_ids_filter:
filter["document_id"] = document_ids_filter
results = self.similarity_search_with_score_by_vector(
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter
)
# Organize results.
docs = []
for document, score in results:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if 1 - score >= score_threshold:
docs.append(document)
return docs
def similarity_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: dict | None = None,
) -> list[tuple[Document, float]]:
# Add the filter if provided
filter_condition = ""
if filter is not None:
conditions = [
f"metadata->>'{key!r}' in ({', '.join(map(repr, value))})"
if len(value) > 1
else f"metadata->>'{key!r}' = {value[0]!r}"
for key, value in filter.items()
]
filter_condition = f"WHERE {' AND '.join(conditions)}"
# Define the base query
sql_query = f"""
set vectors.enable_search_growing = on;
set vectors.enable_search_write = on;
SELECT document, metadata, embedding <-> :embedding as distance
FROM "{self._collection_name}"
{filter_condition}
ORDER BY embedding <-> :embedding
LIMIT :k
"""
# Set up the query parameters
embedding_str = ", ".join(format(x) for x in embedding)
embedding_str = "[" + embedding_str + "]"
params = {"embedding": embedding_str, "k": k}
# Execute the query and fetch the results
with self.client.connect() as conn:
results = conn.execute(sql_text(sql_query), params).fetchall()
documents_with_scores = [
(
Document(
page_content=result.document,
metadata=result.metadata,
),
result.distance,
)
for result in results
]
return documents_with_scores
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz/relyt doesn't support bm25 search
return []
class RelytVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.RELYT, collection_name))
return RelytVector(
collection_name=collection_name,
config=RelytConfig(
host=dify_config.RELYT_HOST or "localhost",
port=dify_config.RELYT_PORT,
user=dify_config.RELYT_USER or "",
password=dify_config.RELYT_PASSWORD or "",
database=dify_config.RELYT_DATABASE or "default",
),
group_id=dataset.id,
)

View File

@@ -0,0 +1,413 @@
import json
import logging
import math
from collections.abc import Iterable
from typing import Any
import tablestore # type: ignore
from pydantic import BaseModel, model_validator
from tablestore import BatchGetRowRequest, TableInBatchGetRowItem
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models import Dataset
logger = logging.getLogger(__name__)
class TableStoreConfig(BaseModel):
access_key_id: str | None = None
access_key_secret: str | None = None
instance_name: str | None = None
endpoint: str | None = None
normalize_full_text_bm25_score: bool | None = False
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["access_key_id"]:
raise ValueError("config ACCESS_KEY_ID is required")
if not values["access_key_secret"]:
raise ValueError("config ACCESS_KEY_SECRET is required")
if not values["instance_name"]:
raise ValueError("config INSTANCE_NAME is required")
if not values["endpoint"]:
raise ValueError("config ENDPOINT is required")
return values
class TableStoreVector(BaseVector):
def __init__(self, collection_name: str, config: TableStoreConfig):
super().__init__(collection_name)
self._config = config
self._tablestore_client = tablestore.OTSClient(
config.endpoint,
config.access_key_id,
config.access_key_secret,
config.instance_name,
)
self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score
self._table_name = f"{collection_name}"
self._index_name = f"{collection_name}_idx"
self._tags_field = f"{Field.METADATA_KEY}_tags"
def create_collection(self, embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
def get_by_ids(self, ids: list[str]) -> list[Document]:
docs = []
request = BatchGetRowRequest()
columns_to_get = [Field.METADATA_KEY, Field.CONTENT_KEY]
rows_to_get = [[("id", _id)] for _id in ids]
request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1))
result = self._tablestore_client.batch_get_row(request)
table_result = result.get_result_by_table(self._table_name)
for item in table_result:
if item.is_ok and item.row:
kv = {k: v for k, v, _ in item.row.attribute_columns}
docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY])))
return docs
def get_type(self) -> str:
return VectorType.TABLESTORE
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
self.add_texts(documents=texts, embeddings=embeddings, **kwargs)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
for i in range(len(documents)):
self._write_row(
primary_key=uuids[i],
attributes={
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i],
Field.METADATA_KEY: documents[i].metadata,
},
)
return uuids
def text_exists(self, id: str) -> bool:
result = self._tablestore_client.get_row(
table_name=self._table_name, primary_key=[("id", id)], columns_to_get=["id"]
)
assert isinstance(result, tuple | list)
# Unpack the tuple result
_, return_row, _ = result
return return_row is not None
def delete_by_ids(self, ids: list[str]):
if not ids:
return
for id in ids:
self._delete_row(id=id)
def get_ids_by_metadata_field(self, key: str, value: str):
return self._search_by_metadata(key, value)
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
self.delete_by_ids(ids)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
filtered_list = None
if document_ids_filter:
filtered_list = ["document_id=" + item for item in document_ids_filter]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._search_by_vector(query_vector, filtered_list, top_k, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
filtered_list = None
if document_ids_filter:
filtered_list = ["document_id=" + item for item in document_ids_filter]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._search_by_full_text(query, filtered_list, top_k, score_threshold)
def delete(self):
self._delete_table_if_exist()
def _create_collection(self, dimension: int):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info("Collection %s already exists.", self._collection_name)
return
self._create_table_if_not_exist()
self._create_search_index_if_not_exist(dimension)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _create_table_if_not_exist(self):
table_list = self._tablestore_client.list_table()
if self._table_name in table_list:
logger.info("Tablestore system table[%s] already exists", self._table_name)
return None
schema_of_primary_key = [("id", "STRING")]
table_meta = tablestore.TableMeta(self._table_name, schema_of_primary_key)
table_options = tablestore.TableOptions()
reserved_throughput = tablestore.ReservedThroughput(tablestore.CapacityUnit(0, 0))
self._tablestore_client.create_table(table_meta, table_options, reserved_throughput)
logger.info("Tablestore create table[%s] successfully.", self._table_name)
def _create_search_index_if_not_exist(self, dimension: int):
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
assert isinstance(search_index_list, Iterable)
if self._index_name in [t[1] for t in search_index_list]:
logger.info("Tablestore system index[%s] already exists", self._index_name)
return None
field_schemas = [
tablestore.FieldSchema(
Field.CONTENT_KEY,
tablestore.FieldType.TEXT,
analyzer=tablestore.AnalyzerType.MAXWORD,
index=True,
enable_sort_and_agg=False,
store=False,
),
tablestore.FieldSchema(
Field.VECTOR,
tablestore.FieldType.VECTOR,
vector_options=tablestore.VectorOptions(
data_type=tablestore.VectorDataType.VD_FLOAT_32,
dimension=dimension,
metric_type=tablestore.VectorMetricType.VM_COSINE,
),
),
tablestore.FieldSchema(
Field.METADATA_KEY,
tablestore.FieldType.KEYWORD,
index=True,
store=False,
),
tablestore.FieldSchema(
self._tags_field,
tablestore.FieldType.KEYWORD,
index=True,
store=False,
is_array=True,
),
]
index_meta = tablestore.SearchIndexMeta(field_schemas)
self._tablestore_client.create_search_index(self._table_name, self._index_name, index_meta)
logger.info("Tablestore create system index[%s] successfully.", self._index_name)
def _delete_table_if_exist(self):
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
assert isinstance(search_index_list, Iterable)
for resp_tuple in search_index_list:
self._tablestore_client.delete_search_index(resp_tuple[0], resp_tuple[1])
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
self._tablestore_client.delete_table(self._table_name)
logger.info("Tablestore delete system table[%s] successfully.", self._index_name)
def _delete_search_index(self):
self._tablestore_client.delete_search_index(self._table_name, self._index_name)
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
def _write_row(self, primary_key: str, attributes: dict[str, Any]):
pk = [("id", primary_key)]
tags = []
for key, value in attributes[Field.METADATA_KEY].items():
tags.append(str(key) + "=" + str(value))
attribute_columns = [
(Field.CONTENT_KEY, attributes[Field.CONTENT_KEY]),
(Field.VECTOR, json.dumps(attributes[Field.VECTOR])),
(
Field.METADATA_KEY,
json.dumps(attributes[Field.METADATA_KEY]),
),
(self._tags_field, json.dumps(tags)),
]
row = tablestore.Row(pk, attribute_columns)
self._tablestore_client.put_row(self._table_name, row)
def _delete_row(self, id: str):
primary_key = [("id", id)]
row = tablestore.Row(primary_key)
self._tablestore_client.delete_row(self._table_name, row, None)
def _search_by_metadata(self, key: str, value: str) -> list[str]:
query = tablestore.SearchQuery(
tablestore.TermQuery(self._tags_field, str(key) + "=" + str(value)),
limit=1000,
get_total_count=False,
)
rows: list[str] = []
next_token = None
while True:
if next_token is not None:
query.next_token = next_token
search_response = self._tablestore_client.search(
table_name=self._table_name,
index_name=self._index_name,
search_query=query,
columns_to_get=tablestore.ColumnsToGet(
column_names=[Field.PRIMARY_KEY], return_type=tablestore.ColumnReturnType.SPECIFIED
),
)
if search_response is not None:
rows.extend([row[0][0][1] for row in list(search_response.rows)])
if search_response is None or search_response.next_token == b"":
break
else:
next_token = search_response.next_token
return rows
def _search_by_vector(
self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float
) -> list[Document]:
knn_vector_query = tablestore.KnnVectorQuery(
field_name=Field.VECTOR,
top_k=top_k,
float32_query_vector=query_vector,
)
if document_ids_filter:
knn_vector_query.filter = tablestore.TermsQuery(self._tags_field, document_ids_filter)
sort = tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)])
search_query = tablestore.SearchQuery(knn_vector_query, limit=top_k, get_total_count=False, sort=sort)
search_response = self._tablestore_client.search(
table_name=self._table_name,
index_name=self._index_name,
search_query=search_query,
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
)
documents = []
for search_hit in search_response.search_hits:
if search_hit.score >= score_threshold:
ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
vector_str = ots_column_map.get(Field.VECTOR)
metadata_str = ots_column_map.get(Field.METADATA_KEY)
vector = json.loads(vector_str) if vector_str else None
metadata = json.loads(metadata_str) if metadata_str else {}
metadata["score"] = search_hit.score
documents.append(
Document(
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
vector=vector,
metadata=metadata,
)
)
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
@staticmethod
def _normalize_score_exp_decay(score: float, k: float = 0.15) -> float:
"""
Args:
score: BM25 search score.
k: decay factor, the larger the k, the steeper the low score end
"""
normalized_score = 1 - math.exp(-k * score)
return max(0.0, min(1.0, normalized_score))
def _search_by_full_text(
self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float
) -> list[Document]:
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY))
if document_ids_filter:
bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter))
search_query = tablestore.SearchQuery(
query=bool_query,
sort=tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]),
limit=top_k,
)
search_response = self._tablestore_client.search(
table_name=self._table_name,
index_name=self._index_name,
search_query=search_query,
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
)
documents = []
for search_hit in search_response.search_hits:
score = None
if self._normalize_full_text_bm25_score:
score = self._normalize_score_exp_decay(search_hit.score)
# skip when score is below threshold and use normalize score
if score and score <= score_threshold:
continue
ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
metadata_str = ots_column_map.get(Field.METADATA_KEY)
metadata = json.loads(metadata_str) if metadata_str else {}
vector_str = ots_column_map.get(Field.VECTOR)
vector = json.loads(vector_str) if vector_str else None
if score:
metadata["score"] = score
documents.append(
Document(
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
vector=vector,
metadata=metadata,
)
)
if self._normalize_full_text_bm25_score:
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
class TableStoreVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TableStoreVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TABLESTORE, collection_name))
return TableStoreVector(
collection_name=collection_name,
config=TableStoreConfig(
endpoint=dify_config.TABLESTORE_ENDPOINT,
instance_name=dify_config.TABLESTORE_INSTANCE_NAME,
access_key_id=dify_config.TABLESTORE_ACCESS_KEY_ID,
access_key_secret=dify_config.TABLESTORE_ACCESS_KEY_SECRET,
normalize_full_text_bm25_score=dify_config.TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE,
),
)

View File

@@ -0,0 +1,331 @@
import json
import logging
import math
from typing import Any
from pydantic import BaseModel
from tcvdb_text.encoder import BM25Encoder # type: ignore
from tcvectordb import RPCVectorDBClient, VectorDBException # type: ignore
from tcvectordb.model import document, enum # type: ignore
from tcvectordb.model import index as vdb_index # type: ignore
from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, WeightedRerank # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class TencentConfig(BaseModel):
url: str
api_key: str | None = None
timeout: float = 30
username: str | None = None
database: str | None = None
index_type: str = "HNSW"
metric_type: str = "IP"
shard: int = 1
replicas: int = 2
max_upsert_batch_size: int = 128
enable_hybrid_search: bool = False # Flag to enable hybrid search
def to_tencent_params(self):
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
bm25 = BM25Encoder.default("zh")
class TencentVector(BaseVector):
field_id: str = "id"
field_vector: str = "vector"
field_text: str = "text"
field_metadata: str = "metadata"
def __init__(self, collection_name: str, config: TencentConfig):
super().__init__(collection_name)
self._client_config = config
self._client = RPCVectorDBClient(**self._client_config.to_tencent_params())
self._enable_hybrid_search = False
self._dimension = 1024
self._init_database()
self._load_collection()
def _load_collection(self):
"""
Check if the collection supports hybrid search.
"""
if self._client_config.enable_hybrid_search:
self._enable_hybrid_search = True
if self._has_collection():
coll = self._client.describe_collection(
database_name=self._client_config.database, collection_name=self.collection_name
)
has_hybrid_search = False
for idx in coll.indexes:
if idx.name == "sparse_vector":
has_hybrid_search = True
elif idx.name == "vector":
self._dimension = idx.dimension
if not has_hybrid_search:
self._enable_hybrid_search = False
def _init_database(self):
return self._client.create_database_if_not_exists(database_name=self._client_config.database)
def get_type(self) -> str:
return VectorType.TENCENT
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def _has_collection(self) -> bool:
return bool(
self._client.exists_collection(
database_name=self._client_config.database, collection_name=self.collection_name
)
)
def _create_collection(self, dimension: int):
self._dimension = dimension
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
if self._has_collection():
return
index_type = None
for k, v in enum.IndexType.__members__.items():
if k == self._client_config.index_type:
index_type = v
if index_type is None:
raise ValueError("unsupported index_type")
metric_type = None
for k, v in enum.MetricType.__members__.items():
if k == self._client_config.metric_type:
metric_type = v
if metric_type is None:
raise ValueError("unsupported metric_type")
params = vdb_index.HNSWParams(m=16, efconstruction=200)
index_id = vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY)
index_vector = vdb_index.VectorIndex(
self.field_vector,
dimension,
index_type,
metric_type,
params,
)
index_metadate = vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER)
index_sparse_vector = vdb_index.SparseIndex(
name="sparse_vector",
field_type=enum.FieldType.SparseVector,
index_type=enum.IndexType.SPARSE_INVERTED,
metric_type=enum.MetricType.IP,
)
indexes = [index_id, index_vector, index_metadate]
if self._enable_hybrid_search:
indexes.append(index_sparse_vector)
try:
self._client.create_collection(
database_name=self._client_config.database,
collection_name=self._collection_name,
shard=self._client_config.shard,
replicas=self._client_config.replicas,
description="Collection for Dify",
indexes=indexes,
)
except VectorDBException as e:
if "fieldType:json" not in e.message:
raise e
# vdb version not support json, use string
index_metadate = vdb_index.FilterIndex(
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
)
indexes = [index_id, index_vector, index_metadate]
if self._enable_hybrid_search:
indexes.append(index_sparse_vector)
self._client.create_collection(
database_name=self._client_config.database,
collection_name=self._collection_name,
shard=self._client_config.shard,
replicas=self._client_config.replicas,
description="Collection for Dify",
indexes=indexes,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._create_collection(len(embeddings[0]))
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
total_count = len(embeddings)
batch_size = self._client_config.max_upsert_batch_size
batch = math.ceil(total_count / batch_size)
for j in range(batch):
docs = []
start_idx = j * batch_size
end_idx = min(total_count, (j + 1) * batch_size)
for i in range(start_idx, end_idx):
if metadatas is None:
continue
metadata = metadatas[i] or {}
doc = document.Document(
id=metadata.get("doc_id"),
vector=embeddings[i],
text=texts[i],
metadata=metadata,
)
if self._enable_hybrid_search:
doc.__dict__["sparse_vector"] = bm25.encode_texts(texts[i])
docs.append(doc)
self._client.upsert(
database_name=self._client_config.database,
collection_name=self.collection_name,
documents=docs,
timeout=self._client_config.timeout,
)
def text_exists(self, id: str) -> bool:
docs = self._client.query(
database_name=self._client_config.database, collection_name=self.collection_name, document_ids=[id]
)
if docs and len(docs) > 0:
return True
return False
def delete_by_ids(self, ids: list[str]):
if not ids:
return
total_count = len(ids)
batch_size = self._client_config.max_upsert_batch_size
batch = math.ceil(total_count / batch_size)
for j in range(batch):
start_idx = j * batch_size
end_idx = min(total_count, (j + 1) * batch_size)
batch_ids = ids[start_idx:end_idx]
self._client.delete(
database_name=self._client_config.database, collection_name=self.collection_name, document_ids=batch_ids
)
def delete_by_metadata_field(self, key: str, value: str):
self._client.delete(
database_name=self._client_config.database,
collection_name=self.collection_name,
filter=Filter(Filter.In(f"metadata.{key}", [value])),
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
res = self._client.search(
database_name=self._client_config.database,
collection_name=self.collection_name,
vectors=[query_vector],
filter=filter,
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
retrieve_vector=False,
limit=kwargs.get("top_k", 4),
timeout=self._client_config.timeout,
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
if not self._enable_hybrid_search:
return []
res = self._client.hybrid_search(
database_name=self._client_config.database,
collection_name=self.collection_name,
ann=[
AnnSearch(
field_name="vector",
data=[0.0] * self._dimension,
)
],
match=[
KeywordSearch(
field_name="sparse_vector",
data=bm25.encode_queries(query),
),
],
rerank=WeightedRerank(
field_list=["vector", "sparse_vector"],
weight=[0, 1],
),
retrieve_vector=False,
limit=kwargs.get("top_k", 4),
filter=filter,
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)
def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]:
docs: list[Document] = []
if res is None or len(res) == 0:
return docs
for result in res[0]:
meta = result.get(self.field_metadata)
if isinstance(meta, str):
# Compatible with version 1.1.3 and below.
meta = json.loads(meta)
score = 1 - result.get("score", 0.0)
else:
score = result.get("score", 0.0)
if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)
docs.append(doc)
return docs
def delete(self):
if self._has_collection():
self._client.drop_collection(
database_name=self._client_config.database, collection_name=self.collection_name
)
class TencentVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
return TencentVector(
collection_name=collection_name,
config=TencentConfig(
url=dify_config.TENCENT_VECTOR_DB_URL or "",
api_key=dify_config.TENCENT_VECTOR_DB_API_KEY,
timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT,
username=dify_config.TENCENT_VECTOR_DB_USERNAME,
database=dify_config.TENCENT_VECTOR_DB_DATABASE,
shard=dify_config.TENCENT_VECTOR_DB_SHARD,
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
enable_hybrid_search=dify_config.TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH or False,
),
)

View File

@@ -0,0 +1,535 @@
import json
import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Union
import httpx
import qdrant_client
from flask import current_app
from httpx import DigestAuth
from pydantic import BaseModel
from qdrant_client.http import models as rest
from qdrant_client.http.models import (
FilterSelector,
HnswConfigDiff,
PayloadSchemaType,
TextIndexParams,
TextIndexType,
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, TidbAuthBinding
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
class TidbOnQdrantConfig(BaseModel):
endpoint: str
api_key: str | None = None
timeout: float = 20
root_path: str | None = None
grpc_port: int = 6334
prefer_grpc: bool = False
replication_factor: int = 1
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
if self.root_path:
path = os.path.join(self.root_path, path)
else:
raise ValueError("root_path is required")
return {"path": path}
else:
return {
"url": self.endpoint,
"api_key": self.api_key,
"timeout": self.timeout,
"verify": False,
"grpc_port": self.grpc_port,
"prefer_grpc": self.prefer_grpc,
}
class TidbConfig(BaseModel):
api_url: str
public_key: str
private_key: str
class TidbOnQdrantVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: TidbOnQdrantConfig, distance_func: str = "Cosine"):
super().__init__(collection_name)
self._client_config = config
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
self._distance_func = distance_func.upper()
self._group_id = group_id
def get_type(self) -> str:
return VectorType.TIDB_ON_QDRANT
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:
# get embedding vector size
vector_size = len(embeddings[0])
# get collection name
collection_name = self._collection_name
# create collection
self.create_collection(collection_name, vector_size)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
lock_name = f"vector_indexing_lock_{collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
collection_name = collection_name or uuid.uuid4().hex
all_collection_name = []
collections_response = self._client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
from qdrant_client.http import models as rest
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[self._distance_func],
)
hnsw_config = HnswConfigDiff(
m=0,
payload_m=16,
ef_construct=100,
full_scan_threshold=10000,
max_indexing_threads=0,
on_disk=False,
)
self._client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
replication_factor=self._client_config.replication_factor,
)
# create group_id payload index
self._client.create_payload_index(
collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD
)
# create doc_id payload index
self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD)
# create document_id payload index
self._client.create_payload_index(
collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD
)
# create full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True,
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents if d.metadata is not None]
added_ids = []
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
self._client.upsert(collection_name=self._collection_name, points=points)
added_ids.extend(batch_ids)
return added_ids
def _generate_rest_batches(
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: list[dict] | None = None,
ids: Sequence[str] | None = None,
batch_size: int = 64,
group_id: str | None = None,
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
texts_iterator = iter(texts)
embeddings_iterator = iter(embeddings)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
# Generate the embeddings for all the texts in a batch
batch_embeddings = list(islice(embeddings_iterator, batch_size))
points = [
rest.PointStruct(
id=point_id,
vector=vector,
payload=payload,
)
for point_id, vector, payload in zip(
batch_ids,
batch_embeddings,
self._build_payloads(
batch_texts,
batch_metadatas,
Field.CONTENT_KEY,
Field.METADATA_KEY,
group_id or "",
Field.GROUP_KEY,
),
)
]
yield batch_ids, points
@classmethod
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: list[dict] | None,
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str,
) -> list[dict]:
payloads = []
for i, text in enumerate(texts):
if text is None:
raise ValueError(
"At least one of the texts is None. Please remove it before "
"calling .from_texts or .add_texts on Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id})
return payloads
def delete_by_metadata_field(self, key: str, value: str):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
try:
filter = models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
)
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def delete(self):
from qdrant_client.http.exceptions import UnexpectedResponse
try:
self._client.delete_collection(collection_name=self._collection_name)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def delete_by_ids(self, ids: list[str]):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
for node_id in ids:
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def text_exists(self, id: str) -> bool:
all_collection_name = []
collections_response = self._client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if self._collection_name not in all_collection_name:
return False
response = self._client.retrieve(collection_name=self._collection_name, ids=[id])
return len(response) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
filter = None
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.document_id",
match=models.MatchAny(any=document_ids_filter),
)
],
)
results = self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
query_filter=filter,
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=kwargs.get("score_threshold", 0.0),
)
docs = []
for result in results:
if result.payload is None:
continue
metadata = result.payload.get(Field.METADATA_KEY) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold") or 0.0
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY, ""),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs most similar by bm25.
Returns:
List of documents most similar to the query text and distance for each.
"""
from qdrant_client.http import models
scroll_filter = None
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
scroll_filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.document_id",
match=models.MatchAny(any=document_ids_filter),
)
]
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,
limit=kwargs.get("top_k", 2),
with_payload=True,
with_vectors=True,
)
results = response[0]
documents = []
for result in results:
if result:
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
documents.append(document)
return documents
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client._load()
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
vector=scored_point.vector,
metadata=scored_point.payload.get(metadata_payload_key) or {},
)
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if not tidb_auth_binding:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID or "",
dify_config.TIDB_API_URL or "",
dify_config.TIDB_IAM_API_URL or "",
dify_config.TIDB_PUBLIC_KEY or "",
dify_config.TIDB_PRIVATE_KEY or "",
dify_config.TIDB_REGION or "",
)
new_tidb_auth_binding = TidbAuthBinding(
cluster_id=new_cluster["cluster_id"],
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
tenant_id=dataset.tenant_id,
active=True,
status="ACTIVE",
)
db.session.add(new_tidb_auth_binding)
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
else:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_ON_QDRANT, collection_name))
config = current_app.config
return TidbOnQdrantVector(
collection_name=collection_name,
group_id=dataset.id,
config=TidbOnQdrantConfig(
endpoint=dify_config.TIDB_ON_QDRANT_URL or "",
api_key=TIDB_ON_QDRANT_API_KEY,
root_path=str(config.root_path),
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT,
prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED,
replication_factor=dify_config.QDRANT_REPLICATION_FACTOR,
),
)
def create_tidb_serverless_cluster(self, tidb_config: TidbConfig, display_name: str, region: str):
"""
Creates a new TiDB Serverless cluster.
:param tidb_config: The configuration for the TiDB Cloud API.
:param display_name: The user-friendly display name of the cluster (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
region_object = {
"name": region,
}
labels = {
"tidb.cloud/project": "1372813089454548012",
}
cluster_data = {"displayName": display_name, "region": region_object, "labels": labels}
response = httpx.post(
f"{tidb_config.api_url}/clusters",
json=cluster_data,
auth=DigestAuth(tidb_config.public_key, tidb_config.private_key),
)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
def change_tidb_serverless_root_password(self, tidb_config: TidbConfig, cluster_id: str, new_password: str):
"""
Changes the root password of a specific TiDB Serverless cluster.
:param tidb_config: The configuration for the TiDB Cloud API.
:param cluster_id: The ID of the cluster for which the password is to be changed (required).
:param new_password: The new password for the root user (required).
:return: The response from the API.
"""
body = {"password": new_password}
response = httpx.put(
f"{tidb_config.api_url}/clusters/{cluster_id}/password",
json=body,
auth=DigestAuth(tidb_config.public_key, tidb_config.private_key),
)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()

View File

@@ -0,0 +1,247 @@
import time
import uuid
from collections.abc import Sequence
import httpx
from httpx import DigestAuth
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import TidbAuthBinding
class TidbService:
@staticmethod
def create_tidb_serverless_cluster(
project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
):
"""
Creates a new TiDB Serverless cluster.
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
region_object = {
"name": region,
}
labels = {
"tidb.cloud/project": project_id,
}
spending_limit = {
"monthly": dify_config.TIDB_SPEND_LIMIT,
}
password = str(uuid.uuid4()).replace("-", "")[:16]
display_name = str(uuid.uuid4()).replace("-", "")[:16]
cluster_data = {
"displayName": display_name,
"region": region_object,
"labels": labels,
"spendingLimit": spending_limit,
"rootPassword": password,
}
response = httpx.post(f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
response_data = response.json()
cluster_id = response_data["clusterId"]
retry_count = 0
max_retries = 30
while retry_count < max_retries:
cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id)
if cluster_response["state"] == "ACTIVE":
user_prefix = cluster_response["userPrefix"]
return {
"cluster_id": cluster_id,
"cluster_name": display_name,
"account": f"{user_prefix}.root",
"password": password,
}
time.sleep(30) # wait 30 seconds before retrying
retry_count += 1
else:
response.raise_for_status()
@staticmethod
def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
"""
Deletes a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster to be deleted (required).
:return: The response from the API.
"""
response = httpx.delete(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str):
"""
Deletes a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster to be deleted (required).
:return: The response from the API.
"""
response = httpx.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def change_tidb_serverless_root_password(
api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str
):
"""
Changes the root password of a specific TiDB Serverless cluster.
:param api_url: The URL of the TiDB Cloud API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param cluster_id: The ID of the cluster for which the password is to be changed (required).+
:param account: The account for which the password is to be changed (required).
:param new_password: The new password for the root user (required).
:return: The response from the API.
"""
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
response = httpx.patch(
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
json=body,
auth=DigestAuth(public_key, private_key),
)
if response.status_code == 200:
return response.json()
else:
response.raise_for_status()
@staticmethod
def batch_update_tidb_serverless_cluster_status(
tidb_serverless_list: Sequence[TidbAuthBinding],
project_id: str,
api_url: str,
iam_url: str,
public_key: str,
private_key: str,
):
"""
Update the status of a new TiDB Serverless cluster.
:param tidb_serverless_list: The TiDB serverless list (required).
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:return: The response from the API.
"""
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "BASIC"}
response = httpx.get(f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
response_data = response.json()
for item in response_data["clusters"]:
state = item["state"]
userPrefix = item["userPrefix"]
if state == "ACTIVE" and len(userPrefix) > 0:
cluster_info = tidb_serverless_list_map[item["clusterId"]]
cluster_info.status = "ACTIVE"
cluster_info.account = f"{userPrefix}.root"
db.session.add(cluster_info)
db.session.commit()
else:
response.raise_for_status()
@staticmethod
def batch_create_tidb_serverless_cluster(
batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str
) -> list[dict]:
"""
Creates a new TiDB Serverless cluster.
:param batch_size: The batch size (required).
:param project_id: The project ID of the TiDB Cloud project (required).
:param api_url: The URL of the TiDB Cloud API (required).
:param iam_url: The URL of the TiDB Cloud IAM API (required).
:param public_key: The public key for the API (required).
:param private_key: The private key for the API (required).
:param region: The region where the cluster will be created (required).
:return: The response from the API.
"""
clusters = []
for _ in range(batch_size):
region_object = {
"name": region,
}
labels = {
"tidb.cloud/project": project_id,
}
spending_limit = {
"monthly": dify_config.TIDB_SPEND_LIMIT,
}
password = str(uuid.uuid4()).replace("-", "")[:16]
display_name = str(uuid.uuid4()).replace("-", "")
cluster_data = {
"cluster": {
"displayName": display_name,
"region": region_object,
"labels": labels,
"spendingLimit": spending_limit,
"rootPassword": password,
}
}
cache_key = f"tidb_serverless_cluster_password:{display_name}"
redis_client.setex(cache_key, 3600, password)
clusters.append(cluster_data)
request_body = {"requests": clusters}
response = httpx.post(
f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key)
)
if response.status_code == 200:
response_data = response.json()
cluster_infos = []
for item in response_data["clusters"]:
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
cached_password = redis_client.get(cache_key)
if not cached_password:
continue
cluster_info = {
"cluster_id": item["clusterId"],
"cluster_name": item["displayName"],
"account": "root",
"password": cached_password.decode("utf-8"),
}
cluster_infos.append(cluster_info)
return cluster_infos
else:
response.raise_for_status()
return []

View File

@@ -0,0 +1,276 @@
import json
import logging
from typing import Any
import sqlalchemy
from pydantic import BaseModel, model_validator
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class TiDBVectorConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
program_name: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["host"]:
raise ValueError("config TIDB_VECTOR_HOST is required")
if not values["port"]:
raise ValueError("config TIDB_VECTOR_PORT is required")
if not values["user"]:
raise ValueError("config TIDB_VECTOR_USER is required")
if not values["database"]:
raise ValueError("config TIDB_VECTOR_DATABASE is required")
if not values["program_name"]:
raise ValueError("config APPLICATION_NAME is required")
return values
class TiDBVector(BaseVector):
def get_type(self) -> str:
return VectorType.TIDB_VECTOR
def _table(self, dim: int) -> Table:
from tidb_vector.sqlalchemy import VectorType # type: ignore
return Table(
self._collection_name,
self._orm_base.metadata,
Column(Field.PRIMARY_KEY, String(36), primary_key=True, nullable=False),
Column(
Field.VECTOR,
VectorType(dim),
nullable=False,
),
Column(Field.TEXT_KEY, TEXT, nullable=False),
Column("meta", JSON, nullable=False),
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
Column(
"update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")
),
extend_existing=True,
)
def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = "cosine"):
super().__init__(collection_name)
self._client_config = config
self._url = (
f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}"
)
self._distance_func = distance_func.lower()
self._engine = create_engine(self._url)
self._orm_base = declarative_base()
self._dimension = 1536
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
logger.info("create collection and add texts, collection_name: %s", self._collection_name)
self._create_collection(len(embeddings[0]))
self.add_texts(texts, embeddings)
self._dimension = len(embeddings[0])
pass
def _create_collection(self, dimension: int):
logger.info("_create_collection, collection_name %s", self._collection_name)
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
tidb_dist_func = self._get_distance_func()
with Session(self._engine) as session:
session.begin()
create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS {self._collection_name} (
id CHAR(36) PRIMARY KEY,
text TEXT NOT NULL,
meta JSON NOT NULL,
doc_id VARCHAR(64) AS (JSON_UNQUOTE(JSON_EXTRACT(meta, '$.doc_id'))) STORED,
document_id VARCHAR(64) AS (JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id'))) STORED,
vector VECTOR<FLOAT>({dimension}) NOT NULL,
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
KEY (doc_id),
KEY (document_id),
VECTOR INDEX idx_vector (({tidb_dist_func}(vector))) USING HNSW
);
""")
session.execute(create_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
table = self._table(len(embeddings[0]))
ids = self._get_uuids(documents)
metas = [d.metadata for d in documents]
texts = [d.page_content for d in documents]
chunks_table_data = []
with self._engine.connect() as conn, conn.begin():
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == 500:
conn.execute(insert(table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(table).values(chunks_table_data))
return ids
def text_exists(self, id: str) -> bool:
result = self.get_ids_by_metadata_field("doc_id", id)
return bool(result)
def delete_by_ids(self, ids: list[str]):
with Session(self._engine) as session:
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text(
f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """
)
result = session.execute(select_statement).fetchall()
if result:
ids = [item[0] for item in result]
self._delete_by_ids(ids)
def _delete_by_ids(self, ids: list[str]) -> bool:
if ids is None:
raise ValueError("No ids provided to delete.")
table = self._table(self._dimension)
try:
with self._engine.connect() as conn, conn.begin():
delete_condition = table.c.id.in_(ids)
conn.execute(table.delete().where(delete_condition))
return True
except Exception:
logger.exception("Delete operation failed for collection %s", self._collection_name)
return False
def get_ids_by_metadata_field(self, key: str, value: str):
with Session(self._engine) as session:
select_statement = sql_text(
f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.{key}' = '{value}'; """
)
result = session.execute(select_statement).fetchall()
if result:
return [item[0] for item in result]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._delete_by_ids(ids)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
distance = 1 - score_threshold
query_vector_str = ", ".join(format(x) for x in query_vector)
query_vector_str = "[" + query_vector_str + "]"
logger.debug(
"_collection_name: %s, score_threshold: %s, distance: %s", self._collection_name, score_threshold, distance
)
docs = []
tidb_dist_func = self._get_distance_func()
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" WHERE meta->>'$.document_id' in ({document_ids}) "
with Session(self._engine) as session:
select_statement = sql_text(f"""
SELECT meta, text, distance
FROM (
SELECT
meta,
text,
{tidb_dist_func}(vector, :query_vector_str) AS distance
FROM {self._collection_name}
{where_clause}
ORDER BY distance ASC
LIMIT :top_k
) t
WHERE distance <= :distance
""")
res = session.execute(
select_statement,
params={
"query_vector_str": query_vector_str,
"distance": distance,
"top_k": top_k,
},
)
results = [(row[0], row[1], row[2]) for row in res]
for meta, text, distance in results:
metadata = json.loads(meta)
metadata["score"] = 1 - distance
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# tidb doesn't support bm25 search
return []
def delete(self):
with Session(self._engine) as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
session.commit()
def _get_distance_func(self) -> str:
match self._distance_func:
case "l2":
tidb_dist_func = "VEC_L2_DISTANCE"
case "cosine":
tidb_dist_func = "VEC_COSINE_DISTANCE"
case _:
tidb_dist_func = "VEC_COSINE_DISTANCE"
return tidb_dist_func
class TiDBVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
return TiDBVector(
collection_name=collection_name,
config=TiDBVectorConfig(
host=dify_config.TIDB_VECTOR_HOST or "",
port=dify_config.TIDB_VECTOR_PORT or 0,
user=dify_config.TIDB_VECTOR_USER or "",
password=dify_config.TIDB_VECTOR_PASSWORD or "",
database=dify_config.TIDB_VECTOR_DATABASE or "",
program_name=dify_config.APPLICATION_NAME,
),
)

View File

@@ -0,0 +1,143 @@
import json
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, model_validator
from upstash_vector import Index, Vector
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from models.dataset import Dataset
class UpstashVectorConfig(BaseModel):
url: str
token: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values["url"]:
raise ValueError("Upstash URL is required")
if not values["token"]:
raise ValueError("Upstash Token is required")
return values
class UpstashVector(BaseVector):
def __init__(self, collection_name: str, config: UpstashVectorConfig):
super().__init__(collection_name)
self._table_name = collection_name
self.index = Index(url=config.url, token=config.token)
def _get_index_dimension(self) -> int:
index_info = self.index.info()
if index_info and index_info.dimension:
return index_info.dimension
else:
return 1536
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
vectors = [
Vector(
id=str(uuid4()),
vector=embedding,
metadata=doc.metadata,
data=doc.page_content,
)
for doc, embedding in zip(documents, embeddings)
]
self.index.upsert(vectors=vectors)
def text_exists(self, id: str) -> bool:
response = self.get_ids_by_metadata_field("doc_id", id)
return len(response) > 0
def delete_by_ids(self, ids: list[str]):
item_ids = []
for doc_id in ids:
ids = self.get_ids_by_metadata_field("doc_id", doc_id)
if ids:
item_ids += ids
self._delete_by_ids(ids=item_ids)
def _delete_by_ids(self, ids: list[str]):
if ids:
self.index.delete(ids=ids)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
query_result = self.index.query(
vector=[1.001 * i for i in range(self._get_index_dimension())],
include_metadata=True,
top_k=1000,
filter=f"{key} = '{value}'",
)
return [result.id for result in query_result]
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._delete_by_ids(ids)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f"document_id in ({document_ids})"
else:
filter = ""
result = self.index.query(
vector=query_vector,
top_k=top_k,
include_metadata=True,
include_data=True,
include_vectors=False,
filter=filter,
)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in result:
metadata = record.metadata
text = record.data
score = record.score
if metadata is not None and text is not None:
metadata["score"] = score
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def delete(self):
self.index.reset()
def get_type(self) -> str:
return VectorType.UPSTASH
class UpstashVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> UpstashVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.UPSTASH, collection_name))
return UpstashVector(
collection_name=collection_name,
config=UpstashVectorConfig(
url=dify_config.UPSTASH_VECTOR_URL or "",
token=dify_config.UPSTASH_VECTOR_TOKEN or "",
),
)

View File

@@ -0,0 +1,67 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
from core.rag.models.document import Document
class BaseVector(ABC):
def __init__(self, collection_name: str):
self._collection_name = collection_name
@abstractmethod
def get_type(self) -> str:
raise NotImplementedError
@abstractmethod
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError
@abstractmethod
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError
@abstractmethod
def text_exists(self, id: str) -> bool:
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]):
raise NotImplementedError
def get_ids_by_metadata_field(self, key: str, value: str):
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str):
raise NotImplementedError
@abstractmethod
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
raise NotImplementedError
@abstractmethod
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
raise NotImplementedError
@abstractmethod
def delete(self):
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts.copy():
if text.metadata and "doc_id" in text.metadata:
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata]
@property
def collection_name(self):
return self._collection_name

View File

@@ -0,0 +1,265 @@
import logging
import time
from abc import ABC, abstractmethod
from typing import Any
from sqlalchemy import select
from configs import dify_config
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Whitelist
logger = logging.getLogger(__name__)
class AbstractVectorFactory(ABC):
@abstractmethod
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
raise NotImplementedError
@staticmethod
def gen_index_struct_dict(vector_type: VectorType, collection_name: str):
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
return index_struct_dict
class Vector:
def __init__(self, dataset: Dataset, attributes: list | None = None):
if attributes is None:
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
self._dataset = dataset
self._embeddings = self._get_embeddings()
self._attributes = attributes
self._vector_processor = self._init_vector()
def _init_vector(self) -> BaseVector:
vector_type = dify_config.VECTOR_STORE
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict["type"]
else:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
stmt = select(Whitelist).where(
Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db"
)
whitelist = db.session.scalars(stmt).one_or_none()
if whitelist:
vector_type = VectorType.TIDB_ON_QDRANT
if not vector_type:
raise ValueError("Vector store must be specified.")
vector_factory_cls = self.get_vector_factory(vector_type)
return vector_factory_cls().init_vector(self._dataset, self._attributes, self._embeddings)
@staticmethod
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
match vector_type:
case VectorType.CHROMA:
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
return ChromaVectorFactory
case VectorType.MILVUS:
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
return MilvusVectorFactory
case VectorType.ALIBABACLOUD_MYSQL:
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
AlibabaCloudMySQLVectorFactory,
)
return AlibabaCloudMySQLVectorFactory
case VectorType.MYSCALE:
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
return MyScaleVectorFactory
case VectorType.PGVECTOR:
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
return PGVectorFactory
case VectorType.VASTBASE:
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVectorFactory
return VastbaseVectorFactory
case VectorType.PGVECTO_RS:
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
return PGVectoRSFactory
case VectorType.QDRANT:
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
return QdrantVectorFactory
case VectorType.RELYT:
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
return RelytVectorFactory
case VectorType.ELASTICSEARCH:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.ELASTICSEARCH_JA:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
ElasticSearchJaVectorFactory,
)
return ElasticSearchJaVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
return TiDBVectorFactory
case VectorType.WEAVIATE:
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
return WeaviateVectorFactory
case VectorType.TENCENT:
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
return TencentVectorFactory
case VectorType.ORACLE:
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
return OracleVectorFactory
case VectorType.OPENSEARCH:
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
return OpenSearchVectorFactory
case VectorType.ANALYTICDB:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
return AnalyticdbVectorFactory
case VectorType.COUCHBASE:
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory
return CouchbaseVectorFactory
case VectorType.BAIDU:
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
return BaiduVectorFactory
case VectorType.VIKINGDB:
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
return VikingDBVectorFactory
case VectorType.UPSTASH:
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
return UpstashVectorFactory
case VectorType.TIDB_ON_QDRANT:
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
return TidbOnQdrantVectorFactory
case VectorType.LINDORM:
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
return LindormVectorStoreFactory
case VectorType.OCEANBASE:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
return OceanBaseVectorFactory
case VectorType.OPENGAUSS:
from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory
return OpenGaussFactory
case VectorType.TABLESTORE:
from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory
return TableStoreVectorFactory
case VectorType.HUAWEI_CLOUD:
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
return HuaweiCloudVectorFactory
case VectorType.MATRIXONE:
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
return MatrixoneVectorFactory
case VectorType.CLICKZETTA:
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
return ClickzettaVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
def create(self, texts: list | None = None, **kwargs):
if texts:
start = time.time()
logger.info("start embedding %s texts %s", len(texts), start)
batch_size = 1000
total_batches = len(texts) + batch_size - 1
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
batch_start = time.time()
logger.info("Processing batch %s/%s (%s texts)", i // batch_size + 1, total_batches, len(batch))
batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch])
logger.info(
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
)
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get("duplicate_check", False):
documents = self._filter_duplicate_texts(documents)
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs)
def text_exists(self, id: str) -> bool:
return self._vector_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]):
self._vector_processor.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str):
self._vector_processor.delete_by_metadata_field(key, value)
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
query_vector = self._embeddings.embed_query(query)
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)
def delete(self):
self._vector_processor.delete()
# delete collection redis cache
if self._vector_processor.collection_name:
collection_exist_cache_key = f"vector_indexing_{self._vector_processor.collection_name}"
redis_client.delete(collection_exist_cache_key)
def _get_embeddings(self) -> Embeddings:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model,
)
return CacheEmbedding(embedding_model)
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts.copy():
if text.metadata is None:
continue
doc_id = text.metadata["doc_id"]
if doc_id:
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def __getattr__(self, name):
if self._vector_processor is not None:
method = getattr(self._vector_processor, name)
if callable(method):
return method
raise AttributeError(f"'vector_processor' object has no attribute '{name}'")

View File

@@ -0,0 +1,34 @@
from enum import StrEnum
class VectorType(StrEnum):
ALIBABACLOUD_MYSQL = "alibabacloud_mysql"
ANALYTICDB = "analyticdb"
CHROMA = "chroma"
MILVUS = "milvus"
MYSCALE = "myscale"
PGVECTOR = "pgvector"
VASTBASE = "vastbase"
PGVECTO_RS = "pgvecto-rs"
QDRANT = "qdrant"
RELYT = "relyt"
TIDB_VECTOR = "tidb_vector"
WEAVIATE = "weaviate"
OPENSEARCH = "opensearch"
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"
ELASTICSEARCH_JA = "elasticsearch-ja"
LINDORM = "lindorm"
COUCHBASE = "couchbase"
BAIDU = "baidu"
VIKINGDB = "vikingdb"
UPSTASH = "upstash"
TIDB_ON_QDRANT = "tidb_on_qdrant"
OCEANBASE = "oceanbase"
OPENGAUSS = "opengauss"
TABLESTORE = "tablestore"
HUAWEI_CLOUD = "huawei_cloud"
MATRIXONE = "matrixone"
CLICKZETTA = "clickzetta"

View File

@@ -0,0 +1,244 @@
import json
from typing import Any
from pydantic import BaseModel
from volcengine.viking_db import ( # type: ignore
Data,
DistanceType,
Field,
FieldType,
IndexType,
QuantType,
VectorIndexParams,
VikingDBService,
)
from configs import dify_config
from core.rag.datasource.vdb.field import Field as vdb_Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
class VikingDBConfig(BaseModel):
access_key: str
secret_key: str
host: str
region: str
scheme: str
connection_timeout: int
socket_timeout: int
index_type: str = str(IndexType.HNSW)
distance: str = str(DistanceType.L2)
quant: str = str(QuantType.Float)
class VikingDBVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: VikingDBConfig):
super().__init__(collection_name)
self._group_id = group_id
self._client_config = config
self._index_name = f"{self._collection_name}_idx"
self._client = VikingDBService(
host=config.host,
region=config.region,
scheme=config.scheme,
connection_timeout=config.connection_timeout,
socket_timeout=config.socket_timeout,
ak=config.access_key,
sk=config.secret_key,
)
def _has_collection(self) -> bool:
try:
self._client.get_collection(self._collection_name)
except Exception:
return False
return True
def _has_index(self) -> bool:
try:
self._client.get_index(self._collection_name, self._index_name)
except Exception:
return False
return True
def _create_collection(self, dimension: int):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
if not self._has_collection():
fields = [
Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True),
Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String),
Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String),
Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text),
Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=dimension),
]
self._client.create_collection(
collection_name=self._collection_name,
fields=fields,
description="Collection For Dify",
)
if not self._has_index():
vector_index = VectorIndexParams(
distance=self._client_config.distance,
index_type=self._client_config.index_type,
quant=self._client_config.quant,
)
self._client.create_index(
collection_name=self._collection_name,
index_name=self._index_name,
vector_index=vector_index,
partition_by=vdb_Field.GROUP_KEY,
description="Index For Dify",
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def get_type(self) -> str:
return VectorType.VIKINGDB
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
self.add_texts(texts, embeddings, **kwargs)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
page_contents = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
docs = []
for i, page_content in enumerate(page_contents):
metadata = {}
if metadatas is not None:
for key, val in (metadatas[i] or {}).items():
metadata[key] = val
# FIXME: fix the type of metadata later
doc = Data(
{
vdb_Field.PRIMARY_KEY: metadatas[i]["doc_id"], # type: ignore
vdb_Field.VECTOR: embeddings[i] if embeddings else None,
vdb_Field.CONTENT_KEY: page_content,
vdb_Field.METADATA_KEY: json.dumps(metadata),
vdb_Field.GROUP_KEY: self._group_id,
}
)
docs.append(doc)
self._client.get_collection(self._collection_name).upsert_data(docs)
def text_exists(self, id: str) -> bool:
docs = self._client.get_collection(self._collection_name).fetch_data(id)
not_exists_str = "data does not exist"
if docs is not None and not_exists_str not in docs.fields.get("message", ""):
return True
return False
def delete_by_ids(self, ids: list[str]):
self._client.get_collection(self._collection_name).delete_data(ids)
def get_ids_by_metadata_field(self, key: str, value: str):
# Note: Metadata field value is an dict, but vikingdb field
# not support json type
results = self._client.get_index(self._collection_name, self._index_name).search(
filter={"op": "must", "field": vdb_Field.GROUP_KEY, "conds": [self._group_id]},
# max value is 5000
limit=5000,
)
if not results:
return []
ids = []
for result in results:
metadata = result.fields.get(vdb_Field.METADATA_KEY)
if metadata is not None:
metadata = json.loads(metadata)
if metadata.get(key) == value:
ids.append(result.id)
return ids
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
self.delete_by_ids(ids)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
results = self._client.get_index(self._collection_name, self._index_name).search_by_vector(
query_vector, limit=kwargs.get("top_k", 4)
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
docs = self._get_search_res(results, score_threshold)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
docs = [doc for doc in docs if doc.metadata.get("document_id") in document_ids_filter]
return docs
def _get_search_res(self, results, score_threshold) -> list[Document]:
if len(results) == 0:
return []
docs = []
for result in results:
metadata = result.fields.get(vdb_Field.METADATA_KEY)
if metadata is not None:
metadata = json.loads(metadata)
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY), metadata=metadata)
docs.append(doc)
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def delete(self):
if self._has_index():
self._client.drop_index(self._collection_name, self._index_name)
if self._has_collection():
self._client.drop_collection(self._collection_name)
class VikingDBVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VikingDBVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.VIKINGDB, collection_name))
if dify_config.VIKINGDB_ACCESS_KEY is None:
raise ValueError("VIKINGDB_ACCESS_KEY should not be None")
if dify_config.VIKINGDB_SECRET_KEY is None:
raise ValueError("VIKINGDB_SECRET_KEY should not be None")
if dify_config.VIKINGDB_HOST is None:
raise ValueError("VIKINGDB_HOST should not be None")
if dify_config.VIKINGDB_REGION is None:
raise ValueError("VIKINGDB_REGION should not be None")
if dify_config.VIKINGDB_SCHEME is None:
raise ValueError("VIKINGDB_SCHEME should not be None")
return VikingDBVector(
collection_name=collection_name,
group_id=dataset.id,
config=VikingDBConfig(
access_key=dify_config.VIKINGDB_ACCESS_KEY,
secret_key=dify_config.VIKINGDB_SECRET_KEY,
host=dify_config.VIKINGDB_HOST,
region=dify_config.VIKINGDB_REGION,
scheme=dify_config.VIKINGDB_SCHEME,
connection_timeout=dify_config.VIKINGDB_CONNECTION_TIMEOUT,
socket_timeout=dify_config.VIKINGDB_SOCKET_TIMEOUT,
),
)

View File

@@ -0,0 +1,460 @@
"""
Weaviate vector database implementation for Dify's RAG system.
This module provides integration with Weaviate vector database for storing and retrieving
document embeddings used in retrieval-augmented generation workflows.
"""
import datetime
import json
import logging
import uuid as _uuid
from typing import Any
from urllib.parse import urlparse
import weaviate
import weaviate.classes.config as wc
from pydantic import BaseModel, model_validator
from weaviate.classes.data import DataObject
from weaviate.classes.init import Auth
from weaviate.classes.query import Filter, MetadataQuery
from weaviate.exceptions import UnexpectedStatusCodeError
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class WeaviateConfig(BaseModel):
"""
Configuration model for Weaviate connection settings.
Attributes:
endpoint: Weaviate server endpoint URL
grpc_endpoint: Optional Weaviate gRPC server endpoint URL
api_key: Optional API key for authentication
batch_size: Number of objects to batch per insert operation
"""
endpoint: str
grpc_endpoint: str | None = None
api_key: str | None = None
batch_size: int = 100
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""Validates that required configuration values are present."""
if not values["endpoint"]:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVector(BaseVector):
"""
Weaviate vector database implementation for document storage and retrieval.
Handles creation, insertion, deletion, and querying of document embeddings
in a Weaviate collection.
"""
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
"""
Initializes the Weaviate vector store.
Args:
collection_name: Name of the Weaviate collection
config: Weaviate configuration settings
attributes: List of metadata attributes to store
"""
super().__init__(collection_name)
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
"""
Initializes and returns a connected Weaviate client.
Configures both HTTP and gRPC connections with proper authentication.
"""
p = urlparse(config.endpoint)
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
http_secure = p.scheme == "https"
http_port = p.port or (443 if http_secure else 80)
# Parse gRPC configuration
if config.grpc_endpoint:
# Urls without scheme won't be parsed correctly in some python versions,
# see https://bugs.python.org/issue27657
grpc_endpoint_with_scheme = (
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
)
grpc_p = urlparse(grpc_endpoint_with_scheme)
grpc_host = grpc_p.hostname or "localhost"
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
grpc_secure = grpc_p.scheme == "grpcs"
else:
# Infer from HTTP endpoint as fallback
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
client = weaviate.connect_to_custom(
http_host=host,
http_port=http_port,
http_secure=http_secure,
grpc_host=grpc_host,
grpc_port=grpc_port,
grpc_secure=grpc_secure,
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests
)
if not client.is_ready():
raise ConnectionError("Vector database is not ready")
return client
def get_type(self) -> str:
"""Returns the vector database type identifier."""
return VectorType.WEAVIATE
def get_collection_name(self, dataset: Dataset) -> str:
"""
Retrieves or generates the collection name for a dataset.
Uses existing index structure if available, otherwise generates from dataset ID.
"""
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
if not class_prefix.endswith("_Node"):
class_prefix += "_Node"
return class_prefix
dataset_id = dataset.id
return Dataset.gen_collection_name_by_id(dataset_id)
def to_index_struct(self) -> dict:
"""Returns the index structure dictionary for persistence."""
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""
Creates a new collection and adds initial documents with embeddings.
"""
self._create_collection()
self.add_texts(texts, embeddings)
def _create_collection(self):
"""
Creates the Weaviate collection with required schema if it doesn't exist.
Uses Redis locking to prevent concurrent creation attempts.
"""
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(cache_key):
return
try:
if not self._client.collections.exists(self._collection_name):
tokenization = (
wc.Tokenization(dify_config.WEAVIATE_TOKENIZATION)
if dify_config.WEAVIATE_TOKENIZATION
else wc.Tokenization.WORD
)
self._client.collections.create(
name=self._collection_name,
properties=[
wc.Property(
name=Field.TEXT_KEY.value,
data_type=wc.DataType.TEXT,
tokenization=tokenization,
),
wc.Property(name="document_id", data_type=wc.DataType.TEXT),
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),
wc.Property(name="chunk_index", data_type=wc.DataType.INT),
],
vector_config=wc.Configure.Vectors.self_provided(),
)
self._ensure_properties()
redis_client.set(cache_key, 1, ex=3600)
except Exception as e:
logger.exception("Error creating collection %s", self._collection_name)
raise
def _ensure_properties(self) -> None:
"""
Ensures all required properties exist in the collection schema.
Adds missing properties if the collection exists but lacks them.
"""
if not self._client.collections.exists(self._collection_name):
return
col = self._client.collections.use(self._collection_name)
cfg = col.config.get()
existing = {p.name for p in (cfg.properties or [])}
to_add = []
if "document_id" not in existing:
to_add.append(wc.Property(name="document_id", data_type=wc.DataType.TEXT))
if "doc_id" not in existing:
to_add.append(wc.Property(name="doc_id", data_type=wc.DataType.TEXT))
if "chunk_index" not in existing:
to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT))
for prop in to_add:
try:
col.config.add_property(prop)
except Exception as e:
logger.warning("Could not add property %s: %s", prop.name, e)
def _get_uuids(self, documents: list[Document]) -> list[str]:
"""
Generates deterministic UUIDs for documents based on their content.
Uses UUID5 with URL namespace to ensure consistent IDs for identical content.
"""
URL_NAMESPACE = _uuid.UUID("6ba7b811-9dad-11d1-80b4-00c04fd430c8")
uuids = []
for doc in documents:
uuid_val = _uuid.uuid5(URL_NAMESPACE, doc.page_content)
uuids.append(str(uuid_val))
return uuids
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Adds documents with their embeddings to the collection.
Batches insertions for efficiency and returns the list of inserted object IDs.
"""
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
col = self._client.collections.use(self._collection_name)
objs: list[DataObject] = []
ids_out: list[str] = []
for i, text in enumerate(texts):
props: dict[str, Any] = {Field.TEXT_KEY.value: text}
meta = metadatas[i] or {}
for k, v in meta.items():
props[k] = self._json_serializable(v)
candidate = uuids[i] if uuids else None
uid = candidate if (candidate and self._is_uuid(candidate)) else str(_uuid.uuid4())
ids_out.append(uid)
vec_payload = None
if embeddings and i < len(embeddings) and embeddings[i]:
vec_payload = {"default": embeddings[i]}
objs.append(
DataObject(
uuid=uid,
properties=props, # type: ignore[arg-type] # mypy incorrectly infers DataObject signature
vector=vec_payload,
)
)
with col.batch.dynamic() as batch:
for obj in objs:
batch.add_object(properties=obj.properties, uuid=obj.uuid, vector=obj.vector)
return ids_out
def _is_uuid(self, val: str) -> bool:
"""Validates whether a string is a valid UUID format."""
try:
_uuid.UUID(str(val))
return True
except Exception:
return False
def delete_by_metadata_field(self, key: str, value: str) -> None:
"""Deletes all objects matching a specific metadata field value."""
if not self._client.collections.exists(self._collection_name):
return
col = self._client.collections.use(self._collection_name)
col.data.delete_many(where=Filter.by_property(key).equal(value))
def delete(self):
"""Deletes the entire collection from Weaviate."""
if self._client.collections.exists(self._collection_name):
self._client.collections.delete(self._collection_name)
def text_exists(self, id: str) -> bool:
"""Checks if a document with the given doc_id exists in the collection."""
if not self._client.collections.exists(self._collection_name):
return False
col = self._client.collections.use(self._collection_name)
res = col.query.fetch_objects(
filters=Filter.by_property("doc_id").equal(id),
limit=1,
return_properties=["doc_id"],
)
return len(res.objects) > 0
def delete_by_ids(self, ids: list[str]) -> None:
"""
Deletes objects by their UUID identifiers.
Silently ignores 404 errors for non-existent IDs.
"""
if not self._client.collections.exists(self._collection_name):
return
col = self._client.collections.use(self._collection_name)
for uid in ids:
try:
col.data.delete_by_id(uid)
except UnexpectedStatusCodeError as e:
if getattr(e, "status_code", None) != 404:
raise
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Performs vector similarity search using the provided query vector.
Filters by document IDs if provided and applies score threshold.
Returns documents sorted by relevance score.
"""
if not self._client.collections.exists(self._collection_name):
return []
col = self._client.collections.use(self._collection_name)
props = list({*self._attributes, "document_id", Field.TEXT_KEY.value})
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
top_k = int(kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
res = col.query.near_vector(
near_vector=query_vector,
limit=top_k,
return_properties=props,
return_metadata=MetadataQuery(distance=True),
include_vector=False,
filters=where,
target_vector="default",
)
docs: list[Document] = []
for obj in res.objects:
properties = dict(obj.properties or {})
text = properties.pop(Field.TEXT_KEY.value, "")
if obj.metadata and obj.metadata.distance is not None:
distance = obj.metadata.distance
else:
distance = 1.0
score = 1.0 - distance
if score > score_threshold:
properties["score"] = score
docs.append(Document(page_content=text, metadata=properties))
docs.sort(key=lambda d: d.metadata.get("score", 0.0), reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""
Performs BM25 full-text search on document content.
Filters by document IDs if provided and returns matching documents with vectors.
"""
if not self._client.collections.exists(self._collection_name):
return []
col = self._client.collections.use(self._collection_name)
props = list({*self._attributes, Field.TEXT_KEY.value})
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
top_k = int(kwargs.get("top_k", 4))
res = col.query.bm25(
query=query,
query_properties=[Field.TEXT_KEY.value],
limit=top_k,
return_properties=props,
include_vector=True,
filters=where,
)
docs: list[Document] = []
for obj in res.objects:
properties = dict(obj.properties or {})
text = properties.pop(Field.TEXT_KEY.value, "")
vec = obj.vector
if isinstance(vec, dict):
vec = vec.get("default") or next(iter(vec.values()), None)
docs.append(Document(page_content=text, vector=vec, metadata=properties))
return docs
def _json_serializable(self, value: Any) -> Any:
"""Converts values to JSON-serializable format, handling datetime objects."""
if isinstance(value, datetime.datetime):
return value.isoformat()
return value
class WeaviateVectorFactory(AbstractVectorFactory):
"""Factory class for creating WeaviateVector instances."""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
"""
Initializes a WeaviateVector instance for the given dataset.
Uses existing collection name from dataset index structure or generates a new one.
Updates dataset index structure if not already set.
"""
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
endpoint=dify_config.WEAVIATE_ENDPOINT or "",
grpc_endpoint=dify_config.WEAVIATE_GRPC_ENDPOINT or "",
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
),
attributes=attributes,
)