165 lines
4.9 KiB
Python
165 lines
4.9 KiB
Python
|
|
"""CLI 入口,argparse 参数解析"""
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
|
|||
|
|
load_dotenv()
|
|||
|
|
|
|||
|
|
from exceptions import ApiError, ParseError, UnsupportedFormatError
|
|||
|
|
from splitter import Splitter
|
|||
|
|
|
|||
|
|
|
|||
|
|
def derive_output_path(input_file: str, output_format: str = "markdown") -> str:
|
|||
|
|
"""根据输入文件路径推导默认输出路径"""
|
|||
|
|
root, _ = os.path.splitext(input_file)
|
|||
|
|
ext = ".json" if output_format == "json" else ".md"
|
|||
|
|
return root + ext
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_parser() -> argparse.ArgumentParser:
|
|||
|
|
"""构建命令行参数解析器"""
|
|||
|
|
parser = argparse.ArgumentParser(
|
|||
|
|
description="AI 知识库文档智能分块工具 - 将多种格式文档解析并通过 DeepSeek API 进行语义级智能分块"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 输入源(单文件或批量文件夹,二选一)
|
|||
|
|
input_group = parser.add_mutually_exclusive_group(required=True)
|
|||
|
|
input_group.add_argument(
|
|||
|
|
"input_file",
|
|||
|
|
nargs="?",
|
|||
|
|
default=None,
|
|||
|
|
help="输入文件路径(支持 PDF、Word、Excel、CSV、HTML、TXT、图片等格式)",
|
|||
|
|
)
|
|||
|
|
input_group.add_argument(
|
|||
|
|
"-b", "--batch",
|
|||
|
|
default=None,
|
|||
|
|
help="批量处理模式:指定输入文件夹路径,递归扫描所有支持的文件",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# API Key(支持环境变量 DEEPSEEK_API_KEY)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"-k", "--api-key",
|
|||
|
|
default=os.environ.get("DEEPSEEK_API_KEY"),
|
|||
|
|
help="DeepSeek API Key(也可通过环境变量 DEEPSEEK_API_KEY 设置)",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 输出相关
|
|||
|
|
parser.add_argument(
|
|||
|
|
"-o", "--output",
|
|||
|
|
default=None,
|
|||
|
|
help="输出文件路径(单文件模式,默认为输入文件同目录同名 .md/.json 文件)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--output-dir",
|
|||
|
|
default=None,
|
|||
|
|
help="批量模式的输出目录(默认:输入文件夹下的 output/ 子目录)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"-f", "--format",
|
|||
|
|
choices=["markdown", "json"],
|
|||
|
|
default="markdown",
|
|||
|
|
help="输出格式(默认: markdown)",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 处理参数
|
|||
|
|
parser.add_argument(
|
|||
|
|
"-d", "--delimiter",
|
|||
|
|
default="---",
|
|||
|
|
help="分块分隔符(默认: ---)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--chunk-size",
|
|||
|
|
type=int,
|
|||
|
|
default=None,
|
|||
|
|
help="预切分大小(字符数),默认 12000。中文文档建议 10000-15000",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--vision-prompt",
|
|||
|
|
default=None,
|
|||
|
|
help="自定义图片识别的 system prompt",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 批量处理选项
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--skip-existing",
|
|||
|
|
action="store_true",
|
|||
|
|
default=False,
|
|||
|
|
help="跳过已存在的输出文件(避免重复处理和 API 费用)",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return parser
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main() -> None:
|
|||
|
|
parser = build_parser()
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if not args.api_key:
|
|||
|
|
print("错误: 未提供 API Key。请通过 -k 参数或环境变量 DEEPSEEK_API_KEY 设置", file=sys.stderr)
|
|||
|
|
sys.exit(1)
|
|||
|
|
|
|||
|
|
splitter = Splitter(
|
|||
|
|
api_key=args.api_key,
|
|||
|
|
delimiter=args.delimiter,
|
|||
|
|
pre_split_size=args.chunk_size,
|
|||
|
|
vision_prompt=args.vision_prompt,
|
|||
|
|
output_format=args.format,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if args.batch:
|
|||
|
|
# 批量处理模式
|
|||
|
|
from batch import batch_process, print_summary
|
|||
|
|
|
|||
|
|
input_dir = args.batch
|
|||
|
|
if not os.path.isdir(input_dir):
|
|||
|
|
print(f"错误: 批量处理路径不是文件夹: {input_dir}", file=sys.stderr)
|
|||
|
|
sys.exit(1)
|
|||
|
|
|
|||
|
|
output_dir = args.output_dir or os.path.join(input_dir, "output")
|
|||
|
|
|
|||
|
|
result = batch_process(
|
|||
|
|
splitter=splitter,
|
|||
|
|
input_dir=input_dir,
|
|||
|
|
output_dir=output_dir,
|
|||
|
|
skip_existing=args.skip_existing,
|
|||
|
|
output_format=args.format,
|
|||
|
|
)
|
|||
|
|
print_summary(result)
|
|||
|
|
|
|||
|
|
if result.failed:
|
|||
|
|
sys.exit(1)
|
|||
|
|
else:
|
|||
|
|
# 单文件处理模式
|
|||
|
|
input_file = args.input_file
|
|||
|
|
output_path = args.output or derive_output_path(input_file, args.format)
|
|||
|
|
|
|||
|
|
if args.skip_existing and os.path.exists(output_path):
|
|||
|
|
print(f"输出文件已存在,跳过: {output_path}")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
splitter.process(input_file, output_path)
|
|||
|
|
|
|||
|
|
except FileNotFoundError as e:
|
|||
|
|
print(f"错误: {e}", file=sys.stderr)
|
|||
|
|
sys.exit(1)
|
|||
|
|
except UnsupportedFormatError as e:
|
|||
|
|
print(f"错误: {e}", file=sys.stderr)
|
|||
|
|
sys.exit(1)
|
|||
|
|
except ParseError as e:
|
|||
|
|
print(f"错误: {e}", file=sys.stderr)
|
|||
|
|
sys.exit(1)
|
|||
|
|
except ApiError as e:
|
|||
|
|
print(f"错误: API 调用失败 - {e}", file=sys.stderr)
|
|||
|
|
sys.exit(1)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"错误: {e}", file=sys.stderr)
|
|||
|
|
sys.exit(1)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|