Files
bigwo/tests/test_main.py

183 lines
6.4 KiB
Python
Raw Permalink Normal View History

"""CLI 入口 main.py 单元测试"""
import os
import pytest
from unittest.mock import patch, MagicMock
from main import derive_output_path, build_parser, main
from exceptions import ParseError, UnsupportedFormatError, ApiError
class TestDeriveOutputPath:
"""默认输出路径推导测试"""
def test_pdf_to_md(self):
assert derive_output_path("report.pdf") == "report.md"
def test_xlsx_to_md(self):
assert derive_output_path("data.xlsx") == "data.md"
def test_with_directory(self):
assert derive_output_path("/home/user/docs/file.docx") == "/home/user/docs/file.md"
def test_txt_to_md(self):
assert derive_output_path("notes.txt") == "notes.md"
def test_no_extension(self):
assert derive_output_path("README") == "README.md"
def test_multiple_dots(self):
assert derive_output_path("my.report.v2.pdf") == "my.report.v2.md"
class TestBuildParser:
"""argparse 参数解析测试"""
def test_all_args(self):
parser = build_parser()
args = parser.parse_args(["input.pdf", "-k", "sk-abc", "-o", "out.md", "-d", "==="])
assert args.input_file == "input.pdf"
assert args.api_key == "sk-abc"
assert args.output == "out.md"
assert args.delimiter == "==="
def test_required_args_only(self):
parser = build_parser()
args = parser.parse_args(["input.pdf", "-k", "sk-abc"])
assert args.input_file == "input.pdf"
assert args.api_key == "sk-abc"
assert args.output is None
assert args.delimiter == "---"
def test_long_option_names(self):
parser = build_parser()
args = parser.parse_args(["input.pdf", "--api-key", "sk-abc", "--output", "out.md", "--delimiter", "***"])
assert args.api_key == "sk-abc"
assert args.output == "out.md"
assert args.delimiter == "***"
def test_missing_input_file(self):
parser = build_parser()
with pytest.raises(SystemExit) as exc_info:
parser.parse_args(["-k", "sk-abc"])
assert exc_info.value.code != 0
@patch.dict(os.environ, {}, clear=True)
def test_missing_api_key(self):
"""无 -k 且无环境变量时api_key 应为 None"""
parser = build_parser()
args = parser.parse_args(["input.pdf"])
assert args.api_key is None
class TestMainFunction:
"""main() 函数集成测试"""
@patch("main.Splitter")
def test_success(self, mock_splitter_cls):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc"]):
main()
mock_splitter_cls.assert_called_once_with(
api_key="sk-abc", delimiter="---",
pre_split_size=None, vision_prompt=None, output_format="markdown",
)
mock_splitter.process.assert_called_once_with("input.pdf", "input.md")
@patch("main.Splitter")
def test_custom_output(self, mock_splitter_cls):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc", "-o", "custom.md"]):
main()
mock_splitter.process.assert_called_once_with("input.pdf", "custom.md")
@patch("main.Splitter")
def test_custom_delimiter(self, mock_splitter_cls):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc", "-d", "==="]):
main()
mock_splitter_cls.assert_called_once_with(
api_key="sk-abc", delimiter="===",
pre_split_size=None, vision_prompt=None, output_format="markdown",
)
@patch("main.Splitter")
def test_file_not_found_error(self, mock_splitter_cls, capsys):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
mock_splitter.process.side_effect = FileNotFoundError("输入文件不存在: missing.pdf")
with patch("sys.argv", ["main.py", "missing.pdf", "-k", "sk-abc"]):
with pytest.raises(SystemExit) as exc_info:
main()
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "missing.pdf" in captured.err
@patch("main.Splitter")
def test_unsupported_format_error(self, mock_splitter_cls, capsys):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
mock_splitter.process.side_effect = UnsupportedFormatError("file.xyz", ".xyz")
with patch("sys.argv", ["main.py", "file.xyz", "-k", "sk-abc"]):
with pytest.raises(SystemExit) as exc_info:
main()
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert ".xyz" in captured.err
@patch("main.Splitter")
def test_parse_error(self, mock_splitter_cls, capsys):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
mock_splitter.process.side_effect = ParseError("bad.pdf", "文件损坏")
with patch("sys.argv", ["main.py", "bad.pdf", "-k", "sk-abc"]):
with pytest.raises(SystemExit) as exc_info:
main()
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "bad.pdf" in captured.err
@patch("main.Splitter")
def test_api_error(self, mock_splitter_cls, capsys):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
mock_splitter.process.side_effect = ApiError("认证失败", status_code=401)
with patch("sys.argv", ["main.py", "input.pdf", "-k", "bad-key"]):
with pytest.raises(SystemExit) as exc_info:
main()
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "API" in captured.err
@patch("main.Splitter")
def test_generic_exception(self, mock_splitter_cls, capsys):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
mock_splitter.process.side_effect = RuntimeError("意外错误")
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc"]):
with pytest.raises(SystemExit) as exc_info:
main()
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "意外错误" in captured.err