Initial commit: AI 知识库文档智能分块工具
This commit is contained in:
359
tests/test_splitter.py
Normal file
359
tests/test_splitter.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""Splitter 协调器单元测试"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
from exceptions import ApiError, ParseError, UnsupportedFormatError
|
||||
from models import Chunk
|
||||
from splitter import Splitter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_deps():
|
||||
"""Patch all external dependencies and return their mocks."""
|
||||
with (
|
||||
patch("splitter.ApiClient") as mock_api_cls,
|
||||
patch("splitter.AIChunker") as mock_chunker_cls,
|
||||
patch("splitter.MarkdownWriter") as mock_writer_cls,
|
||||
patch("splitter.JsonWriter"),
|
||||
patch("splitter.ParserRegistry") as mock_registry_cls,
|
||||
patch("splitter.TextParser"),
|
||||
patch("splitter.CsvParser"),
|
||||
patch("splitter.HtmlParser"),
|
||||
patch("splitter.PdfParser"),
|
||||
patch("splitter.DocParser"),
|
||||
patch("splitter.LegacyDocParser"),
|
||||
patch("splitter.XlsxParser"),
|
||||
patch("splitter.XlsParser"),
|
||||
patch("splitter.ImageParser"),
|
||||
):
|
||||
api_client = mock_api_cls.return_value
|
||||
chunker = mock_chunker_cls.return_value
|
||||
writer = mock_writer_cls.return_value
|
||||
registry = mock_registry_cls.return_value
|
||||
|
||||
splitter = Splitter(api_key="test-key", delimiter="---")
|
||||
|
||||
yield {
|
||||
"splitter": splitter,
|
||||
"api_client": api_client,
|
||||
"chunker": chunker,
|
||||
"writer": writer,
|
||||
"registry": registry,
|
||||
}
|
||||
|
||||
|
||||
class TestInit:
|
||||
"""初始化测试"""
|
||||
|
||||
def test_registers_all_parsers(self):
|
||||
"""验证所有解析器都被注册"""
|
||||
with (
|
||||
patch("splitter.ApiClient"),
|
||||
patch("splitter.AIChunker"),
|
||||
patch("splitter.MarkdownWriter"),
|
||||
patch("splitter.JsonWriter"),
|
||||
patch("splitter.ParserRegistry") as mock_registry_cls,
|
||||
patch("splitter.TextParser"),
|
||||
patch("splitter.CsvParser"),
|
||||
patch("splitter.HtmlParser"),
|
||||
patch("splitter.PdfParser"),
|
||||
patch("splitter.DocParser"),
|
||||
patch("splitter.LegacyDocParser"),
|
||||
patch("splitter.XlsxParser"),
|
||||
patch("splitter.XlsParser"),
|
||||
patch("splitter.ImageParser"),
|
||||
):
|
||||
registry = mock_registry_cls.return_value
|
||||
Splitter(api_key="test-key")
|
||||
# 9 parsers: Text, Csv, Html, Pdf, Doc, LegacyDoc, Xlsx, Xls, Image
|
||||
assert registry.register.call_count == 9
|
||||
|
||||
def test_creates_api_client_with_key(self):
|
||||
"""验证 ApiClient 使用正确的 api_key 创建"""
|
||||
with (
|
||||
patch("splitter.ApiClient") as mock_api_cls,
|
||||
patch("splitter.AIChunker"),
|
||||
patch("splitter.MarkdownWriter"),
|
||||
patch("splitter.JsonWriter"),
|
||||
patch("splitter.ParserRegistry"),
|
||||
patch("splitter.TextParser"),
|
||||
patch("splitter.CsvParser"),
|
||||
patch("splitter.HtmlParser"),
|
||||
patch("splitter.PdfParser"),
|
||||
patch("splitter.DocParser"),
|
||||
patch("splitter.LegacyDocParser"),
|
||||
patch("splitter.XlsxParser"),
|
||||
patch("splitter.XlsParser"),
|
||||
patch("splitter.ImageParser"),
|
||||
):
|
||||
Splitter(api_key="my-secret-key")
|
||||
mock_api_cls.assert_called_once_with(api_key="my-secret-key")
|
||||
|
||||
def test_creates_chunker_with_delimiter(self):
|
||||
"""验证 AIChunker 使用正确的 delimiter 创建"""
|
||||
with (
|
||||
patch("splitter.ApiClient") as mock_api_cls,
|
||||
patch("splitter.AIChunker") as mock_chunker_cls,
|
||||
patch("splitter.MarkdownWriter"),
|
||||
patch("splitter.JsonWriter"),
|
||||
patch("splitter.ParserRegistry"),
|
||||
patch("splitter.TextParser"),
|
||||
patch("splitter.CsvParser"),
|
||||
patch("splitter.HtmlParser"),
|
||||
patch("splitter.PdfParser"),
|
||||
patch("splitter.DocParser"),
|
||||
patch("splitter.LegacyDocParser"),
|
||||
patch("splitter.XlsxParser"),
|
||||
patch("splitter.XlsParser"),
|
||||
patch("splitter.ImageParser"),
|
||||
):
|
||||
Splitter(api_key="key", delimiter="===")
|
||||
mock_chunker_cls.assert_called_once_with(
|
||||
mock_api_cls.return_value, "===", pre_split_size=None
|
||||
)
|
||||
|
||||
|
||||
class TestProcessSuccess:
|
||||
"""成功处理流程测试"""
|
||||
|
||||
def test_full_flow(self, mock_deps, tmp_path, capsys):
|
||||
"""验证完整的成功处理流程"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
chunker = mock_deps["chunker"]
|
||||
writer = mock_deps["writer"]
|
||||
|
||||
# Setup
|
||||
input_file = tmp_path / "test.txt"
|
||||
input_file.write_text("hello")
|
||||
output_file = str(tmp_path / "output.md")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "parsed text"
|
||||
registry.get_parser.return_value = mock_parser
|
||||
|
||||
chunks = [Chunk(title="标题", content="内容")]
|
||||
chunker.chunk.return_value = chunks
|
||||
|
||||
# Execute
|
||||
splitter.process(str(input_file), output_file)
|
||||
|
||||
# Verify call chain
|
||||
registry.get_parser.assert_called_once_with(str(input_file))
|
||||
mock_parser.parse.assert_called_once_with(str(input_file))
|
||||
chunker.chunk.assert_called_once()
|
||||
assert chunker.chunk.call_args[0][0] == "parsed text"
|
||||
writer.write.assert_called_once_with(
|
||||
chunks, output_file, "test.txt", "---"
|
||||
)
|
||||
|
||||
def test_logs_parsing_stage(self, mock_deps, tmp_path, capsys):
|
||||
"""验证输出文件解析日志"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
chunker = mock_deps["chunker"]
|
||||
|
||||
input_file = tmp_path / "doc.pdf"
|
||||
input_file.write_text("data")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "text"
|
||||
registry.get_parser.return_value = mock_parser
|
||||
chunker.chunk.return_value = [Chunk(title="t", content="c")]
|
||||
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "解析文件: doc.pdf" in output
|
||||
|
||||
def test_logs_chunking_stage(self, mock_deps, tmp_path, capsys):
|
||||
"""验证输出 AI 分块日志"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
chunker = mock_deps["chunker"]
|
||||
|
||||
input_file = tmp_path / "doc.txt"
|
||||
input_file.write_text("data")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "text"
|
||||
registry.get_parser.return_value = mock_parser
|
||||
chunker.chunk.return_value = [Chunk(title="t", content="c")]
|
||||
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "AI 语义分块" in output
|
||||
|
||||
def test_logs_writing_stage(self, mock_deps, tmp_path, capsys):
|
||||
"""验证输出写入日志"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
chunker = mock_deps["chunker"]
|
||||
|
||||
input_file = tmp_path / "doc.txt"
|
||||
input_file.write_text("data")
|
||||
output_path = str(tmp_path / "out.md")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "text"
|
||||
registry.get_parser.return_value = mock_parser
|
||||
chunker.chunk.return_value = [Chunk(title="t", content="c")]
|
||||
|
||||
splitter.process(str(input_file), output_path)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "写入输出" in output
|
||||
|
||||
def test_logs_summary(self, mock_deps, tmp_path, capsys):
|
||||
"""验证输出处理摘要"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
chunker = mock_deps["chunker"]
|
||||
|
||||
input_file = tmp_path / "doc.txt"
|
||||
input_file.write_text("data")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "text"
|
||||
registry.get_parser.return_value = mock_parser
|
||||
chunker.chunk.return_value = [
|
||||
Chunk(title="t1", content="c1"),
|
||||
Chunk(title="t2", content="c2"),
|
||||
Chunk(title="t3", content="c3"),
|
||||
]
|
||||
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "3 个分块" in output
|
||||
|
||||
def test_progress_callback_passed_to_chunker(self, mock_deps, tmp_path, capsys):
|
||||
"""验证进度回调被传递给 chunker 并正确输出"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
chunker = mock_deps["chunker"]
|
||||
|
||||
input_file = tmp_path / "doc.txt"
|
||||
input_file.write_text("data")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "text"
|
||||
registry.get_parser.return_value = mock_parser
|
||||
|
||||
# Simulate chunker calling the progress callback
|
||||
def fake_chunk(text, content_type=None, source_file="", on_progress=None):
|
||||
if on_progress:
|
||||
on_progress(1, 3)
|
||||
on_progress(2, 3)
|
||||
on_progress(3, 3)
|
||||
return [Chunk(title="t", content="c")]
|
||||
|
||||
chunker.chunk.side_effect = fake_chunk
|
||||
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "分块进度: 1/3" in output
|
||||
assert "分块进度: 2/3" in output
|
||||
assert "分块进度: 3/3" in output
|
||||
|
||||
|
||||
class TestProcessErrors:
|
||||
"""错误处理测试"""
|
||||
|
||||
def test_file_not_found(self, mock_deps):
|
||||
"""验证文件不存在时抛出 FileNotFoundError"""
|
||||
splitter = mock_deps["splitter"]
|
||||
with pytest.raises(FileNotFoundError, match="输入文件不存在"):
|
||||
splitter.process("/nonexistent/path/file.txt", "output.md")
|
||||
|
||||
def test_unsupported_format(self, mock_deps, tmp_path):
|
||||
"""验证不支持的格式时抛出 UnsupportedFormatError"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
|
||||
input_file = tmp_path / "file.xyz"
|
||||
input_file.write_text("data")
|
||||
|
||||
registry.get_parser.side_effect = UnsupportedFormatError("file.xyz", ".xyz")
|
||||
|
||||
with pytest.raises(UnsupportedFormatError):
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
def test_parse_error(self, mock_deps, tmp_path):
|
||||
"""验证解析错误时抛出 ParseError"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
|
||||
input_file = tmp_path / "bad.pdf"
|
||||
input_file.write_bytes(b"\x00\x01\x02")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.side_effect = ParseError("bad.pdf", "文件损坏")
|
||||
registry.get_parser.return_value = mock_parser
|
||||
|
||||
with pytest.raises(ParseError, match="bad.pdf"):
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
def test_api_error(self, mock_deps, tmp_path):
|
||||
"""验证 API 错误时抛出 ApiError"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
chunker = mock_deps["chunker"]
|
||||
|
||||
input_file = tmp_path / "doc.txt"
|
||||
input_file.write_text("data")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "text"
|
||||
registry.get_parser.return_value = mock_parser
|
||||
chunker.chunk.side_effect = ApiError("API 调用失败")
|
||||
|
||||
with pytest.raises(ApiError, match="API 调用失败"):
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
|
||||
class TestCustomDelimiter:
|
||||
"""自定义分隔符测试"""
|
||||
|
||||
def test_delimiter_passed_to_writer(self, tmp_path):
|
||||
"""验证自定义分隔符传递给 writer"""
|
||||
with (
|
||||
patch("splitter.ApiClient"),
|
||||
patch("splitter.AIChunker") as mock_chunker_cls,
|
||||
patch("splitter.MarkdownWriter") as mock_writer_cls,
|
||||
patch("splitter.JsonWriter"),
|
||||
patch("splitter.ParserRegistry") as mock_registry_cls,
|
||||
patch("splitter.TextParser"),
|
||||
patch("splitter.CsvParser"),
|
||||
patch("splitter.HtmlParser"),
|
||||
patch("splitter.PdfParser"),
|
||||
patch("splitter.DocParser"),
|
||||
patch("splitter.LegacyDocParser"),
|
||||
patch("splitter.XlsxParser"),
|
||||
patch("splitter.XlsParser"),
|
||||
patch("splitter.ImageParser"),
|
||||
):
|
||||
splitter = Splitter(api_key="key", delimiter="===")
|
||||
|
||||
input_file = tmp_path / "test.txt"
|
||||
input_file.write_text("hello")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "text"
|
||||
mock_registry_cls.return_value.get_parser.return_value = mock_parser
|
||||
|
||||
chunks = [Chunk(title="t", content="c")]
|
||||
mock_chunker_cls.return_value.chunk.return_value = chunks
|
||||
|
||||
writer = mock_writer_cls.return_value
|
||||
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
writer.write.assert_called_once_with(
|
||||
chunks, str(tmp_path / "out.md"), "test.txt", "==="
|
||||
)
|
||||
Reference in New Issue
Block a user