128 lines
5.2 KiB
Python
128 lines
5.2 KiB
Python
"""Splitter 协调器,编排文件解析、AI 分块和输出的完整流程"""
|
||
|
||
import os
|
||
from typing import Optional
|
||
|
||
from api_client import ApiClient
|
||
from chunker import AIChunker
|
||
from exceptions import ApiError, ParseError, UnsupportedFormatError
|
||
from parsers.base import ParserRegistry
|
||
from parsers.csv_parser import CsvParser
|
||
from parsers.doc_parser import DocParser
|
||
from parsers.html_parser import HtmlParser
|
||
from parsers.image_parser import ImageParser
|
||
from parsers.legacy_doc_parser import LegacyDocParser
|
||
from parsers.pdf_parser import PdfParser
|
||
from parsers.text_parser import TextParser
|
||
from parsers.xls_parser import XlsParser
|
||
from parsers.xlsx_parser import XlsxParser
|
||
from prompts import detect_content_type, CONTENT_TYPE_IMAGE, CONTENT_TYPE_TABLE
|
||
from writer import JsonWriter, MarkdownWriter
|
||
|
||
# 表格类扩展名 — Coze 有专门的表格知识库,不需要 AI 分块
|
||
TABLE_EXTENSIONS = {".xlsx", ".xls", ".csv"}
|
||
|
||
|
||
class Splitter:
|
||
"""主协调器,编排文件解析、AI 分块和输出的完整流程"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str,
|
||
delimiter: str = "---",
|
||
pre_split_size: Optional[int] = None,
|
||
vision_prompt: Optional[str] = None,
|
||
output_format: str = "markdown",
|
||
):
|
||
self._api_client = ApiClient(api_key=api_key)
|
||
self._delimiter = delimiter
|
||
|
||
# 创建并注册所有解析器
|
||
self._registry = ParserRegistry()
|
||
self._registry.register(TextParser())
|
||
self._registry.register(CsvParser())
|
||
self._registry.register(HtmlParser())
|
||
self._registry.register(PdfParser())
|
||
self._registry.register(DocParser())
|
||
self._registry.register(LegacyDocParser())
|
||
self._registry.register(XlsxParser())
|
||
self._registry.register(XlsParser())
|
||
self._registry.register(ImageParser(self._api_client, vision_prompt=vision_prompt))
|
||
|
||
self._chunker = AIChunker(self._api_client, delimiter, pre_split_size=pre_split_size)
|
||
|
||
# 根据输出格式选择写入器
|
||
if output_format == "json":
|
||
self._writer = JsonWriter()
|
||
else:
|
||
self._writer = MarkdownWriter()
|
||
|
||
def process(self, input_path: str, output_path: str) -> None:
|
||
"""
|
||
执行完整的处理流程:
|
||
1. 验证输入文件存在
|
||
2. 根据文件类型选择解析器
|
||
3. 解析文件获取文本
|
||
4. AI 语义分块
|
||
5. 写入输出
|
||
6. 输出处理摘要
|
||
"""
|
||
# 1. 验证文件存在
|
||
if not os.path.exists(input_path):
|
||
raise FileNotFoundError(f"输入文件不存在: {input_path}")
|
||
|
||
file_name = os.path.basename(input_path)
|
||
file_size = os.path.getsize(input_path)
|
||
size_str = f"{file_size / 1024:.1f}KB" if file_size < 1024 * 1024 else f"{file_size / 1024 / 1024:.1f}MB"
|
||
|
||
# 表格类文件:Coze 有专门的表格知识库,不需要 AI 分块
|
||
ext = os.path.splitext(input_path)[1].lower()
|
||
if ext in TABLE_EXTENSIONS:
|
||
print(f" ⚠️ 表格文件 {file_name} 请直接上传到 Coze「表格知识库」,无需分块处理")
|
||
print(f" 批量模式(-b)会自动将表格文件复制到 tables/ 文件夹")
|
||
return
|
||
|
||
# 2. 选择解析器
|
||
parser = self._registry.get_parser(input_path)
|
||
|
||
# 3. 解析文件
|
||
print(f" [1/4] 解析文件: {file_name} ({size_str})")
|
||
text = parser.parse(input_path)
|
||
print(f" 提取文本: {len(text)} 字符")
|
||
|
||
# 3.5 检测内容类型
|
||
ext = os.path.splitext(input_path)[1].lower()
|
||
content_type = detect_content_type(ext, text)
|
||
type_labels = {
|
||
"document": "文档", "table": "表格", "qa": "问答", "image": "图片",
|
||
}
|
||
print(f" 内容类型: {type_labels.get(content_type, content_type)}")
|
||
|
||
# 4. AI 分块(图片类跳过 API 调用)
|
||
if content_type == CONTENT_TYPE_IMAGE:
|
||
print(f" [2/4] 图片内容,跳过 AI 分块")
|
||
else:
|
||
print(f" [2/4] AI 语义分块中({type_labels.get(content_type, '通用')}策略)...")
|
||
|
||
def progress_callback(current: int, total: int) -> None:
|
||
print(f" 分块进度: {current}/{total}")
|
||
|
||
chunks = self._chunker.chunk(text, content_type=content_type, source_file=file_name, on_progress=progress_callback)
|
||
|
||
# 5. 写入输出
|
||
print(f" [3/4] 写入输出: {output_path}")
|
||
source_file = os.path.basename(input_path)
|
||
self._writer.write(chunks, output_path, source_file, self._delimiter)
|
||
|
||
# 6. 输出摘要
|
||
chunk_sizes = [len(c.content) for c in chunks]
|
||
avg_size = sum(chunk_sizes) // max(len(chunk_sizes), 1)
|
||
max_size = max(chunk_sizes) if chunk_sizes else 0
|
||
min_size = min(chunk_sizes) if chunk_sizes else 0
|
||
print(f" [4/4] 完成! 共 {len(chunks)} 个分块 (平均 {avg_size} 字, 最大 {max_size} 字, 最小 {min_size} 字)")
|
||
|
||
# 检查是否有超大分块(可能不适合知识库平台)
|
||
oversized = [i + 1 for i, s in enumerate(chunk_sizes) if s > 800]
|
||
if oversized:
|
||
print(f" ⚠️ 分块 {oversized} 超过 800 字,上传知识库时建议检查")
|