219 lines
7.9 KiB
Python
219 lines
7.9 KiB
Python
|
|
"""AI 分块器,通过 DeepSeek API 进行语义级智能分块"""
|
|||
|
|
|
|||
|
|
import re
|
|||
|
|
from typing import Callable, List, Optional
|
|||
|
|
|
|||
|
|
from api_client import ApiClient
|
|||
|
|
from exceptions import ApiError
|
|||
|
|
from models import Chunk
|
|||
|
|
from prompts import get_system_prompt, get_user_prompt, CONTENT_TYPE_DOCUMENT, CONTENT_TYPE_IMAGE
|
|||
|
|
|
|||
|
|
# 匹配 [标签名] 标题 格式
|
|||
|
|
_TAG_PATTERN = re.compile(r"^\[(.+?)\]\s*(.+)$")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AIChunker:
|
|||
|
|
"""通过 DeepSeek API 进行语义级智能分块"""
|
|||
|
|
|
|||
|
|
DEFAULT_PRE_SPLIT_SIZE = 12000 # 默认预切分字符数,适合中文文档
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
api_client: ApiClient,
|
|||
|
|
delimiter: str = "---",
|
|||
|
|
pre_split_size: int = None,
|
|||
|
|
):
|
|||
|
|
self._api_client = api_client
|
|||
|
|
self._delimiter = delimiter
|
|||
|
|
self.PRE_SPLIT_SIZE = pre_split_size or self.DEFAULT_PRE_SPLIT_SIZE
|
|||
|
|
|
|||
|
|
def chunk(
|
|||
|
|
self,
|
|||
|
|
text: str,
|
|||
|
|
content_type: str = CONTENT_TYPE_DOCUMENT,
|
|||
|
|
source_file: str = "",
|
|||
|
|
on_progress: Optional[Callable[[int, int], None]] = None,
|
|||
|
|
) -> List[Chunk]:
|
|||
|
|
"""
|
|||
|
|
对文本进行 AI 语义分块。
|
|||
|
|
|
|||
|
|
图片类内容直接作为单个分块返回(通常很短,不需要再分块)。
|
|||
|
|
若文本超过 PRE_SPLIT_SIZE,先按段落边界预切分,再逐段调用 API。
|
|||
|
|
on_progress 回调用于报告进度,签名: (current: int, total: int) -> None
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 待分块的文本
|
|||
|
|
content_type: 内容类型(document/table/qa/image)
|
|||
|
|
source_file: 源文件名(帮助 AI 判断业务标签)
|
|||
|
|
on_progress: 进度回调
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Chunk 列表
|
|||
|
|
"""
|
|||
|
|
# 图片类内容通常很短,直接作为单个分块,标签从文件名推断
|
|||
|
|
if content_type == CONTENT_TYPE_IMAGE:
|
|||
|
|
tag = self._infer_image_tag(source_file)
|
|||
|
|
return [Chunk(title="图片内容识别", content=text.strip(), tag=tag)]
|
|||
|
|
|
|||
|
|
# 小文本(< 800 字符)直接作为单个分块,避免 AI 过度拆分
|
|||
|
|
if len(text.strip()) < 800:
|
|||
|
|
return self._call_api(text, content_type, source_file)
|
|||
|
|
|
|||
|
|
if len(text) <= self.PRE_SPLIT_SIZE:
|
|||
|
|
return self._call_api(text, content_type, source_file)
|
|||
|
|
|
|||
|
|
segments = self._pre_split(text)
|
|||
|
|
total = len(segments)
|
|||
|
|
all_chunks: List[Chunk] = []
|
|||
|
|
|
|||
|
|
for i, segment in enumerate(segments, start=1):
|
|||
|
|
chunks = self._call_api(segment, content_type, source_file)
|
|||
|
|
all_chunks.extend(chunks)
|
|||
|
|
if on_progress is not None:
|
|||
|
|
on_progress(i, total)
|
|||
|
|
|
|||
|
|
return all_chunks
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _infer_image_tag(source_file: str) -> str:
|
|||
|
|
"""从图片文件名推断业务标签。"""
|
|||
|
|
if not source_file:
|
|||
|
|
return "产品图片"
|
|||
|
|
# 去掉扩展名和数字后缀,作为标签参考
|
|||
|
|
import os
|
|||
|
|
name = os.path.splitext(source_file)[0]
|
|||
|
|
# 去掉末尾数字(如 "CC套装2" → "CC套装")
|
|||
|
|
name = re.sub(r"\d+$", "", name).strip()
|
|||
|
|
return "产品图片" if not name else "产品图片"
|
|||
|
|
|
|||
|
|
def _pre_split(self, text: str) -> List[str]:
|
|||
|
|
"""按段落边界预切分文本,每段不超过 PRE_SPLIT_SIZE。
|
|||
|
|
|
|||
|
|
策略:
|
|||
|
|
1. 按双换行符分割为段落列表
|
|||
|
|
2. 贪心合并段落,使每段不超过 PRE_SPLIT_SIZE
|
|||
|
|
3. 若单个段落超过 PRE_SPLIT_SIZE,按单换行符进一步切分
|
|||
|
|
4. 若单行仍超限,按字符数硬切分(保留最后一个完整句子)
|
|||
|
|
"""
|
|||
|
|
paragraphs = text.split("\n\n")
|
|||
|
|
# 展开超长段落:按单换行符进一步切分
|
|||
|
|
lines: List[str] = []
|
|||
|
|
for para in paragraphs:
|
|||
|
|
if len(para) <= self.PRE_SPLIT_SIZE:
|
|||
|
|
lines.append(para)
|
|||
|
|
else:
|
|||
|
|
# 按单换行符切分
|
|||
|
|
sub_lines = para.split("\n")
|
|||
|
|
for line in sub_lines:
|
|||
|
|
if len(line) <= self.PRE_SPLIT_SIZE:
|
|||
|
|
lines.append(line)
|
|||
|
|
else:
|
|||
|
|
# 硬切分超长单行
|
|||
|
|
lines.extend(self._hard_split(line))
|
|||
|
|
|
|||
|
|
# 贪心合并,用双换行符重新连接段落
|
|||
|
|
segments: List[str] = []
|
|||
|
|
current = ""
|
|||
|
|
|
|||
|
|
for line in lines:
|
|||
|
|
candidate = f"{current}\n\n{line}" if current else line
|
|||
|
|
if len(candidate) <= self.PRE_SPLIT_SIZE:
|
|||
|
|
current = candidate
|
|||
|
|
else:
|
|||
|
|
if current:
|
|||
|
|
segments.append(current)
|
|||
|
|
current = line
|
|||
|
|
|
|||
|
|
if current:
|
|||
|
|
segments.append(current)
|
|||
|
|
|
|||
|
|
return segments
|
|||
|
|
|
|||
|
|
def _hard_split(self, text: str) -> List[str]:
|
|||
|
|
"""按字符数硬切分超长文本,尽量在句子边界切分。"""
|
|||
|
|
result: List[str] = []
|
|||
|
|
remaining = text
|
|||
|
|
|
|||
|
|
while len(remaining) > self.PRE_SPLIT_SIZE:
|
|||
|
|
chunk = remaining[: self.PRE_SPLIT_SIZE]
|
|||
|
|
# 尝试在句子边界切分(从后往前找句号等标点)
|
|||
|
|
cut = self._find_sentence_boundary(chunk)
|
|||
|
|
result.append(remaining[:cut])
|
|||
|
|
remaining = remaining[cut:]
|
|||
|
|
|
|||
|
|
if remaining:
|
|||
|
|
result.append(remaining)
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _find_sentence_boundary(text: str) -> int:
|
|||
|
|
"""在文本中从后往前查找最后一个完整句子的边界。"""
|
|||
|
|
sentence_endings = ("。", "!", "?", ".", "!", "?", "\n")
|
|||
|
|
# 从末尾往前搜索,至少保留一半内容
|
|||
|
|
search_start = len(text) // 2
|
|||
|
|
for i in range(len(text) - 1, search_start - 1, -1):
|
|||
|
|
if text[i] in sentence_endings:
|
|||
|
|
return i + 1
|
|||
|
|
# 找不到句子边界,直接在 PRE_SPLIT_SIZE 处切分
|
|||
|
|
return len(text)
|
|||
|
|
|
|||
|
|
def _call_api(self, text_segment: str, content_type: str = CONTENT_TYPE_DOCUMENT, source_file: str = "") -> List[Chunk]:
|
|||
|
|
"""调用 DeepSeek API 对单段文本进行语义分块。"""
|
|||
|
|
system_prompt = get_system_prompt(self._delimiter, content_type=content_type)
|
|||
|
|
user_prompt = get_user_prompt(text_segment, source_file=source_file)
|
|||
|
|
response = self._api_client.chat(system_prompt, user_prompt)
|
|||
|
|
return self._parse_response(response)
|
|||
|
|
|
|||
|
|
def _parse_response(self, response: str) -> List[Chunk]:
|
|||
|
|
"""解析 API 返回的分块结果。
|
|||
|
|
|
|||
|
|
按 delimiter 独占一行 分割响应文本,提取业务标签、摘要标题和分块内容。
|
|||
|
|
标题格式:[标签名] 摘要标题
|
|||
|
|
解析失败时抛出 ApiError。
|
|||
|
|
|
|||
|
|
注意:使用正则匹配 delimiter 独占一行的情况,避免与 Markdown 表格
|
|||
|
|
语法 `| --- | --- |` 中的 `---` 冲突。
|
|||
|
|
"""
|
|||
|
|
if not response or not response.strip():
|
|||
|
|
raise ApiError("API 返回空响应")
|
|||
|
|
|
|||
|
|
# 只匹配 delimiter 独占一行的情况,避免与 Markdown 表格 | --- | 冲突
|
|||
|
|
padded = f"\n{response}\n"
|
|||
|
|
parts = padded.split(f"\n{self._delimiter}\n")
|
|||
|
|
chunks: List[Chunk] = []
|
|||
|
|
|
|||
|
|
for part in parts:
|
|||
|
|
part = part.strip()
|
|||
|
|
if not part:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
lines = part.split("\n")
|
|||
|
|
# 找到第一个非空行作为标题
|
|||
|
|
title = ""
|
|||
|
|
content_start = 0
|
|||
|
|
for j, line in enumerate(lines):
|
|||
|
|
if line.strip():
|
|||
|
|
title = line.strip()
|
|||
|
|
content_start = j + 1
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not title:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 提取 [标签] 前缀
|
|||
|
|
tag = ""
|
|||
|
|
match = _TAG_PATTERN.match(title)
|
|||
|
|
if match:
|
|||
|
|
tag = match.group(1)
|
|||
|
|
title = match.group(2)
|
|||
|
|
|
|||
|
|
# 剩余内容作为分块正文
|
|||
|
|
content = "\n".join(lines[content_start:]).strip()
|
|||
|
|
chunks.append(Chunk(title=title, content=content, tag=tag))
|
|||
|
|
|
|||
|
|
if not chunks:
|
|||
|
|
raise ApiError(f"无法解析 API 响应为有效分块: {response[:200]}")
|
|||
|
|
|
|||
|
|
return chunks
|