Files
bigwo/splitter.py
2026-03-02 17:38:28 +08:00

128 lines
5.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Splitter 协调器编排文件解析、AI 分块和输出的完整流程"""
import os
from typing import Optional
from api_client import ApiClient
from chunker import AIChunker
from exceptions import ApiError, ParseError, UnsupportedFormatError
from parsers.base import ParserRegistry
from parsers.csv_parser import CsvParser
from parsers.doc_parser import DocParser
from parsers.html_parser import HtmlParser
from parsers.image_parser import ImageParser
from parsers.legacy_doc_parser import LegacyDocParser
from parsers.pdf_parser import PdfParser
from parsers.text_parser import TextParser
from parsers.xls_parser import XlsParser
from parsers.xlsx_parser import XlsxParser
from prompts import detect_content_type, CONTENT_TYPE_IMAGE, CONTENT_TYPE_TABLE
from writer import JsonWriter, MarkdownWriter
# 表格类扩展名 — Coze 有专门的表格知识库,不需要 AI 分块
TABLE_EXTENSIONS = {".xlsx", ".xls", ".csv"}
class Splitter:
"""主协调器编排文件解析、AI 分块和输出的完整流程"""
def __init__(
self,
api_key: str,
delimiter: str = "---",
pre_split_size: Optional[int] = None,
vision_prompt: Optional[str] = None,
output_format: str = "markdown",
):
self._api_client = ApiClient(api_key=api_key)
self._delimiter = delimiter
# 创建并注册所有解析器
self._registry = ParserRegistry()
self._registry.register(TextParser())
self._registry.register(CsvParser())
self._registry.register(HtmlParser())
self._registry.register(PdfParser())
self._registry.register(DocParser())
self._registry.register(LegacyDocParser())
self._registry.register(XlsxParser())
self._registry.register(XlsParser())
self._registry.register(ImageParser(self._api_client, vision_prompt=vision_prompt))
self._chunker = AIChunker(self._api_client, delimiter, pre_split_size=pre_split_size)
# 根据输出格式选择写入器
if output_format == "json":
self._writer = JsonWriter()
else:
self._writer = MarkdownWriter()
def process(self, input_path: str, output_path: str) -> None:
"""
执行完整的处理流程:
1. 验证输入文件存在
2. 根据文件类型选择解析器
3. 解析文件获取文本
4. AI 语义分块
5. 写入输出
6. 输出处理摘要
"""
# 1. 验证文件存在
if not os.path.exists(input_path):
raise FileNotFoundError(f"输入文件不存在: {input_path}")
file_name = os.path.basename(input_path)
file_size = os.path.getsize(input_path)
size_str = f"{file_size / 1024:.1f}KB" if file_size < 1024 * 1024 else f"{file_size / 1024 / 1024:.1f}MB"
# 表格类文件Coze 有专门的表格知识库,不需要 AI 分块
ext = os.path.splitext(input_path)[1].lower()
if ext in TABLE_EXTENSIONS:
print(f" ⚠️ 表格文件 {file_name} 请直接上传到 Coze「表格知识库」无需分块处理")
print(f" 批量模式(-b会自动将表格文件复制到 tables/ 文件夹")
return
# 2. 选择解析器
parser = self._registry.get_parser(input_path)
# 3. 解析文件
print(f" [1/4] 解析文件: {file_name} ({size_str})")
text = parser.parse(input_path)
print(f" 提取文本: {len(text)} 字符")
# 3.5 检测内容类型
ext = os.path.splitext(input_path)[1].lower()
content_type = detect_content_type(ext, text)
type_labels = {
"document": "文档", "table": "表格", "qa": "问答", "image": "图片",
}
print(f" 内容类型: {type_labels.get(content_type, content_type)}")
# 4. AI 分块(图片类跳过 API 调用)
if content_type == CONTENT_TYPE_IMAGE:
print(f" [2/4] 图片内容,跳过 AI 分块")
else:
print(f" [2/4] AI 语义分块中({type_labels.get(content_type, '通用')}策略)...")
def progress_callback(current: int, total: int) -> None:
print(f" 分块进度: {current}/{total}")
chunks = self._chunker.chunk(text, content_type=content_type, source_file=file_name, on_progress=progress_callback)
# 5. 写入输出
print(f" [3/4] 写入输出: {output_path}")
source_file = os.path.basename(input_path)
self._writer.write(chunks, output_path, source_file, self._delimiter)
# 6. 输出摘要
chunk_sizes = [len(c.content) for c in chunks]
avg_size = sum(chunk_sizes) // max(len(chunk_sizes), 1)
max_size = max(chunk_sizes) if chunk_sizes else 0
min_size = min(chunk_sizes) if chunk_sizes else 0
print(f" [4/4] 完成! 共 {len(chunks)} 个分块 (平均 {avg_size} 字, 最大 {max_size} 字, 最小 {min_size} 字)")
# 检查是否有超大分块(可能不适合知识库平台)
oversized = [i + 1 for i, s in enumerate(chunk_sizes) if s > 800]
if oversized:
print(f" ⚠️ 分块 {oversized} 超过 800 字,上传知识库时建议检查")