Initial commit: AI 知识库文档智能分块工具
This commit is contained in:
141
batch.py
Normal file
141
batch.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""批量处理模块,递归扫描文件夹并逐个处理,含容错和汇总报告
|
||||
|
||||
Coze 知识库适配:
|
||||
- 文本类文件(docx/doc/pdf/txt/html)→ AI 分块后输出到 output/
|
||||
- 表格类文件(xlsx/xls/csv)→ 直接复制到 tables/,上传 Coze 表格知识库
|
||||
- 图片类文件 → 正常走 AI 流程(后续可能调整)
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
from splitter import Splitter, TABLE_EXTENSIONS
|
||||
|
||||
# 所有支持的扩展名
|
||||
SUPPORTED_EXTENSIONS: Set[str] = {
|
||||
".txt", ".md", ".csv", ".html", ".htm",
|
||||
".pdf", ".docx", ".doc",
|
||||
".xlsx", ".xls",
|
||||
".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchResult:
|
||||
"""批量处理结果"""
|
||||
success: List[str] = field(default_factory=list)
|
||||
failed: List[Tuple[str, str]] = field(default_factory=list) # (file_path, error_msg)
|
||||
skipped: List[str] = field(default_factory=list)
|
||||
tables: List[str] = field(default_factory=list) # 直接复制的表格文件
|
||||
|
||||
|
||||
def scan_files(input_dir: str) -> List[str]:
|
||||
"""递归扫描文件夹,返回所有支持格式的文件路径列表(按名称排序)"""
|
||||
files = []
|
||||
for root, _, filenames in os.walk(input_dir):
|
||||
for filename in sorted(filenames):
|
||||
if filename.startswith("."):
|
||||
continue
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext in SUPPORTED_EXTENSIONS:
|
||||
files.append(os.path.join(root, filename))
|
||||
return files
|
||||
|
||||
|
||||
def batch_process(
|
||||
splitter: Splitter,
|
||||
input_dir: str,
|
||||
output_dir: str,
|
||||
skip_existing: bool = False,
|
||||
output_format: str = "markdown",
|
||||
) -> BatchResult:
|
||||
"""
|
||||
批量处理文件夹中的所有支持格式的文件。
|
||||
|
||||
表格类文件(xlsx/xls/csv)直接复制到 output_dir/tables/ 子文件夹,
|
||||
不经过 AI 分块,用户可直接上传到 Coze 表格知识库。
|
||||
|
||||
Args:
|
||||
splitter: Splitter 实例
|
||||
input_dir: 输入文件夹路径
|
||||
output_dir: 输出文件夹路径
|
||||
skip_existing: 是否跳过已存在的输出文件
|
||||
output_format: 输出格式 ("markdown" 或 "json")
|
||||
|
||||
Returns:
|
||||
BatchResult 包含成功/失败/跳过/表格的文件列表
|
||||
"""
|
||||
result = BatchResult()
|
||||
files = scan_files(input_dir)
|
||||
total = len(files)
|
||||
|
||||
if total == 0:
|
||||
print(f"未在 {input_dir} 中找到支持的文件")
|
||||
return result
|
||||
|
||||
print(f"共扫描到 {total} 个文件待处理\n")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
tables_dir = os.path.join(output_dir, "tables")
|
||||
|
||||
for i, file_path in enumerate(files, start=1):
|
||||
rel_path = os.path.relpath(file_path, input_dir)
|
||||
file_ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
# 表格类文件:直接复制到 tables/ 子文件夹
|
||||
if file_ext in TABLE_EXTENSIONS:
|
||||
os.makedirs(tables_dir, exist_ok=True)
|
||||
dest = os.path.join(tables_dir, os.path.basename(file_path))
|
||||
if skip_existing and os.path.exists(dest):
|
||||
result.skipped.append(file_path)
|
||||
print(f"[{i}/{total}] 跳过(已存在): {rel_path}")
|
||||
continue
|
||||
shutil.copy2(file_path, dest)
|
||||
result.tables.append(file_path)
|
||||
print(f"[{i}/{total}] 表格文件,直接复制: {rel_path} → tables/")
|
||||
continue
|
||||
|
||||
# 文本/图片类文件:走 AI 分块流程
|
||||
ext = ".json" if output_format == "json" else ".md"
|
||||
output_path = os.path.join(
|
||||
output_dir,
|
||||
os.path.splitext(rel_path)[0] + ext,
|
||||
)
|
||||
|
||||
if skip_existing and os.path.exists(output_path):
|
||||
result.skipped.append(file_path)
|
||||
print(f"[{i}/{total}] 跳过(已存在): {rel_path}")
|
||||
continue
|
||||
|
||||
print(f"[{i}/{total}] 正在处理: {rel_path}")
|
||||
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
splitter.process(file_path, output_path)
|
||||
result.success.append(file_path)
|
||||
print(f" ✓ 完成")
|
||||
except Exception as e:
|
||||
result.failed.append((file_path, str(e)))
|
||||
print(f" ✗ 失败: {e}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def print_summary(result: BatchResult) -> None:
|
||||
"""打印批量处理汇总报告"""
|
||||
total = len(result.success) + len(result.failed) + len(result.skipped) + len(result.tables)
|
||||
print(f"\n{'=' * 50}")
|
||||
print(f"批量处理完成! 共 {total} 个文件")
|
||||
print(f" ✓ 成功: {len(result.success)}")
|
||||
print(f" ✗ 失败: {len(result.failed)}")
|
||||
print(f" ⊘ 跳过: {len(result.skipped)}")
|
||||
if result.tables:
|
||||
print(f" 📊 表格(直接复制): {len(result.tables)}")
|
||||
print(f" → 请上传 tables/ 文件夹内的文件到 Coze「表格知识库」")
|
||||
|
||||
if result.failed:
|
||||
print(f"\n失败文件列表:")
|
||||
for path, err in result.failed:
|
||||
print(f" - {os.path.basename(path)}: {err}")
|
||||
Reference in New Issue
Block a user