128 lines
5.2 KiB
Python
128 lines
5.2 KiB
Python
|
|
"""Splitter 协调器,编排文件解析、AI 分块和输出的完整流程"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
from api_client import ApiClient
|
|||
|
|
from chunker import AIChunker
|
|||
|
|
from exceptions import ApiError, ParseError, UnsupportedFormatError
|
|||
|
|
from parsers.base import ParserRegistry
|
|||
|
|
from parsers.csv_parser import CsvParser
|
|||
|
|
from parsers.doc_parser import DocParser
|
|||
|
|
from parsers.html_parser import HtmlParser
|
|||
|
|
from parsers.image_parser import ImageParser
|
|||
|
|
from parsers.legacy_doc_parser import LegacyDocParser
|
|||
|
|
from parsers.pdf_parser import PdfParser
|
|||
|
|
from parsers.text_parser import TextParser
|
|||
|
|
from parsers.xls_parser import XlsParser
|
|||
|
|
from parsers.xlsx_parser import XlsxParser
|
|||
|
|
from prompts import detect_content_type, CONTENT_TYPE_IMAGE, CONTENT_TYPE_TABLE
|
|||
|
|
from writer import JsonWriter, MarkdownWriter
|
|||
|
|
|
|||
|
|
# 表格类扩展名 — Coze 有专门的表格知识库,不需要 AI 分块
|
|||
|
|
TABLE_EXTENSIONS = {".xlsx", ".xls", ".csv"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Splitter:
|
|||
|
|
"""主协调器,编排文件解析、AI 分块和输出的完整流程"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
api_key: str,
|
|||
|
|
delimiter: str = "---",
|
|||
|
|
pre_split_size: Optional[int] = None,
|
|||
|
|
vision_prompt: Optional[str] = None,
|
|||
|
|
output_format: str = "markdown",
|
|||
|
|
):
|
|||
|
|
self._api_client = ApiClient(api_key=api_key)
|
|||
|
|
self._delimiter = delimiter
|
|||
|
|
|
|||
|
|
# 创建并注册所有解析器
|
|||
|
|
self._registry = ParserRegistry()
|
|||
|
|
self._registry.register(TextParser())
|
|||
|
|
self._registry.register(CsvParser())
|
|||
|
|
self._registry.register(HtmlParser())
|
|||
|
|
self._registry.register(PdfParser())
|
|||
|
|
self._registry.register(DocParser())
|
|||
|
|
self._registry.register(LegacyDocParser())
|
|||
|
|
self._registry.register(XlsxParser())
|
|||
|
|
self._registry.register(XlsParser())
|
|||
|
|
self._registry.register(ImageParser(self._api_client, vision_prompt=vision_prompt))
|
|||
|
|
|
|||
|
|
self._chunker = AIChunker(self._api_client, delimiter, pre_split_size=pre_split_size)
|
|||
|
|
|
|||
|
|
# 根据输出格式选择写入器
|
|||
|
|
if output_format == "json":
|
|||
|
|
self._writer = JsonWriter()
|
|||
|
|
else:
|
|||
|
|
self._writer = MarkdownWriter()
|
|||
|
|
|
|||
|
|
def process(self, input_path: str, output_path: str) -> None:
|
|||
|
|
"""
|
|||
|
|
执行完整的处理流程:
|
|||
|
|
1. 验证输入文件存在
|
|||
|
|
2. 根据文件类型选择解析器
|
|||
|
|
3. 解析文件获取文本
|
|||
|
|
4. AI 语义分块
|
|||
|
|
5. 写入输出
|
|||
|
|
6. 输出处理摘要
|
|||
|
|
"""
|
|||
|
|
# 1. 验证文件存在
|
|||
|
|
if not os.path.exists(input_path):
|
|||
|
|
raise FileNotFoundError(f"输入文件不存在: {input_path}")
|
|||
|
|
|
|||
|
|
file_name = os.path.basename(input_path)
|
|||
|
|
file_size = os.path.getsize(input_path)
|
|||
|
|
size_str = f"{file_size / 1024:.1f}KB" if file_size < 1024 * 1024 else f"{file_size / 1024 / 1024:.1f}MB"
|
|||
|
|
|
|||
|
|
# 表格类文件:Coze 有专门的表格知识库,不需要 AI 分块
|
|||
|
|
ext = os.path.splitext(input_path)[1].lower()
|
|||
|
|
if ext in TABLE_EXTENSIONS:
|
|||
|
|
print(f" ⚠️ 表格文件 {file_name} 请直接上传到 Coze「表格知识库」,无需分块处理")
|
|||
|
|
print(f" 批量模式(-b)会自动将表格文件复制到 tables/ 文件夹")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 2. 选择解析器
|
|||
|
|
parser = self._registry.get_parser(input_path)
|
|||
|
|
|
|||
|
|
# 3. 解析文件
|
|||
|
|
print(f" [1/4] 解析文件: {file_name} ({size_str})")
|
|||
|
|
text = parser.parse(input_path)
|
|||
|
|
print(f" 提取文本: {len(text)} 字符")
|
|||
|
|
|
|||
|
|
# 3.5 检测内容类型
|
|||
|
|
ext = os.path.splitext(input_path)[1].lower()
|
|||
|
|
content_type = detect_content_type(ext, text)
|
|||
|
|
type_labels = {
|
|||
|
|
"document": "文档", "table": "表格", "qa": "问答", "image": "图片",
|
|||
|
|
}
|
|||
|
|
print(f" 内容类型: {type_labels.get(content_type, content_type)}")
|
|||
|
|
|
|||
|
|
# 4. AI 分块(图片类跳过 API 调用)
|
|||
|
|
if content_type == CONTENT_TYPE_IMAGE:
|
|||
|
|
print(f" [2/4] 图片内容,跳过 AI 分块")
|
|||
|
|
else:
|
|||
|
|
print(f" [2/4] AI 语义分块中({type_labels.get(content_type, '通用')}策略)...")
|
|||
|
|
|
|||
|
|
def progress_callback(current: int, total: int) -> None:
|
|||
|
|
print(f" 分块进度: {current}/{total}")
|
|||
|
|
|
|||
|
|
chunks = self._chunker.chunk(text, content_type=content_type, source_file=file_name, on_progress=progress_callback)
|
|||
|
|
|
|||
|
|
# 5. 写入输出
|
|||
|
|
print(f" [3/4] 写入输出: {output_path}")
|
|||
|
|
source_file = os.path.basename(input_path)
|
|||
|
|
self._writer.write(chunks, output_path, source_file, self._delimiter)
|
|||
|
|
|
|||
|
|
# 6. 输出摘要
|
|||
|
|
chunk_sizes = [len(c.content) for c in chunks]
|
|||
|
|
avg_size = sum(chunk_sizes) // max(len(chunk_sizes), 1)
|
|||
|
|
max_size = max(chunk_sizes) if chunk_sizes else 0
|
|||
|
|
min_size = min(chunk_sizes) if chunk_sizes else 0
|
|||
|
|
print(f" [4/4] 完成! 共 {len(chunks)} 个分块 (平均 {avg_size} 字, 最大 {max_size} 字, 最小 {min_size} 字)")
|
|||
|
|
|
|||
|
|
# 检查是否有超大分块(可能不适合知识库平台)
|
|||
|
|
oversized = [i + 1 for i, s in enumerate(chunk_sizes) if s > 800]
|
|||
|
|
if oversized:
|
|||
|
|
print(f" ⚠️ 分块 {oversized} 超过 800 字,上传知识库时建议检查")
|