Files
bigwo/tests/test_splitter.py

360 lines
13 KiB
Python
Raw Permalink Normal View History

"""Splitter 协调器单元测试"""
import os
import pytest
from unittest.mock import MagicMock, patch, call
from exceptions import ApiError, ParseError, UnsupportedFormatError
from models import Chunk
from splitter import Splitter
@pytest.fixture
def mock_deps():
"""Patch all external dependencies and return their mocks."""
with (
patch("splitter.ApiClient") as mock_api_cls,
patch("splitter.AIChunker") as mock_chunker_cls,
patch("splitter.MarkdownWriter") as mock_writer_cls,
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry") as mock_registry_cls,
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
api_client = mock_api_cls.return_value
chunker = mock_chunker_cls.return_value
writer = mock_writer_cls.return_value
registry = mock_registry_cls.return_value
splitter = Splitter(api_key="test-key", delimiter="---")
yield {
"splitter": splitter,
"api_client": api_client,
"chunker": chunker,
"writer": writer,
"registry": registry,
}
class TestInit:
"""初始化测试"""
def test_registers_all_parsers(self):
"""验证所有解析器都被注册"""
with (
patch("splitter.ApiClient"),
patch("splitter.AIChunker"),
patch("splitter.MarkdownWriter"),
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry") as mock_registry_cls,
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
registry = mock_registry_cls.return_value
Splitter(api_key="test-key")
# 9 parsers: Text, Csv, Html, Pdf, Doc, LegacyDoc, Xlsx, Xls, Image
assert registry.register.call_count == 9
def test_creates_api_client_with_key(self):
"""验证 ApiClient 使用正确的 api_key 创建"""
with (
patch("splitter.ApiClient") as mock_api_cls,
patch("splitter.AIChunker"),
patch("splitter.MarkdownWriter"),
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry"),
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
Splitter(api_key="my-secret-key")
mock_api_cls.assert_called_once_with(api_key="my-secret-key")
def test_creates_chunker_with_delimiter(self):
"""验证 AIChunker 使用正确的 delimiter 创建"""
with (
patch("splitter.ApiClient") as mock_api_cls,
patch("splitter.AIChunker") as mock_chunker_cls,
patch("splitter.MarkdownWriter"),
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry"),
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
Splitter(api_key="key", delimiter="===")
mock_chunker_cls.assert_called_once_with(
mock_api_cls.return_value, "===", pre_split_size=None
)
class TestProcessSuccess:
"""成功处理流程测试"""
def test_full_flow(self, mock_deps, tmp_path, capsys):
"""验证完整的成功处理流程"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
writer = mock_deps["writer"]
# Setup
input_file = tmp_path / "test.txt"
input_file.write_text("hello")
output_file = str(tmp_path / "output.md")
mock_parser = MagicMock()
mock_parser.parse.return_value = "parsed text"
registry.get_parser.return_value = mock_parser
chunks = [Chunk(title="标题", content="内容")]
chunker.chunk.return_value = chunks
# Execute
splitter.process(str(input_file), output_file)
# Verify call chain
registry.get_parser.assert_called_once_with(str(input_file))
mock_parser.parse.assert_called_once_with(str(input_file))
chunker.chunk.assert_called_once()
assert chunker.chunk.call_args[0][0] == "parsed text"
writer.write.assert_called_once_with(
chunks, output_file, "test.txt", "---"
)
def test_logs_parsing_stage(self, mock_deps, tmp_path, capsys):
"""验证输出文件解析日志"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.pdf"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.return_value = [Chunk(title="t", content="c")]
splitter.process(str(input_file), str(tmp_path / "out.md"))
output = capsys.readouterr().out
assert "解析文件: doc.pdf" in output
def test_logs_chunking_stage(self, mock_deps, tmp_path, capsys):
"""验证输出 AI 分块日志"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.return_value = [Chunk(title="t", content="c")]
splitter.process(str(input_file), str(tmp_path / "out.md"))
output = capsys.readouterr().out
assert "AI 语义分块" in output
def test_logs_writing_stage(self, mock_deps, tmp_path, capsys):
"""验证输出写入日志"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
output_path = str(tmp_path / "out.md")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.return_value = [Chunk(title="t", content="c")]
splitter.process(str(input_file), output_path)
output = capsys.readouterr().out
assert "写入输出" in output
def test_logs_summary(self, mock_deps, tmp_path, capsys):
"""验证输出处理摘要"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.return_value = [
Chunk(title="t1", content="c1"),
Chunk(title="t2", content="c2"),
Chunk(title="t3", content="c3"),
]
splitter.process(str(input_file), str(tmp_path / "out.md"))
output = capsys.readouterr().out
assert "3 个分块" in output
def test_progress_callback_passed_to_chunker(self, mock_deps, tmp_path, capsys):
"""验证进度回调被传递给 chunker 并正确输出"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
# Simulate chunker calling the progress callback
def fake_chunk(text, content_type=None, source_file="", on_progress=None):
if on_progress:
on_progress(1, 3)
on_progress(2, 3)
on_progress(3, 3)
return [Chunk(title="t", content="c")]
chunker.chunk.side_effect = fake_chunk
splitter.process(str(input_file), str(tmp_path / "out.md"))
output = capsys.readouterr().out
assert "分块进度: 1/3" in output
assert "分块进度: 2/3" in output
assert "分块进度: 3/3" in output
class TestProcessErrors:
"""错误处理测试"""
def test_file_not_found(self, mock_deps):
"""验证文件不存在时抛出 FileNotFoundError"""
splitter = mock_deps["splitter"]
with pytest.raises(FileNotFoundError, match="输入文件不存在"):
splitter.process("/nonexistent/path/file.txt", "output.md")
def test_unsupported_format(self, mock_deps, tmp_path):
"""验证不支持的格式时抛出 UnsupportedFormatError"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
input_file = tmp_path / "file.xyz"
input_file.write_text("data")
registry.get_parser.side_effect = UnsupportedFormatError("file.xyz", ".xyz")
with pytest.raises(UnsupportedFormatError):
splitter.process(str(input_file), str(tmp_path / "out.md"))
def test_parse_error(self, mock_deps, tmp_path):
"""验证解析错误时抛出 ParseError"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
input_file = tmp_path / "bad.pdf"
input_file.write_bytes(b"\x00\x01\x02")
mock_parser = MagicMock()
mock_parser.parse.side_effect = ParseError("bad.pdf", "文件损坏")
registry.get_parser.return_value = mock_parser
with pytest.raises(ParseError, match="bad.pdf"):
splitter.process(str(input_file), str(tmp_path / "out.md"))
def test_api_error(self, mock_deps, tmp_path):
"""验证 API 错误时抛出 ApiError"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.side_effect = ApiError("API 调用失败")
with pytest.raises(ApiError, match="API 调用失败"):
splitter.process(str(input_file), str(tmp_path / "out.md"))
class TestCustomDelimiter:
"""自定义分隔符测试"""
def test_delimiter_passed_to_writer(self, tmp_path):
"""验证自定义分隔符传递给 writer"""
with (
patch("splitter.ApiClient"),
patch("splitter.AIChunker") as mock_chunker_cls,
patch("splitter.MarkdownWriter") as mock_writer_cls,
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry") as mock_registry_cls,
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
splitter = Splitter(api_key="key", delimiter="===")
input_file = tmp_path / "test.txt"
input_file.write_text("hello")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
mock_registry_cls.return_value.get_parser.return_value = mock_parser
chunks = [Chunk(title="t", content="c")]
mock_chunker_cls.return_value.chunk.return_value = chunks
writer = mock_writer_cls.return_value
splitter.process(str(input_file), str(tmp_path / "out.md"))
writer.write.assert_called_once_with(
chunks, str(tmp_path / "out.md"), "test.txt", "==="
)