Files
bigwo/main.py

165 lines
4.9 KiB
Python
Raw Normal View History

"""CLI 入口argparse 参数解析"""
import argparse
import os
import sys
from dotenv import load_dotenv
load_dotenv()
from exceptions import ApiError, ParseError, UnsupportedFormatError
from splitter import Splitter
def derive_output_path(input_file: str, output_format: str = "markdown") -> str:
"""根据输入文件路径推导默认输出路径"""
root, _ = os.path.splitext(input_file)
ext = ".json" if output_format == "json" else ".md"
return root + ext
def build_parser() -> argparse.ArgumentParser:
"""构建命令行参数解析器"""
parser = argparse.ArgumentParser(
description="AI 知识库文档智能分块工具 - 将多种格式文档解析并通过 DeepSeek API 进行语义级智能分块"
)
# 输入源(单文件或批量文件夹,二选一)
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument(
"input_file",
nargs="?",
default=None,
help="输入文件路径(支持 PDF、Word、Excel、CSV、HTML、TXT、图片等格式",
)
input_group.add_argument(
"-b", "--batch",
default=None,
help="批量处理模式:指定输入文件夹路径,递归扫描所有支持的文件",
)
# API Key支持环境变量 DEEPSEEK_API_KEY
parser.add_argument(
"-k", "--api-key",
default=os.environ.get("DEEPSEEK_API_KEY"),
help="DeepSeek API Key也可通过环境变量 DEEPSEEK_API_KEY 设置)",
)
# 输出相关
parser.add_argument(
"-o", "--output",
default=None,
help="输出文件路径(单文件模式,默认为输入文件同目录同名 .md/.json 文件)",
)
parser.add_argument(
"--output-dir",
default=None,
help="批量模式的输出目录(默认:输入文件夹下的 output/ 子目录)",
)
parser.add_argument(
"-f", "--format",
choices=["markdown", "json"],
default="markdown",
help="输出格式(默认: markdown",
)
# 处理参数
parser.add_argument(
"-d", "--delimiter",
default="---",
help="分块分隔符(默认: ---",
)
parser.add_argument(
"--chunk-size",
type=int,
default=None,
help="预切分大小(字符数),默认 12000。中文文档建议 10000-15000",
)
parser.add_argument(
"--vision-prompt",
default=None,
help="自定义图片识别的 system prompt",
)
# 批量处理选项
parser.add_argument(
"--skip-existing",
action="store_true",
default=False,
help="跳过已存在的输出文件(避免重复处理和 API 费用)",
)
return parser
def main() -> None:
parser = build_parser()
args = parser.parse_args()
try:
if not args.api_key:
print("错误: 未提供 API Key。请通过 -k 参数或环境变量 DEEPSEEK_API_KEY 设置", file=sys.stderr)
sys.exit(1)
splitter = Splitter(
api_key=args.api_key,
delimiter=args.delimiter,
pre_split_size=args.chunk_size,
vision_prompt=args.vision_prompt,
output_format=args.format,
)
if args.batch:
# 批量处理模式
from batch import batch_process, print_summary
input_dir = args.batch
if not os.path.isdir(input_dir):
print(f"错误: 批量处理路径不是文件夹: {input_dir}", file=sys.stderr)
sys.exit(1)
output_dir = args.output_dir or os.path.join(input_dir, "output")
result = batch_process(
splitter=splitter,
input_dir=input_dir,
output_dir=output_dir,
skip_existing=args.skip_existing,
output_format=args.format,
)
print_summary(result)
if result.failed:
sys.exit(1)
else:
# 单文件处理模式
input_file = args.input_file
output_path = args.output or derive_output_path(input_file, args.format)
if args.skip_existing and os.path.exists(output_path):
print(f"输出文件已存在,跳过: {output_path}")
return
splitter.process(input_file, output_path)
except FileNotFoundError as e:
print(f"错误: {e}", file=sys.stderr)
sys.exit(1)
except UnsupportedFormatError as e:
print(f"错误: {e}", file=sys.stderr)
sys.exit(1)
except ParseError as e:
print(f"错误: {e}", file=sys.stderr)
sys.exit(1)
except ApiError as e:
print(f"错误: API 调用失败 - {e}", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"错误: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()