Initial commit: AI 知识库文档智能分块工具
This commit is contained in:
218
chunker.py
Normal file
218
chunker.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""AI 分块器,通过 DeepSeek API 进行语义级智能分块"""
|
||||
|
||||
import re
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from api_client import ApiClient
|
||||
from exceptions import ApiError
|
||||
from models import Chunk
|
||||
from prompts import get_system_prompt, get_user_prompt, CONTENT_TYPE_DOCUMENT, CONTENT_TYPE_IMAGE
|
||||
|
||||
# 匹配 [标签名] 标题 格式
|
||||
_TAG_PATTERN = re.compile(r"^\[(.+?)\]\s*(.+)$")
|
||||
|
||||
|
||||
class AIChunker:
|
||||
"""通过 DeepSeek API 进行语义级智能分块"""
|
||||
|
||||
DEFAULT_PRE_SPLIT_SIZE = 12000 # 默认预切分字符数,适合中文文档
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_client: ApiClient,
|
||||
delimiter: str = "---",
|
||||
pre_split_size: int = None,
|
||||
):
|
||||
self._api_client = api_client
|
||||
self._delimiter = delimiter
|
||||
self.PRE_SPLIT_SIZE = pre_split_size or self.DEFAULT_PRE_SPLIT_SIZE
|
||||
|
||||
def chunk(
|
||||
self,
|
||||
text: str,
|
||||
content_type: str = CONTENT_TYPE_DOCUMENT,
|
||||
source_file: str = "",
|
||||
on_progress: Optional[Callable[[int, int], None]] = None,
|
||||
) -> List[Chunk]:
|
||||
"""
|
||||
对文本进行 AI 语义分块。
|
||||
|
||||
图片类内容直接作为单个分块返回(通常很短,不需要再分块)。
|
||||
若文本超过 PRE_SPLIT_SIZE,先按段落边界预切分,再逐段调用 API。
|
||||
on_progress 回调用于报告进度,签名: (current: int, total: int) -> None
|
||||
|
||||
Args:
|
||||
text: 待分块的文本
|
||||
content_type: 内容类型(document/table/qa/image)
|
||||
source_file: 源文件名(帮助 AI 判断业务标签)
|
||||
on_progress: 进度回调
|
||||
|
||||
Returns:
|
||||
Chunk 列表
|
||||
"""
|
||||
# 图片类内容通常很短,直接作为单个分块,标签从文件名推断
|
||||
if content_type == CONTENT_TYPE_IMAGE:
|
||||
tag = self._infer_image_tag(source_file)
|
||||
return [Chunk(title="图片内容识别", content=text.strip(), tag=tag)]
|
||||
|
||||
# 小文本(< 800 字符)直接作为单个分块,避免 AI 过度拆分
|
||||
if len(text.strip()) < 800:
|
||||
return self._call_api(text, content_type, source_file)
|
||||
|
||||
if len(text) <= self.PRE_SPLIT_SIZE:
|
||||
return self._call_api(text, content_type, source_file)
|
||||
|
||||
segments = self._pre_split(text)
|
||||
total = len(segments)
|
||||
all_chunks: List[Chunk] = []
|
||||
|
||||
for i, segment in enumerate(segments, start=1):
|
||||
chunks = self._call_api(segment, content_type, source_file)
|
||||
all_chunks.extend(chunks)
|
||||
if on_progress is not None:
|
||||
on_progress(i, total)
|
||||
|
||||
return all_chunks
|
||||
|
||||
@staticmethod
|
||||
def _infer_image_tag(source_file: str) -> str:
|
||||
"""从图片文件名推断业务标签。"""
|
||||
if not source_file:
|
||||
return "产品图片"
|
||||
# 去掉扩展名和数字后缀,作为标签参考
|
||||
import os
|
||||
name = os.path.splitext(source_file)[0]
|
||||
# 去掉末尾数字(如 "CC套装2" → "CC套装")
|
||||
name = re.sub(r"\d+$", "", name).strip()
|
||||
return "产品图片" if not name else "产品图片"
|
||||
|
||||
def _pre_split(self, text: str) -> List[str]:
|
||||
"""按段落边界预切分文本,每段不超过 PRE_SPLIT_SIZE。
|
||||
|
||||
策略:
|
||||
1. 按双换行符分割为段落列表
|
||||
2. 贪心合并段落,使每段不超过 PRE_SPLIT_SIZE
|
||||
3. 若单个段落超过 PRE_SPLIT_SIZE,按单换行符进一步切分
|
||||
4. 若单行仍超限,按字符数硬切分(保留最后一个完整句子)
|
||||
"""
|
||||
paragraphs = text.split("\n\n")
|
||||
# 展开超长段落:按单换行符进一步切分
|
||||
lines: List[str] = []
|
||||
for para in paragraphs:
|
||||
if len(para) <= self.PRE_SPLIT_SIZE:
|
||||
lines.append(para)
|
||||
else:
|
||||
# 按单换行符切分
|
||||
sub_lines = para.split("\n")
|
||||
for line in sub_lines:
|
||||
if len(line) <= self.PRE_SPLIT_SIZE:
|
||||
lines.append(line)
|
||||
else:
|
||||
# 硬切分超长单行
|
||||
lines.extend(self._hard_split(line))
|
||||
|
||||
# 贪心合并,用双换行符重新连接段落
|
||||
segments: List[str] = []
|
||||
current = ""
|
||||
|
||||
for line in lines:
|
||||
candidate = f"{current}\n\n{line}" if current else line
|
||||
if len(candidate) <= self.PRE_SPLIT_SIZE:
|
||||
current = candidate
|
||||
else:
|
||||
if current:
|
||||
segments.append(current)
|
||||
current = line
|
||||
|
||||
if current:
|
||||
segments.append(current)
|
||||
|
||||
return segments
|
||||
|
||||
def _hard_split(self, text: str) -> List[str]:
|
||||
"""按字符数硬切分超长文本,尽量在句子边界切分。"""
|
||||
result: List[str] = []
|
||||
remaining = text
|
||||
|
||||
while len(remaining) > self.PRE_SPLIT_SIZE:
|
||||
chunk = remaining[: self.PRE_SPLIT_SIZE]
|
||||
# 尝试在句子边界切分(从后往前找句号等标点)
|
||||
cut = self._find_sentence_boundary(chunk)
|
||||
result.append(remaining[:cut])
|
||||
remaining = remaining[cut:]
|
||||
|
||||
if remaining:
|
||||
result.append(remaining)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _find_sentence_boundary(text: str) -> int:
|
||||
"""在文本中从后往前查找最后一个完整句子的边界。"""
|
||||
sentence_endings = ("。", "!", "?", ".", "!", "?", "\n")
|
||||
# 从末尾往前搜索,至少保留一半内容
|
||||
search_start = len(text) // 2
|
||||
for i in range(len(text) - 1, search_start - 1, -1):
|
||||
if text[i] in sentence_endings:
|
||||
return i + 1
|
||||
# 找不到句子边界,直接在 PRE_SPLIT_SIZE 处切分
|
||||
return len(text)
|
||||
|
||||
def _call_api(self, text_segment: str, content_type: str = CONTENT_TYPE_DOCUMENT, source_file: str = "") -> List[Chunk]:
|
||||
"""调用 DeepSeek API 对单段文本进行语义分块。"""
|
||||
system_prompt = get_system_prompt(self._delimiter, content_type=content_type)
|
||||
user_prompt = get_user_prompt(text_segment, source_file=source_file)
|
||||
response = self._api_client.chat(system_prompt, user_prompt)
|
||||
return self._parse_response(response)
|
||||
|
||||
def _parse_response(self, response: str) -> List[Chunk]:
|
||||
"""解析 API 返回的分块结果。
|
||||
|
||||
按 delimiter 独占一行 分割响应文本,提取业务标签、摘要标题和分块内容。
|
||||
标题格式:[标签名] 摘要标题
|
||||
解析失败时抛出 ApiError。
|
||||
|
||||
注意:使用正则匹配 delimiter 独占一行的情况,避免与 Markdown 表格
|
||||
语法 `| --- | --- |` 中的 `---` 冲突。
|
||||
"""
|
||||
if not response or not response.strip():
|
||||
raise ApiError("API 返回空响应")
|
||||
|
||||
# 只匹配 delimiter 独占一行的情况,避免与 Markdown 表格 | --- | 冲突
|
||||
padded = f"\n{response}\n"
|
||||
parts = padded.split(f"\n{self._delimiter}\n")
|
||||
chunks: List[Chunk] = []
|
||||
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
|
||||
lines = part.split("\n")
|
||||
# 找到第一个非空行作为标题
|
||||
title = ""
|
||||
content_start = 0
|
||||
for j, line in enumerate(lines):
|
||||
if line.strip():
|
||||
title = line.strip()
|
||||
content_start = j + 1
|
||||
break
|
||||
|
||||
if not title:
|
||||
continue
|
||||
|
||||
# 提取 [标签] 前缀
|
||||
tag = ""
|
||||
match = _TAG_PATTERN.match(title)
|
||||
if match:
|
||||
tag = match.group(1)
|
||||
title = match.group(2)
|
||||
|
||||
# 剩余内容作为分块正文
|
||||
content = "\n".join(lines[content_start:]).strip()
|
||||
chunks.append(Chunk(title=title, content=content, tag=tag))
|
||||
|
||||
if not chunks:
|
||||
raise ApiError(f"无法解析 API 响应为有效分块: {response[:200]}")
|
||||
|
||||
return chunks
|
||||
Reference in New Issue
Block a user