Initial commit: AI 知识库文档智能分块工具

This commit is contained in:
AI Knowledge Splitter
2026-03-02 17:38:28 +08:00
commit 92e7fc5bda
160 changed files with 9577 additions and 0 deletions

359
tests/test_splitter.py Normal file
View File

@@ -0,0 +1,359 @@
"""Splitter 协调器单元测试"""
import os
import pytest
from unittest.mock import MagicMock, patch, call
from exceptions import ApiError, ParseError, UnsupportedFormatError
from models import Chunk
from splitter import Splitter
@pytest.fixture
def mock_deps():
"""Patch all external dependencies and return their mocks."""
with (
patch("splitter.ApiClient") as mock_api_cls,
patch("splitter.AIChunker") as mock_chunker_cls,
patch("splitter.MarkdownWriter") as mock_writer_cls,
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry") as mock_registry_cls,
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
api_client = mock_api_cls.return_value
chunker = mock_chunker_cls.return_value
writer = mock_writer_cls.return_value
registry = mock_registry_cls.return_value
splitter = Splitter(api_key="test-key", delimiter="---")
yield {
"splitter": splitter,
"api_client": api_client,
"chunker": chunker,
"writer": writer,
"registry": registry,
}
class TestInit:
"""初始化测试"""
def test_registers_all_parsers(self):
"""验证所有解析器都被注册"""
with (
patch("splitter.ApiClient"),
patch("splitter.AIChunker"),
patch("splitter.MarkdownWriter"),
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry") as mock_registry_cls,
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
registry = mock_registry_cls.return_value
Splitter(api_key="test-key")
# 9 parsers: Text, Csv, Html, Pdf, Doc, LegacyDoc, Xlsx, Xls, Image
assert registry.register.call_count == 9
def test_creates_api_client_with_key(self):
"""验证 ApiClient 使用正确的 api_key 创建"""
with (
patch("splitter.ApiClient") as mock_api_cls,
patch("splitter.AIChunker"),
patch("splitter.MarkdownWriter"),
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry"),
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
Splitter(api_key="my-secret-key")
mock_api_cls.assert_called_once_with(api_key="my-secret-key")
def test_creates_chunker_with_delimiter(self):
"""验证 AIChunker 使用正确的 delimiter 创建"""
with (
patch("splitter.ApiClient") as mock_api_cls,
patch("splitter.AIChunker") as mock_chunker_cls,
patch("splitter.MarkdownWriter"),
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry"),
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
Splitter(api_key="key", delimiter="===")
mock_chunker_cls.assert_called_once_with(
mock_api_cls.return_value, "===", pre_split_size=None
)
class TestProcessSuccess:
"""成功处理流程测试"""
def test_full_flow(self, mock_deps, tmp_path, capsys):
"""验证完整的成功处理流程"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
writer = mock_deps["writer"]
# Setup
input_file = tmp_path / "test.txt"
input_file.write_text("hello")
output_file = str(tmp_path / "output.md")
mock_parser = MagicMock()
mock_parser.parse.return_value = "parsed text"
registry.get_parser.return_value = mock_parser
chunks = [Chunk(title="标题", content="内容")]
chunker.chunk.return_value = chunks
# Execute
splitter.process(str(input_file), output_file)
# Verify call chain
registry.get_parser.assert_called_once_with(str(input_file))
mock_parser.parse.assert_called_once_with(str(input_file))
chunker.chunk.assert_called_once()
assert chunker.chunk.call_args[0][0] == "parsed text"
writer.write.assert_called_once_with(
chunks, output_file, "test.txt", "---"
)
def test_logs_parsing_stage(self, mock_deps, tmp_path, capsys):
"""验证输出文件解析日志"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.pdf"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.return_value = [Chunk(title="t", content="c")]
splitter.process(str(input_file), str(tmp_path / "out.md"))
output = capsys.readouterr().out
assert "解析文件: doc.pdf" in output
def test_logs_chunking_stage(self, mock_deps, tmp_path, capsys):
"""验证输出 AI 分块日志"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.return_value = [Chunk(title="t", content="c")]
splitter.process(str(input_file), str(tmp_path / "out.md"))
output = capsys.readouterr().out
assert "AI 语义分块" in output
def test_logs_writing_stage(self, mock_deps, tmp_path, capsys):
"""验证输出写入日志"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
output_path = str(tmp_path / "out.md")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.return_value = [Chunk(title="t", content="c")]
splitter.process(str(input_file), output_path)
output = capsys.readouterr().out
assert "写入输出" in output
def test_logs_summary(self, mock_deps, tmp_path, capsys):
"""验证输出处理摘要"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.return_value = [
Chunk(title="t1", content="c1"),
Chunk(title="t2", content="c2"),
Chunk(title="t3", content="c3"),
]
splitter.process(str(input_file), str(tmp_path / "out.md"))
output = capsys.readouterr().out
assert "3 个分块" in output
def test_progress_callback_passed_to_chunker(self, mock_deps, tmp_path, capsys):
"""验证进度回调被传递给 chunker 并正确输出"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
# Simulate chunker calling the progress callback
def fake_chunk(text, content_type=None, source_file="", on_progress=None):
if on_progress:
on_progress(1, 3)
on_progress(2, 3)
on_progress(3, 3)
return [Chunk(title="t", content="c")]
chunker.chunk.side_effect = fake_chunk
splitter.process(str(input_file), str(tmp_path / "out.md"))
output = capsys.readouterr().out
assert "分块进度: 1/3" in output
assert "分块进度: 2/3" in output
assert "分块进度: 3/3" in output
class TestProcessErrors:
"""错误处理测试"""
def test_file_not_found(self, mock_deps):
"""验证文件不存在时抛出 FileNotFoundError"""
splitter = mock_deps["splitter"]
with pytest.raises(FileNotFoundError, match="输入文件不存在"):
splitter.process("/nonexistent/path/file.txt", "output.md")
def test_unsupported_format(self, mock_deps, tmp_path):
"""验证不支持的格式时抛出 UnsupportedFormatError"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
input_file = tmp_path / "file.xyz"
input_file.write_text("data")
registry.get_parser.side_effect = UnsupportedFormatError("file.xyz", ".xyz")
with pytest.raises(UnsupportedFormatError):
splitter.process(str(input_file), str(tmp_path / "out.md"))
def test_parse_error(self, mock_deps, tmp_path):
"""验证解析错误时抛出 ParseError"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
input_file = tmp_path / "bad.pdf"
input_file.write_bytes(b"\x00\x01\x02")
mock_parser = MagicMock()
mock_parser.parse.side_effect = ParseError("bad.pdf", "文件损坏")
registry.get_parser.return_value = mock_parser
with pytest.raises(ParseError, match="bad.pdf"):
splitter.process(str(input_file), str(tmp_path / "out.md"))
def test_api_error(self, mock_deps, tmp_path):
"""验证 API 错误时抛出 ApiError"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.side_effect = ApiError("API 调用失败")
with pytest.raises(ApiError, match="API 调用失败"):
splitter.process(str(input_file), str(tmp_path / "out.md"))
class TestCustomDelimiter:
"""自定义分隔符测试"""
def test_delimiter_passed_to_writer(self, tmp_path):
"""验证自定义分隔符传递给 writer"""
with (
patch("splitter.ApiClient"),
patch("splitter.AIChunker") as mock_chunker_cls,
patch("splitter.MarkdownWriter") as mock_writer_cls,
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry") as mock_registry_cls,
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
splitter = Splitter(api_key="key", delimiter="===")
input_file = tmp_path / "test.txt"
input_file.write_text("hello")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
mock_registry_cls.return_value.get_parser.return_value = mock_parser
chunks = [Chunk(title="t", content="c")]
mock_chunker_cls.return_value.chunk.return_value = chunks
writer = mock_writer_cls.return_value
splitter.process(str(input_file), str(tmp_path / "out.md"))
writer.write.assert_called_once_with(
chunks, str(tmp_path / "out.md"), "test.txt", "==="
)