219 lines
7.9 KiB
Python
219 lines
7.9 KiB
Python
"""AI 分块器,通过 DeepSeek API 进行语义级智能分块"""
|
||
|
||
import re
|
||
from typing import Callable, List, Optional
|
||
|
||
from api_client import ApiClient
|
||
from exceptions import ApiError
|
||
from models import Chunk
|
||
from prompts import get_system_prompt, get_user_prompt, CONTENT_TYPE_DOCUMENT, CONTENT_TYPE_IMAGE
|
||
|
||
# 匹配 [标签名] 标题 格式
|
||
_TAG_PATTERN = re.compile(r"^\[(.+?)\]\s*(.+)$")
|
||
|
||
|
||
class AIChunker:
|
||
"""通过 DeepSeek API 进行语义级智能分块"""
|
||
|
||
DEFAULT_PRE_SPLIT_SIZE = 12000 # 默认预切分字符数,适合中文文档
|
||
|
||
def __init__(
|
||
self,
|
||
api_client: ApiClient,
|
||
delimiter: str = "---",
|
||
pre_split_size: int = None,
|
||
):
|
||
self._api_client = api_client
|
||
self._delimiter = delimiter
|
||
self.PRE_SPLIT_SIZE = pre_split_size or self.DEFAULT_PRE_SPLIT_SIZE
|
||
|
||
def chunk(
|
||
self,
|
||
text: str,
|
||
content_type: str = CONTENT_TYPE_DOCUMENT,
|
||
source_file: str = "",
|
||
on_progress: Optional[Callable[[int, int], None]] = None,
|
||
) -> List[Chunk]:
|
||
"""
|
||
对文本进行 AI 语义分块。
|
||
|
||
图片类内容直接作为单个分块返回(通常很短,不需要再分块)。
|
||
若文本超过 PRE_SPLIT_SIZE,先按段落边界预切分,再逐段调用 API。
|
||
on_progress 回调用于报告进度,签名: (current: int, total: int) -> None
|
||
|
||
Args:
|
||
text: 待分块的文本
|
||
content_type: 内容类型(document/table/qa/image)
|
||
source_file: 源文件名(帮助 AI 判断业务标签)
|
||
on_progress: 进度回调
|
||
|
||
Returns:
|
||
Chunk 列表
|
||
"""
|
||
# 图片类内容通常很短,直接作为单个分块,标签从文件名推断
|
||
if content_type == CONTENT_TYPE_IMAGE:
|
||
tag = self._infer_image_tag(source_file)
|
||
return [Chunk(title="图片内容识别", content=text.strip(), tag=tag)]
|
||
|
||
# 小文本(< 800 字符)直接作为单个分块,避免 AI 过度拆分
|
||
if len(text.strip()) < 800:
|
||
return self._call_api(text, content_type, source_file)
|
||
|
||
if len(text) <= self.PRE_SPLIT_SIZE:
|
||
return self._call_api(text, content_type, source_file)
|
||
|
||
segments = self._pre_split(text)
|
||
total = len(segments)
|
||
all_chunks: List[Chunk] = []
|
||
|
||
for i, segment in enumerate(segments, start=1):
|
||
chunks = self._call_api(segment, content_type, source_file)
|
||
all_chunks.extend(chunks)
|
||
if on_progress is not None:
|
||
on_progress(i, total)
|
||
|
||
return all_chunks
|
||
|
||
@staticmethod
|
||
def _infer_image_tag(source_file: str) -> str:
|
||
"""从图片文件名推断业务标签。"""
|
||
if not source_file:
|
||
return "产品图片"
|
||
# 去掉扩展名和数字后缀,作为标签参考
|
||
import os
|
||
name = os.path.splitext(source_file)[0]
|
||
# 去掉末尾数字(如 "CC套装2" → "CC套装")
|
||
name = re.sub(r"\d+$", "", name).strip()
|
||
return "产品图片" if not name else "产品图片"
|
||
|
||
def _pre_split(self, text: str) -> List[str]:
|
||
"""按段落边界预切分文本,每段不超过 PRE_SPLIT_SIZE。
|
||
|
||
策略:
|
||
1. 按双换行符分割为段落列表
|
||
2. 贪心合并段落,使每段不超过 PRE_SPLIT_SIZE
|
||
3. 若单个段落超过 PRE_SPLIT_SIZE,按单换行符进一步切分
|
||
4. 若单行仍超限,按字符数硬切分(保留最后一个完整句子)
|
||
"""
|
||
paragraphs = text.split("\n\n")
|
||
# 展开超长段落:按单换行符进一步切分
|
||
lines: List[str] = []
|
||
for para in paragraphs:
|
||
if len(para) <= self.PRE_SPLIT_SIZE:
|
||
lines.append(para)
|
||
else:
|
||
# 按单换行符切分
|
||
sub_lines = para.split("\n")
|
||
for line in sub_lines:
|
||
if len(line) <= self.PRE_SPLIT_SIZE:
|
||
lines.append(line)
|
||
else:
|
||
# 硬切分超长单行
|
||
lines.extend(self._hard_split(line))
|
||
|
||
# 贪心合并,用双换行符重新连接段落
|
||
segments: List[str] = []
|
||
current = ""
|
||
|
||
for line in lines:
|
||
candidate = f"{current}\n\n{line}" if current else line
|
||
if len(candidate) <= self.PRE_SPLIT_SIZE:
|
||
current = candidate
|
||
else:
|
||
if current:
|
||
segments.append(current)
|
||
current = line
|
||
|
||
if current:
|
||
segments.append(current)
|
||
|
||
return segments
|
||
|
||
def _hard_split(self, text: str) -> List[str]:
|
||
"""按字符数硬切分超长文本,尽量在句子边界切分。"""
|
||
result: List[str] = []
|
||
remaining = text
|
||
|
||
while len(remaining) > self.PRE_SPLIT_SIZE:
|
||
chunk = remaining[: self.PRE_SPLIT_SIZE]
|
||
# 尝试在句子边界切分(从后往前找句号等标点)
|
||
cut = self._find_sentence_boundary(chunk)
|
||
result.append(remaining[:cut])
|
||
remaining = remaining[cut:]
|
||
|
||
if remaining:
|
||
result.append(remaining)
|
||
|
||
return result
|
||
|
||
@staticmethod
|
||
def _find_sentence_boundary(text: str) -> int:
|
||
"""在文本中从后往前查找最后一个完整句子的边界。"""
|
||
sentence_endings = ("。", "!", "?", ".", "!", "?", "\n")
|
||
# 从末尾往前搜索,至少保留一半内容
|
||
search_start = len(text) // 2
|
||
for i in range(len(text) - 1, search_start - 1, -1):
|
||
if text[i] in sentence_endings:
|
||
return i + 1
|
||
# 找不到句子边界,直接在 PRE_SPLIT_SIZE 处切分
|
||
return len(text)
|
||
|
||
def _call_api(self, text_segment: str, content_type: str = CONTENT_TYPE_DOCUMENT, source_file: str = "") -> List[Chunk]:
|
||
"""调用 DeepSeek API 对单段文本进行语义分块。"""
|
||
system_prompt = get_system_prompt(self._delimiter, content_type=content_type)
|
||
user_prompt = get_user_prompt(text_segment, source_file=source_file)
|
||
response = self._api_client.chat(system_prompt, user_prompt)
|
||
return self._parse_response(response)
|
||
|
||
def _parse_response(self, response: str) -> List[Chunk]:
|
||
"""解析 API 返回的分块结果。
|
||
|
||
按 delimiter 独占一行 分割响应文本,提取业务标签、摘要标题和分块内容。
|
||
标题格式:[标签名] 摘要标题
|
||
解析失败时抛出 ApiError。
|
||
|
||
注意:使用正则匹配 delimiter 独占一行的情况,避免与 Markdown 表格
|
||
语法 `| --- | --- |` 中的 `---` 冲突。
|
||
"""
|
||
if not response or not response.strip():
|
||
raise ApiError("API 返回空响应")
|
||
|
||
# 只匹配 delimiter 独占一行的情况,避免与 Markdown 表格 | --- | 冲突
|
||
padded = f"\n{response}\n"
|
||
parts = padded.split(f"\n{self._delimiter}\n")
|
||
chunks: List[Chunk] = []
|
||
|
||
for part in parts:
|
||
part = part.strip()
|
||
if not part:
|
||
continue
|
||
|
||
lines = part.split("\n")
|
||
# 找到第一个非空行作为标题
|
||
title = ""
|
||
content_start = 0
|
||
for j, line in enumerate(lines):
|
||
if line.strip():
|
||
title = line.strip()
|
||
content_start = j + 1
|
||
break
|
||
|
||
if not title:
|
||
continue
|
||
|
||
# 提取 [标签] 前缀
|
||
tag = ""
|
||
match = _TAG_PATTERN.match(title)
|
||
if match:
|
||
tag = match.group(1)
|
||
title = match.group(2)
|
||
|
||
# 剩余内容作为分块正文
|
||
content = "\n".join(lines[content_start:]).strip()
|
||
chunks.append(Chunk(title=title, content=content, tag=tag))
|
||
|
||
if not chunks:
|
||
raise ApiError(f"无法解析 API 响应为有效分块: {response[:200]}")
|
||
|
||
return chunks
|