dify
This commit is contained in:
0
dify/api/core/rag/splitter/__init__.py
Normal file
0
dify/api/core/rag/splitter/__init__.py
Normal file
151
dify/api/core/rag/splitter/fixed_text_splitter.py
Normal file
151
dify/api/core/rag/splitter/fixed_text_splitter.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Functionality for splitting text."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
||||
from core.rag.splitter.text_splitter import (
|
||||
TS,
|
||||
Collection,
|
||||
Literal,
|
||||
RecursiveCharacterTextSplitter,
|
||||
Set,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
||||
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
"""
|
||||
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_encoder(
|
||||
cls: type[TS],
|
||||
embedding_model_instance: ModelInstance | None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
|
||||
**kwargs: Any,
|
||||
):
|
||||
def _token_encoder(texts: list[str]) -> list[int]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
if embedding_model_instance:
|
||||
return embedding_model_instance.get_text_embedding_num_tokens(texts=texts)
|
||||
else:
|
||||
return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
|
||||
|
||||
def _character_encoder(texts: list[str]) -> list[int]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
return [len(text) for text in texts]
|
||||
|
||||
return cls(length_function=_character_encoder, **kwargs)
|
||||
|
||||
|
||||
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
|
||||
def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._fixed_separator = fixed_separator
|
||||
self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""]
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
if self._fixed_separator:
|
||||
chunks = text.split(self._fixed_separator)
|
||||
else:
|
||||
chunks = [text]
|
||||
|
||||
final_chunks = []
|
||||
chunks_lengths = self._length_function(chunks)
|
||||
for chunk, chunk_length in zip(chunks, chunks_lengths):
|
||||
if chunk_length > self._chunk_size:
|
||||
final_chunks.extend(self.recursive_split_text(chunk))
|
||||
else:
|
||||
final_chunks.append(chunk)
|
||||
|
||||
return final_chunks
|
||||
|
||||
def recursive_split_text(self, text: str) -> list[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
|
||||
final_chunks = []
|
||||
separator = self._separators[-1]
|
||||
new_separators = []
|
||||
|
||||
for i, _s in enumerate(self._separators):
|
||||
if _s == "":
|
||||
separator = _s
|
||||
break
|
||||
if _s in text:
|
||||
separator = _s
|
||||
new_separators = self._separators[i + 1 :]
|
||||
break
|
||||
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
if separator == " ":
|
||||
splits = re.split(r" +", text)
|
||||
else:
|
||||
splits = text.split(separator)
|
||||
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
|
||||
else:
|
||||
splits = list(text)
|
||||
if separator == "\n":
|
||||
splits = [s for s in splits if s != ""]
|
||||
else:
|
||||
splits = [s for s in splits if (s not in {"", "\n"})]
|
||||
_good_splits = []
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
_separator = separator if self._keep_separator else ""
|
||||
s_lens = self._length_function(splits)
|
||||
if separator != "":
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
if s_len < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
_good_splits_lengths.append(s_len)
|
||||
else:
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
_good_splits_lengths = []
|
||||
if not new_separators:
|
||||
final_chunks.append(s)
|
||||
else:
|
||||
other_info = self._split_text(s, new_separators)
|
||||
final_chunks.extend(other_info)
|
||||
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
|
||||
final_chunks.extend(merged_text)
|
||||
else:
|
||||
current_part = ""
|
||||
current_length = 0
|
||||
overlap_part = ""
|
||||
overlap_part_length = 0
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
if current_length + s_len <= self._chunk_size - self._chunk_overlap:
|
||||
current_part += s
|
||||
current_length += s_len
|
||||
elif current_length + s_len <= self._chunk_size:
|
||||
current_part += s
|
||||
current_length += s_len
|
||||
overlap_part += s
|
||||
overlap_part_length += s_len
|
||||
else:
|
||||
final_chunks.append(current_part)
|
||||
current_part = overlap_part + s
|
||||
current_length = s_len + overlap_part_length
|
||||
overlap_part = ""
|
||||
overlap_part_length = 0
|
||||
if current_part:
|
||||
final_chunks.append(current_part)
|
||||
|
||||
return final_chunks
|
||||
297
dify/api/core/rag/splitter/text_splitter.py
Normal file
297
dify/api/core/rag/splitter/text_splitter.py
Normal file
@@ -0,0 +1,297 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Collection, Iterable, Sequence, Set
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from core.rag.models.document import BaseDocumentTransformer, Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TS = TypeVar("TS", bound="TextSplitter")
|
||||
|
||||
|
||||
def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]:
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
if keep_separator:
|
||||
# The parentheses in the pattern keep the delimiters in the result.
|
||||
_splits = re.split(f"({re.escape(separator)})", text)
|
||||
splits = [_splits[i - 1] + _splits[i] for i in range(1, len(_splits), 2)]
|
||||
if len(_splits) % 2 != 0:
|
||||
splits += _splits[-1:]
|
||||
else:
|
||||
splits = re.split(separator, text)
|
||||
else:
|
||||
splits = list(text)
|
||||
return [s for s in splits if (s not in {"", "\n"})]
|
||||
|
||||
|
||||
class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
"""Interface for splitting text into chunks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: Callable[[list[str]], list[int]] = lambda x: [len(x) for x in x],
|
||||
keep_separator: bool = False,
|
||||
add_start_index: bool = False,
|
||||
):
|
||||
"""Create a new TextSplitter.
|
||||
|
||||
Args:
|
||||
chunk_size: Maximum size of chunks to return
|
||||
chunk_overlap: Overlap in characters between chunks
|
||||
length_function: Function that measures the length of given chunks
|
||||
keep_separator: Whether to keep the separator in the chunks
|
||||
add_start_index: If `True`, includes chunk's start index in metadata
|
||||
"""
|
||||
if chunk_overlap > chunk_size:
|
||||
raise ValueError(
|
||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size ({chunk_size}), should be smaller."
|
||||
)
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._length_function = length_function
|
||||
self._keep_separator = keep_separator
|
||||
self._add_start_index = add_start_index
|
||||
|
||||
@abstractmethod
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split text into multiple components."""
|
||||
|
||||
def create_documents(self, texts: list[str], metadatas: list[dict] | None = None) -> list[Document]:
|
||||
"""Create documents from a list of texts."""
|
||||
_metadatas = metadatas or [{}] * len(texts)
|
||||
documents = []
|
||||
for i, text in enumerate(texts):
|
||||
index = -1
|
||||
for chunk in self.split_text(text):
|
||||
metadata = copy.deepcopy(_metadatas[i])
|
||||
if self._add_start_index:
|
||||
index = text.find(chunk, index + 1)
|
||||
metadata["start_index"] = index
|
||||
new_doc = Document(page_content=chunk, metadata=metadata)
|
||||
documents.append(new_doc)
|
||||
return documents
|
||||
|
||||
def split_documents(self, documents: Iterable[Document]) -> list[Document]:
|
||||
"""Split documents."""
|
||||
texts, metadatas = [], []
|
||||
for doc in documents:
|
||||
texts.append(doc.page_content)
|
||||
metadatas.append(doc.metadata or {})
|
||||
return self.create_documents(texts, metadatas=metadatas)
|
||||
|
||||
def _join_docs(self, docs: list[str], separator: str) -> str | None:
|
||||
text = separator.join(docs)
|
||||
text = text.strip()
|
||||
if text == "":
|
||||
return None
|
||||
else:
|
||||
return text
|
||||
|
||||
def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
|
||||
# We now want to combine these smaller pieces into medium size
|
||||
# chunks to send to the LLM.
|
||||
separator_len = self._length_function([separator])[0]
|
||||
|
||||
docs = []
|
||||
current_doc: list[str] = []
|
||||
total = 0
|
||||
for d, _len in zip(splits, lengths):
|
||||
if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
|
||||
if total > self._chunk_size:
|
||||
logger.warning(
|
||||
"Created a chunk of size %s, which is longer than the specified %s", total, self._chunk_size
|
||||
)
|
||||
if len(current_doc) > 0:
|
||||
doc = self._join_docs(current_doc, separator)
|
||||
if doc is not None:
|
||||
docs.append(doc)
|
||||
# Keep on popping if:
|
||||
# - we have a larger chunk than in the chunk overlap
|
||||
# - or if we still have any chunks and the length is long
|
||||
while total > self._chunk_overlap or (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
|
||||
):
|
||||
total -= self._length_function([current_doc[0]])[0] + (
|
||||
separator_len if len(current_doc) > 1 else 0
|
||||
)
|
||||
current_doc = current_doc[1:]
|
||||
current_doc.append(d)
|
||||
total += _len + (separator_len if len(current_doc) > 1 else 0)
|
||||
doc = self._join_docs(current_doc, separator)
|
||||
if doc is not None:
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
|
||||
"""Text splitter that uses HuggingFace tokenizer to count length."""
|
||||
try:
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")
|
||||
|
||||
def _huggingface_tokenizer_length(text: str) -> int:
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. Please install it with `pip install transformers`."
|
||||
)
|
||||
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
|
||||
|
||||
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Transform sequence of documents by splitting them."""
|
||||
return self.split_documents(list(documents))
|
||||
|
||||
async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Asynchronously transform a sequence of documents by splitting them."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# @dataclass(frozen=True, kw_only=True, slots=True)
|
||||
@dataclass(frozen=True)
|
||||
class Tokenizer:
|
||||
chunk_overlap: int
|
||||
tokens_per_chunk: int
|
||||
decode: Callable[[list[int]], str]
|
||||
encode: Callable[[str], list[int]]
|
||||
|
||||
|
||||
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
|
||||
"""Split incoming text and return chunks using tokenizer."""
|
||||
splits: list[str] = []
|
||||
input_ids = tokenizer.encode(text)
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
while start_idx < len(input_ids):
|
||||
splits.append(tokenizer.decode(chunk_ids))
|
||||
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
return splits
|
||||
|
||||
|
||||
class TokenTextSplitter(TextSplitter):
|
||||
"""Splitting text to tokens using model tokenizer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: str | None = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to for TokenTextSplitter. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
if model_name is not None:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
self._tokenizer = enc
|
||||
self._allowed_special = allowed_special
|
||||
self._disallowed_special = disallowed_special
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
def _encode(_text: str) -> list[int]:
|
||||
return self._tokenizer.encode(
|
||||
_text,
|
||||
allowed_special=self._allowed_special,
|
||||
disallowed_special=self._disallowed_special,
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(
|
||||
chunk_overlap=self._chunk_overlap,
|
||||
tokens_per_chunk=self._chunk_size,
|
||||
decode=self._tokenizer.decode,
|
||||
encode=_encode,
|
||||
)
|
||||
|
||||
return split_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
|
||||
|
||||
class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
"""Splitting text by recursively look at characters.
|
||||
|
||||
Recursively tries to split by different characters to find one
|
||||
that works.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
separators: list[str] | None = None,
|
||||
keep_separator: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(keep_separator=keep_separator, **kwargs)
|
||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||
|
||||
def _split_text(self, text: str, separators: list[str]) -> list[str]:
|
||||
final_chunks = []
|
||||
separator = separators[-1]
|
||||
new_separators = []
|
||||
|
||||
for i, _s in enumerate(separators):
|
||||
if _s == "":
|
||||
separator = _s
|
||||
break
|
||||
if re.search(_s, text):
|
||||
separator = _s
|
||||
new_separators = separators[i + 1 :]
|
||||
break
|
||||
|
||||
splits = _split_text_with_regex(text, separator, self._keep_separator)
|
||||
_good_splits = []
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
_separator = "" if self._keep_separator else separator
|
||||
s_lens = self._length_function(splits)
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
if s_len < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
_good_splits_lengths.append(s_len)
|
||||
else:
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
_good_splits_lengths = []
|
||||
if not new_separators:
|
||||
final_chunks.append(s)
|
||||
else:
|
||||
other_info = self._split_text(s, new_separators)
|
||||
final_chunks.extend(other_info)
|
||||
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
|
||||
final_chunks.extend(merged_text)
|
||||
|
||||
return final_chunks
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
return self._split_text(text, self._separators)
|
||||
Reference in New Issue
Block a user