"""CLI 入口,argparse 参数解析""" import argparse import os import sys from dotenv import load_dotenv load_dotenv() from exceptions import ApiError, ParseError, UnsupportedFormatError from splitter import Splitter def derive_output_path(input_file: str, output_format: str = "markdown") -> str: """根据输入文件路径推导默认输出路径""" root, _ = os.path.splitext(input_file) ext = ".json" if output_format == "json" else ".md" return root + ext def build_parser() -> argparse.ArgumentParser: """构建命令行参数解析器""" parser = argparse.ArgumentParser( description="AI 知识库文档智能分块工具 - 将多种格式文档解析并通过 DeepSeek API 进行语义级智能分块" ) # 输入源(单文件或批量文件夹,二选一) input_group = parser.add_mutually_exclusive_group(required=True) input_group.add_argument( "input_file", nargs="?", default=None, help="输入文件路径(支持 PDF、Word、Excel、CSV、HTML、TXT、图片等格式)", ) input_group.add_argument( "-b", "--batch", default=None, help="批量处理模式:指定输入文件夹路径,递归扫描所有支持的文件", ) # API Key(支持环境变量 DEEPSEEK_API_KEY) parser.add_argument( "-k", "--api-key", default=os.environ.get("DEEPSEEK_API_KEY"), help="DeepSeek API Key(也可通过环境变量 DEEPSEEK_API_KEY 设置)", ) # 输出相关 parser.add_argument( "-o", "--output", default=None, help="输出文件路径(单文件模式,默认为输入文件同目录同名 .md/.json 文件)", ) parser.add_argument( "--output-dir", default=None, help="批量模式的输出目录(默认:输入文件夹下的 output/ 子目录)", ) parser.add_argument( "-f", "--format", choices=["markdown", "json"], default="markdown", help="输出格式(默认: markdown)", ) # 处理参数 parser.add_argument( "-d", "--delimiter", default="---", help="分块分隔符(默认: ---)", ) parser.add_argument( "--chunk-size", type=int, default=None, help="预切分大小(字符数),默认 12000。中文文档建议 10000-15000", ) parser.add_argument( "--vision-prompt", default=None, help="自定义图片识别的 system prompt", ) # 批量处理选项 parser.add_argument( "--skip-existing", action="store_true", default=False, help="跳过已存在的输出文件(避免重复处理和 API 费用)", ) return parser def main() -> None: parser = build_parser() args = parser.parse_args() try: if not args.api_key: print("错误: 未提供 API Key。请通过 -k 参数或环境变量 DEEPSEEK_API_KEY 设置", file=sys.stderr) sys.exit(1) splitter = Splitter( api_key=args.api_key, delimiter=args.delimiter, pre_split_size=args.chunk_size, vision_prompt=args.vision_prompt, output_format=args.format, ) if args.batch: # 批量处理模式 from batch import batch_process, print_summary input_dir = args.batch if not os.path.isdir(input_dir): print(f"错误: 批量处理路径不是文件夹: {input_dir}", file=sys.stderr) sys.exit(1) output_dir = args.output_dir or os.path.join(input_dir, "output") result = batch_process( splitter=splitter, input_dir=input_dir, output_dir=output_dir, skip_existing=args.skip_existing, output_format=args.format, ) print_summary(result) if result.failed: sys.exit(1) else: # 单文件处理模式 input_file = args.input_file output_path = args.output or derive_output_path(input_file, args.format) if args.skip_existing and os.path.exists(output_path): print(f"输出文件已存在,跳过: {output_path}") return splitter.process(input_file, output_path) except FileNotFoundError as e: print(f"错误: {e}", file=sys.stderr) sys.exit(1) except UnsupportedFormatError as e: print(f"错误: {e}", file=sys.stderr) sys.exit(1) except ParseError as e: print(f"错误: {e}", file=sys.stderr) sys.exit(1) except ApiError as e: print(f"错误: API 调用失败 - {e}", file=sys.stderr) sys.exit(1) except Exception as e: print(f"错误: {e}", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main()