Files
bigwo/chunker.py
2026-03-02 17:38:28 +08:00

219 lines
7.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""AI 分块器,通过 DeepSeek API 进行语义级智能分块"""
import re
from typing import Callable, List, Optional
from api_client import ApiClient
from exceptions import ApiError
from models import Chunk
from prompts import get_system_prompt, get_user_prompt, CONTENT_TYPE_DOCUMENT, CONTENT_TYPE_IMAGE
# 匹配 [标签名] 标题 格式
_TAG_PATTERN = re.compile(r"^\[(.+?)\]\s*(.+)$")
class AIChunker:
"""通过 DeepSeek API 进行语义级智能分块"""
DEFAULT_PRE_SPLIT_SIZE = 12000 # 默认预切分字符数,适合中文文档
def __init__(
self,
api_client: ApiClient,
delimiter: str = "---",
pre_split_size: int = None,
):
self._api_client = api_client
self._delimiter = delimiter
self.PRE_SPLIT_SIZE = pre_split_size or self.DEFAULT_PRE_SPLIT_SIZE
def chunk(
self,
text: str,
content_type: str = CONTENT_TYPE_DOCUMENT,
source_file: str = "",
on_progress: Optional[Callable[[int, int], None]] = None,
) -> List[Chunk]:
"""
对文本进行 AI 语义分块。
图片类内容直接作为单个分块返回(通常很短,不需要再分块)。
若文本超过 PRE_SPLIT_SIZE先按段落边界预切分再逐段调用 API。
on_progress 回调用于报告进度,签名: (current: int, total: int) -> None
Args:
text: 待分块的文本
content_type: 内容类型document/table/qa/image
source_file: 源文件名(帮助 AI 判断业务标签)
on_progress: 进度回调
Returns:
Chunk 列表
"""
# 图片类内容通常很短,直接作为单个分块,标签从文件名推断
if content_type == CONTENT_TYPE_IMAGE:
tag = self._infer_image_tag(source_file)
return [Chunk(title="图片内容识别", content=text.strip(), tag=tag)]
# 小文本(< 800 字符)直接作为单个分块,避免 AI 过度拆分
if len(text.strip()) < 800:
return self._call_api(text, content_type, source_file)
if len(text) <= self.PRE_SPLIT_SIZE:
return self._call_api(text, content_type, source_file)
segments = self._pre_split(text)
total = len(segments)
all_chunks: List[Chunk] = []
for i, segment in enumerate(segments, start=1):
chunks = self._call_api(segment, content_type, source_file)
all_chunks.extend(chunks)
if on_progress is not None:
on_progress(i, total)
return all_chunks
@staticmethod
def _infer_image_tag(source_file: str) -> str:
"""从图片文件名推断业务标签。"""
if not source_file:
return "产品图片"
# 去掉扩展名和数字后缀,作为标签参考
import os
name = os.path.splitext(source_file)[0]
# 去掉末尾数字(如 "CC套装2" → "CC套装"
name = re.sub(r"\d+$", "", name).strip()
return "产品图片" if not name else "产品图片"
def _pre_split(self, text: str) -> List[str]:
"""按段落边界预切分文本,每段不超过 PRE_SPLIT_SIZE。
策略:
1. 按双换行符分割为段落列表
2. 贪心合并段落,使每段不超过 PRE_SPLIT_SIZE
3. 若单个段落超过 PRE_SPLIT_SIZE按单换行符进一步切分
4. 若单行仍超限,按字符数硬切分(保留最后一个完整句子)
"""
paragraphs = text.split("\n\n")
# 展开超长段落:按单换行符进一步切分
lines: List[str] = []
for para in paragraphs:
if len(para) <= self.PRE_SPLIT_SIZE:
lines.append(para)
else:
# 按单换行符切分
sub_lines = para.split("\n")
for line in sub_lines:
if len(line) <= self.PRE_SPLIT_SIZE:
lines.append(line)
else:
# 硬切分超长单行
lines.extend(self._hard_split(line))
# 贪心合并,用双换行符重新连接段落
segments: List[str] = []
current = ""
for line in lines:
candidate = f"{current}\n\n{line}" if current else line
if len(candidate) <= self.PRE_SPLIT_SIZE:
current = candidate
else:
if current:
segments.append(current)
current = line
if current:
segments.append(current)
return segments
def _hard_split(self, text: str) -> List[str]:
"""按字符数硬切分超长文本,尽量在句子边界切分。"""
result: List[str] = []
remaining = text
while len(remaining) > self.PRE_SPLIT_SIZE:
chunk = remaining[: self.PRE_SPLIT_SIZE]
# 尝试在句子边界切分(从后往前找句号等标点)
cut = self._find_sentence_boundary(chunk)
result.append(remaining[:cut])
remaining = remaining[cut:]
if remaining:
result.append(remaining)
return result
@staticmethod
def _find_sentence_boundary(text: str) -> int:
"""在文本中从后往前查找最后一个完整句子的边界。"""
sentence_endings = ("", "", "", ".", "!", "?", "\n")
# 从末尾往前搜索,至少保留一半内容
search_start = len(text) // 2
for i in range(len(text) - 1, search_start - 1, -1):
if text[i] in sentence_endings:
return i + 1
# 找不到句子边界,直接在 PRE_SPLIT_SIZE 处切分
return len(text)
def _call_api(self, text_segment: str, content_type: str = CONTENT_TYPE_DOCUMENT, source_file: str = "") -> List[Chunk]:
"""调用 DeepSeek API 对单段文本进行语义分块。"""
system_prompt = get_system_prompt(self._delimiter, content_type=content_type)
user_prompt = get_user_prompt(text_segment, source_file=source_file)
response = self._api_client.chat(system_prompt, user_prompt)
return self._parse_response(response)
def _parse_response(self, response: str) -> List[Chunk]:
"""解析 API 返回的分块结果。
按 delimiter 独占一行 分割响应文本,提取业务标签、摘要标题和分块内容。
标题格式:[标签名] 摘要标题
解析失败时抛出 ApiError。
注意:使用正则匹配 delimiter 独占一行的情况,避免与 Markdown 表格
语法 `| --- | --- |` 中的 `---` 冲突。
"""
if not response or not response.strip():
raise ApiError("API 返回空响应")
# 只匹配 delimiter 独占一行的情况,避免与 Markdown 表格 | --- | 冲突
padded = f"\n{response}\n"
parts = padded.split(f"\n{self._delimiter}\n")
chunks: List[Chunk] = []
for part in parts:
part = part.strip()
if not part:
continue
lines = part.split("\n")
# 找到第一个非空行作为标题
title = ""
content_start = 0
for j, line in enumerate(lines):
if line.strip():
title = line.strip()
content_start = j + 1
break
if not title:
continue
# 提取 [标签] 前缀
tag = ""
match = _TAG_PATTERN.match(title)
if match:
tag = match.group(1)
title = match.group(2)
# 剩余内容作为分块正文
content = "\n".join(lines[content_start:]).strip()
chunks.append(Chunk(title=title, content=content, tag=tag))
if not chunks:
raise ApiError(f"无法解析 API 响应为有效分块: {response[:200]}")
return chunks