Initial commit: AI 知识库文档智能分块工具
This commit is contained in:
164
main.py
Normal file
164
main.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""CLI 入口,argparse 参数解析"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
from exceptions import ApiError, ParseError, UnsupportedFormatError
|
||||
from splitter import Splitter
|
||||
|
||||
|
||||
def derive_output_path(input_file: str, output_format: str = "markdown") -> str:
|
||||
"""根据输入文件路径推导默认输出路径"""
|
||||
root, _ = os.path.splitext(input_file)
|
||||
ext = ".json" if output_format == "json" else ".md"
|
||||
return root + ext
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
"""构建命令行参数解析器"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="AI 知识库文档智能分块工具 - 将多种格式文档解析并通过 DeepSeek API 进行语义级智能分块"
|
||||
)
|
||||
|
||||
# 输入源(单文件或批量文件夹,二选一)
|
||||
input_group = parser.add_mutually_exclusive_group(required=True)
|
||||
input_group.add_argument(
|
||||
"input_file",
|
||||
nargs="?",
|
||||
default=None,
|
||||
help="输入文件路径(支持 PDF、Word、Excel、CSV、HTML、TXT、图片等格式)",
|
||||
)
|
||||
input_group.add_argument(
|
||||
"-b", "--batch",
|
||||
default=None,
|
||||
help="批量处理模式:指定输入文件夹路径,递归扫描所有支持的文件",
|
||||
)
|
||||
|
||||
# API Key(支持环境变量 DEEPSEEK_API_KEY)
|
||||
parser.add_argument(
|
||||
"-k", "--api-key",
|
||||
default=os.environ.get("DEEPSEEK_API_KEY"),
|
||||
help="DeepSeek API Key(也可通过环境变量 DEEPSEEK_API_KEY 设置)",
|
||||
)
|
||||
|
||||
# 输出相关
|
||||
parser.add_argument(
|
||||
"-o", "--output",
|
||||
default=None,
|
||||
help="输出文件路径(单文件模式,默认为输入文件同目录同名 .md/.json 文件)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
default=None,
|
||||
help="批量模式的输出目录(默认:输入文件夹下的 output/ 子目录)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f", "--format",
|
||||
choices=["markdown", "json"],
|
||||
default="markdown",
|
||||
help="输出格式(默认: markdown)",
|
||||
)
|
||||
|
||||
# 处理参数
|
||||
parser.add_argument(
|
||||
"-d", "--delimiter",
|
||||
default="---",
|
||||
help="分块分隔符(默认: ---)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="预切分大小(字符数),默认 12000。中文文档建议 10000-15000",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vision-prompt",
|
||||
default=None,
|
||||
help="自定义图片识别的 system prompt",
|
||||
)
|
||||
|
||||
# 批量处理选项
|
||||
parser.add_argument(
|
||||
"--skip-existing",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="跳过已存在的输出文件(避免重复处理和 API 费用)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if not args.api_key:
|
||||
print("错误: 未提供 API Key。请通过 -k 参数或环境变量 DEEPSEEK_API_KEY 设置", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
splitter = Splitter(
|
||||
api_key=args.api_key,
|
||||
delimiter=args.delimiter,
|
||||
pre_split_size=args.chunk_size,
|
||||
vision_prompt=args.vision_prompt,
|
||||
output_format=args.format,
|
||||
)
|
||||
|
||||
if args.batch:
|
||||
# 批量处理模式
|
||||
from batch import batch_process, print_summary
|
||||
|
||||
input_dir = args.batch
|
||||
if not os.path.isdir(input_dir):
|
||||
print(f"错误: 批量处理路径不是文件夹: {input_dir}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
output_dir = args.output_dir or os.path.join(input_dir, "output")
|
||||
|
||||
result = batch_process(
|
||||
splitter=splitter,
|
||||
input_dir=input_dir,
|
||||
output_dir=output_dir,
|
||||
skip_existing=args.skip_existing,
|
||||
output_format=args.format,
|
||||
)
|
||||
print_summary(result)
|
||||
|
||||
if result.failed:
|
||||
sys.exit(1)
|
||||
else:
|
||||
# 单文件处理模式
|
||||
input_file = args.input_file
|
||||
output_path = args.output or derive_output_path(input_file, args.format)
|
||||
|
||||
if args.skip_existing and os.path.exists(output_path):
|
||||
print(f"输出文件已存在,跳过: {output_path}")
|
||||
return
|
||||
|
||||
splitter.process(input_file, output_path)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"错误: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except UnsupportedFormatError as e:
|
||||
print(f"错误: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except ParseError as e:
|
||||
print(f"错误: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except ApiError as e:
|
||||
print(f"错误: API 调用失败 - {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"错误: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user