183 lines
6.4 KiB
Python
183 lines
6.4 KiB
Python
"""CLI 入口 main.py 单元测试"""
|
||
|
||
import os
|
||
|
||
import pytest
|
||
from unittest.mock import patch, MagicMock
|
||
|
||
from main import derive_output_path, build_parser, main
|
||
from exceptions import ParseError, UnsupportedFormatError, ApiError
|
||
|
||
|
||
class TestDeriveOutputPath:
|
||
"""默认输出路径推导测试"""
|
||
|
||
def test_pdf_to_md(self):
|
||
assert derive_output_path("report.pdf") == "report.md"
|
||
|
||
def test_xlsx_to_md(self):
|
||
assert derive_output_path("data.xlsx") == "data.md"
|
||
|
||
def test_with_directory(self):
|
||
assert derive_output_path("/home/user/docs/file.docx") == "/home/user/docs/file.md"
|
||
|
||
def test_txt_to_md(self):
|
||
assert derive_output_path("notes.txt") == "notes.md"
|
||
|
||
def test_no_extension(self):
|
||
assert derive_output_path("README") == "README.md"
|
||
|
||
def test_multiple_dots(self):
|
||
assert derive_output_path("my.report.v2.pdf") == "my.report.v2.md"
|
||
|
||
|
||
class TestBuildParser:
|
||
"""argparse 参数解析测试"""
|
||
|
||
def test_all_args(self):
|
||
parser = build_parser()
|
||
args = parser.parse_args(["input.pdf", "-k", "sk-abc", "-o", "out.md", "-d", "==="])
|
||
assert args.input_file == "input.pdf"
|
||
assert args.api_key == "sk-abc"
|
||
assert args.output == "out.md"
|
||
assert args.delimiter == "==="
|
||
|
||
def test_required_args_only(self):
|
||
parser = build_parser()
|
||
args = parser.parse_args(["input.pdf", "-k", "sk-abc"])
|
||
assert args.input_file == "input.pdf"
|
||
assert args.api_key == "sk-abc"
|
||
assert args.output is None
|
||
assert args.delimiter == "---"
|
||
|
||
def test_long_option_names(self):
|
||
parser = build_parser()
|
||
args = parser.parse_args(["input.pdf", "--api-key", "sk-abc", "--output", "out.md", "--delimiter", "***"])
|
||
assert args.api_key == "sk-abc"
|
||
assert args.output == "out.md"
|
||
assert args.delimiter == "***"
|
||
|
||
def test_missing_input_file(self):
|
||
parser = build_parser()
|
||
with pytest.raises(SystemExit) as exc_info:
|
||
parser.parse_args(["-k", "sk-abc"])
|
||
assert exc_info.value.code != 0
|
||
|
||
@patch.dict(os.environ, {}, clear=True)
|
||
def test_missing_api_key(self):
|
||
"""无 -k 且无环境变量时,api_key 应为 None"""
|
||
parser = build_parser()
|
||
args = parser.parse_args(["input.pdf"])
|
||
assert args.api_key is None
|
||
|
||
|
||
class TestMainFunction:
|
||
"""main() 函数集成测试"""
|
||
|
||
@patch("main.Splitter")
|
||
def test_success(self, mock_splitter_cls):
|
||
mock_splitter = MagicMock()
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
|
||
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc"]):
|
||
main()
|
||
|
||
mock_splitter_cls.assert_called_once_with(
|
||
api_key="sk-abc", delimiter="---",
|
||
pre_split_size=None, vision_prompt=None, output_format="markdown",
|
||
)
|
||
mock_splitter.process.assert_called_once_with("input.pdf", "input.md")
|
||
|
||
@patch("main.Splitter")
|
||
def test_custom_output(self, mock_splitter_cls):
|
||
mock_splitter = MagicMock()
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
|
||
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc", "-o", "custom.md"]):
|
||
main()
|
||
|
||
mock_splitter.process.assert_called_once_with("input.pdf", "custom.md")
|
||
|
||
@patch("main.Splitter")
|
||
def test_custom_delimiter(self, mock_splitter_cls):
|
||
mock_splitter = MagicMock()
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
|
||
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc", "-d", "==="]):
|
||
main()
|
||
|
||
mock_splitter_cls.assert_called_once_with(
|
||
api_key="sk-abc", delimiter="===",
|
||
pre_split_size=None, vision_prompt=None, output_format="markdown",
|
||
)
|
||
|
||
@patch("main.Splitter")
|
||
def test_file_not_found_error(self, mock_splitter_cls, capsys):
|
||
mock_splitter = MagicMock()
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
mock_splitter.process.side_effect = FileNotFoundError("输入文件不存在: missing.pdf")
|
||
|
||
with patch("sys.argv", ["main.py", "missing.pdf", "-k", "sk-abc"]):
|
||
with pytest.raises(SystemExit) as exc_info:
|
||
main()
|
||
assert exc_info.value.code == 1
|
||
|
||
captured = capsys.readouterr()
|
||
assert "missing.pdf" in captured.err
|
||
|
||
@patch("main.Splitter")
|
||
def test_unsupported_format_error(self, mock_splitter_cls, capsys):
|
||
mock_splitter = MagicMock()
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
mock_splitter.process.side_effect = UnsupportedFormatError("file.xyz", ".xyz")
|
||
|
||
with patch("sys.argv", ["main.py", "file.xyz", "-k", "sk-abc"]):
|
||
with pytest.raises(SystemExit) as exc_info:
|
||
main()
|
||
assert exc_info.value.code == 1
|
||
|
||
captured = capsys.readouterr()
|
||
assert ".xyz" in captured.err
|
||
|
||
@patch("main.Splitter")
|
||
def test_parse_error(self, mock_splitter_cls, capsys):
|
||
mock_splitter = MagicMock()
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
mock_splitter.process.side_effect = ParseError("bad.pdf", "文件损坏")
|
||
|
||
with patch("sys.argv", ["main.py", "bad.pdf", "-k", "sk-abc"]):
|
||
with pytest.raises(SystemExit) as exc_info:
|
||
main()
|
||
assert exc_info.value.code == 1
|
||
|
||
captured = capsys.readouterr()
|
||
assert "bad.pdf" in captured.err
|
||
|
||
@patch("main.Splitter")
|
||
def test_api_error(self, mock_splitter_cls, capsys):
|
||
mock_splitter = MagicMock()
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
mock_splitter.process.side_effect = ApiError("认证失败", status_code=401)
|
||
|
||
with patch("sys.argv", ["main.py", "input.pdf", "-k", "bad-key"]):
|
||
with pytest.raises(SystemExit) as exc_info:
|
||
main()
|
||
assert exc_info.value.code == 1
|
||
|
||
captured = capsys.readouterr()
|
||
assert "API" in captured.err
|
||
|
||
@patch("main.Splitter")
|
||
def test_generic_exception(self, mock_splitter_cls, capsys):
|
||
mock_splitter = MagicMock()
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
mock_splitter.process.side_effect = RuntimeError("意外错误")
|
||
|
||
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc"]):
|
||
with pytest.raises(SystemExit) as exc_info:
|
||
main()
|
||
assert exc_info.value.code == 1
|
||
|
||
captured = capsys.readouterr()
|
||
assert "意外错误" in captured.err
|