Files
bigwo/main.py
2026-03-02 17:38:28 +08:00

165 lines
4.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""CLI 入口argparse 参数解析"""
import argparse
import os
import sys
from dotenv import load_dotenv
load_dotenv()
from exceptions import ApiError, ParseError, UnsupportedFormatError
from splitter import Splitter
def derive_output_path(input_file: str, output_format: str = "markdown") -> str:
"""根据输入文件路径推导默认输出路径"""
root, _ = os.path.splitext(input_file)
ext = ".json" if output_format == "json" else ".md"
return root + ext
def build_parser() -> argparse.ArgumentParser:
"""构建命令行参数解析器"""
parser = argparse.ArgumentParser(
description="AI 知识库文档智能分块工具 - 将多种格式文档解析并通过 DeepSeek API 进行语义级智能分块"
)
# 输入源(单文件或批量文件夹,二选一)
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument(
"input_file",
nargs="?",
default=None,
help="输入文件路径(支持 PDF、Word、Excel、CSV、HTML、TXT、图片等格式",
)
input_group.add_argument(
"-b", "--batch",
default=None,
help="批量处理模式:指定输入文件夹路径,递归扫描所有支持的文件",
)
# API Key支持环境变量 DEEPSEEK_API_KEY
parser.add_argument(
"-k", "--api-key",
default=os.environ.get("DEEPSEEK_API_KEY"),
help="DeepSeek API Key也可通过环境变量 DEEPSEEK_API_KEY 设置)",
)
# 输出相关
parser.add_argument(
"-o", "--output",
default=None,
help="输出文件路径(单文件模式,默认为输入文件同目录同名 .md/.json 文件)",
)
parser.add_argument(
"--output-dir",
default=None,
help="批量模式的输出目录(默认:输入文件夹下的 output/ 子目录)",
)
parser.add_argument(
"-f", "--format",
choices=["markdown", "json"],
default="markdown",
help="输出格式(默认: markdown",
)
# 处理参数
parser.add_argument(
"-d", "--delimiter",
default="---",
help="分块分隔符(默认: ---",
)
parser.add_argument(
"--chunk-size",
type=int,
default=None,
help="预切分大小(字符数),默认 12000。中文文档建议 10000-15000",
)
parser.add_argument(
"--vision-prompt",
default=None,
help="自定义图片识别的 system prompt",
)
# 批量处理选项
parser.add_argument(
"--skip-existing",
action="store_true",
default=False,
help="跳过已存在的输出文件(避免重复处理和 API 费用)",
)
return parser
def main() -> None:
parser = build_parser()
args = parser.parse_args()
try:
if not args.api_key:
print("错误: 未提供 API Key。请通过 -k 参数或环境变量 DEEPSEEK_API_KEY 设置", file=sys.stderr)
sys.exit(1)
splitter = Splitter(
api_key=args.api_key,
delimiter=args.delimiter,
pre_split_size=args.chunk_size,
vision_prompt=args.vision_prompt,
output_format=args.format,
)
if args.batch:
# 批量处理模式
from batch import batch_process, print_summary
input_dir = args.batch
if not os.path.isdir(input_dir):
print(f"错误: 批量处理路径不是文件夹: {input_dir}", file=sys.stderr)
sys.exit(1)
output_dir = args.output_dir or os.path.join(input_dir, "output")
result = batch_process(
splitter=splitter,
input_dir=input_dir,
output_dir=output_dir,
skip_existing=args.skip_existing,
output_format=args.format,
)
print_summary(result)
if result.failed:
sys.exit(1)
else:
# 单文件处理模式
input_file = args.input_file
output_path = args.output or derive_output_path(input_file, args.format)
if args.skip_existing and os.path.exists(output_path):
print(f"输出文件已存在,跳过: {output_path}")
return
splitter.process(input_file, output_path)
except FileNotFoundError as e:
print(f"错误: {e}", file=sys.stderr)
sys.exit(1)
except UnsupportedFormatError as e:
print(f"错误: {e}", file=sys.stderr)
sys.exit(1)
except ParseError as e:
print(f"错误: {e}", file=sys.stderr)
sys.exit(1)
except ApiError as e:
print(f"错误: API 调用失败 - {e}", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"错误: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()