165 lines
4.9 KiB
Python
165 lines
4.9 KiB
Python
"""CLI 入口,argparse 参数解析"""
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv()
|
||
|
||
from exceptions import ApiError, ParseError, UnsupportedFormatError
|
||
from splitter import Splitter
|
||
|
||
|
||
def derive_output_path(input_file: str, output_format: str = "markdown") -> str:
|
||
"""根据输入文件路径推导默认输出路径"""
|
||
root, _ = os.path.splitext(input_file)
|
||
ext = ".json" if output_format == "json" else ".md"
|
||
return root + ext
|
||
|
||
|
||
def build_parser() -> argparse.ArgumentParser:
|
||
"""构建命令行参数解析器"""
|
||
parser = argparse.ArgumentParser(
|
||
description="AI 知识库文档智能分块工具 - 将多种格式文档解析并通过 DeepSeek API 进行语义级智能分块"
|
||
)
|
||
|
||
# 输入源(单文件或批量文件夹,二选一)
|
||
input_group = parser.add_mutually_exclusive_group(required=True)
|
||
input_group.add_argument(
|
||
"input_file",
|
||
nargs="?",
|
||
default=None,
|
||
help="输入文件路径(支持 PDF、Word、Excel、CSV、HTML、TXT、图片等格式)",
|
||
)
|
||
input_group.add_argument(
|
||
"-b", "--batch",
|
||
default=None,
|
||
help="批量处理模式:指定输入文件夹路径,递归扫描所有支持的文件",
|
||
)
|
||
|
||
# API Key(支持环境变量 DEEPSEEK_API_KEY)
|
||
parser.add_argument(
|
||
"-k", "--api-key",
|
||
default=os.environ.get("DEEPSEEK_API_KEY"),
|
||
help="DeepSeek API Key(也可通过环境变量 DEEPSEEK_API_KEY 设置)",
|
||
)
|
||
|
||
# 输出相关
|
||
parser.add_argument(
|
||
"-o", "--output",
|
||
default=None,
|
||
help="输出文件路径(单文件模式,默认为输入文件同目录同名 .md/.json 文件)",
|
||
)
|
||
parser.add_argument(
|
||
"--output-dir",
|
||
default=None,
|
||
help="批量模式的输出目录(默认:输入文件夹下的 output/ 子目录)",
|
||
)
|
||
parser.add_argument(
|
||
"-f", "--format",
|
||
choices=["markdown", "json"],
|
||
default="markdown",
|
||
help="输出格式(默认: markdown)",
|
||
)
|
||
|
||
# 处理参数
|
||
parser.add_argument(
|
||
"-d", "--delimiter",
|
||
default="---",
|
||
help="分块分隔符(默认: ---)",
|
||
)
|
||
parser.add_argument(
|
||
"--chunk-size",
|
||
type=int,
|
||
default=None,
|
||
help="预切分大小(字符数),默认 12000。中文文档建议 10000-15000",
|
||
)
|
||
parser.add_argument(
|
||
"--vision-prompt",
|
||
default=None,
|
||
help="自定义图片识别的 system prompt",
|
||
)
|
||
|
||
# 批量处理选项
|
||
parser.add_argument(
|
||
"--skip-existing",
|
||
action="store_true",
|
||
default=False,
|
||
help="跳过已存在的输出文件(避免重复处理和 API 费用)",
|
||
)
|
||
|
||
return parser
|
||
|
||
|
||
def main() -> None:
|
||
parser = build_parser()
|
||
args = parser.parse_args()
|
||
|
||
try:
|
||
if not args.api_key:
|
||
print("错误: 未提供 API Key。请通过 -k 参数或环境变量 DEEPSEEK_API_KEY 设置", file=sys.stderr)
|
||
sys.exit(1)
|
||
|
||
splitter = Splitter(
|
||
api_key=args.api_key,
|
||
delimiter=args.delimiter,
|
||
pre_split_size=args.chunk_size,
|
||
vision_prompt=args.vision_prompt,
|
||
output_format=args.format,
|
||
)
|
||
|
||
if args.batch:
|
||
# 批量处理模式
|
||
from batch import batch_process, print_summary
|
||
|
||
input_dir = args.batch
|
||
if not os.path.isdir(input_dir):
|
||
print(f"错误: 批量处理路径不是文件夹: {input_dir}", file=sys.stderr)
|
||
sys.exit(1)
|
||
|
||
output_dir = args.output_dir or os.path.join(input_dir, "output")
|
||
|
||
result = batch_process(
|
||
splitter=splitter,
|
||
input_dir=input_dir,
|
||
output_dir=output_dir,
|
||
skip_existing=args.skip_existing,
|
||
output_format=args.format,
|
||
)
|
||
print_summary(result)
|
||
|
||
if result.failed:
|
||
sys.exit(1)
|
||
else:
|
||
# 单文件处理模式
|
||
input_file = args.input_file
|
||
output_path = args.output or derive_output_path(input_file, args.format)
|
||
|
||
if args.skip_existing and os.path.exists(output_path):
|
||
print(f"输出文件已存在,跳过: {output_path}")
|
||
return
|
||
|
||
splitter.process(input_file, output_path)
|
||
|
||
except FileNotFoundError as e:
|
||
print(f"错误: {e}", file=sys.stderr)
|
||
sys.exit(1)
|
||
except UnsupportedFormatError as e:
|
||
print(f"错误: {e}", file=sys.stderr)
|
||
sys.exit(1)
|
||
except ParseError as e:
|
||
print(f"错误: {e}", file=sys.stderr)
|
||
sys.exit(1)
|
||
except ApiError as e:
|
||
print(f"错误: API 调用失败 - {e}", file=sys.stderr)
|
||
sys.exit(1)
|
||
except Exception as e:
|
||
print(f"错误: {e}", file=sys.stderr)
|
||
sys.exit(1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|