Initial commit: AI 知识库文档智能分块工具

This commit is contained in:
AI Knowledge Splitter
2026-03-02 17:38:28 +08:00
commit 92e7fc5bda
160 changed files with 9577 additions and 0 deletions

317
tests/test_chunker.py Normal file
View File

@@ -0,0 +1,317 @@
"""AIChunker 单元测试"""
import pytest
from unittest.mock import MagicMock, call
from chunker import AIChunker
from exceptions import ApiError
from models import Chunk
def _make_chunker(api_response="标题1\n\n内容1", delimiter="---"):
"""创建注入 mock ApiClient 的 AIChunker"""
mock_api = MagicMock()
mock_api.chat.return_value = api_response
chunker = AIChunker(api_client=mock_api, delimiter=delimiter)
return chunker, mock_api
class TestParseResponse:
"""_parse_response() 方法测试"""
def test_single_chunk(self):
"""解析单个分块"""
chunker, _ = _make_chunker()
result = chunker._parse_response("摘要标题\n\n这是分块内容")
assert len(result) == 1
assert result[0].title == "摘要标题"
assert result[0].content == "这是分块内容"
def test_multiple_chunks(self):
"""解析多个分块(用 delimiter 分隔)"""
chunker, _ = _make_chunker()
response = "标题一\n\n内容一\n\n---\n标题二\n\n内容二"
result = chunker._parse_response(response)
assert len(result) == 2
assert result[0].title == "标题一"
assert result[0].content == "内容一"
assert result[1].title == "标题二"
assert result[1].content == "内容二"
def test_skip_empty_parts(self):
"""跳过空片段"""
chunker, _ = _make_chunker()
response = "标题\n\n内容\n\n---\n\n---\n"
result = chunker._parse_response(response)
assert len(result) == 1
assert result[0].title == "标题"
def test_title_only_no_content(self):
"""只有标题没有内容的分块"""
chunker, _ = _make_chunker()
result = chunker._parse_response("仅标题")
assert len(result) == 1
assert result[0].title == "仅标题"
assert result[0].content == ""
def test_empty_response_raises_error(self):
"""空响应抛出 ApiError"""
chunker, _ = _make_chunker()
with pytest.raises(ApiError, match="API 返回空响应"):
chunker._parse_response("")
def test_whitespace_only_response_raises_error(self):
"""纯空白响应抛出 ApiError"""
chunker, _ = _make_chunker()
with pytest.raises(ApiError, match="API 返回空响应"):
chunker._parse_response(" \n\n ")
def test_custom_delimiter(self):
"""使用自定义分隔符解析"""
chunker, _ = _make_chunker(delimiter="===")
response = "标题A\n\n内容A\n\n===\n标题B\n\n内容B"
result = chunker._parse_response(response)
assert len(result) == 2
assert result[0].title == "标题A"
assert result[1].title == "标题B"
def test_strips_whitespace_from_parts(self):
"""去除片段首尾空白"""
chunker, _ = _make_chunker()
response = " \n标题\n\n内容\n "
result = chunker._parse_response(response)
assert len(result) == 1
assert result[0].title == "标题"
assert result[0].content == "内容"
class TestPreSplit:
"""_pre_split() 方法测试"""
def test_short_text_single_segment(self):
"""短文本不需要切分,返回单段"""
chunker, _ = _make_chunker()
text = "短文本内容"
result = chunker._pre_split(text)
assert len(result) == 1
assert result[0] == text
def test_split_on_paragraph_boundary(self):
"""在段落边界(双换行)处切分"""
chunker, _ = _make_chunker()
para1 = "a" * 7000
para2 = "b" * 7000
text = f"{para1}\n\n{para2}"
result = chunker._pre_split(text)
assert len(result) == 2
assert result[0] == para1
assert result[1] == para2
def test_greedy_merge_paragraphs(self):
"""贪心合并段落,不超过 PRE_SPLIT_SIZE"""
chunker, _ = _make_chunker()
para1 = "a" * 4000
para2 = "b" * 4000
para3 = "c" * 5000
text = f"{para1}\n\n{para2}\n\n{para3}"
result = chunker._pre_split(text)
# para1 + \n\n + para2 = 8002 <= 12000, so they merge
# adding para3 would be 8002 + 2 + 5000 = 13004 > 12000
assert len(result) == 2
assert result[0] == f"{para1}\n\n{para2}"
assert result[1] == para3
def test_single_paragraph_exceeds_limit_split_by_newline(self):
"""单段落超限时按单换行符切分"""
chunker, _ = _make_chunker()
line1 = "x" * 7000
line2 = "y" * 7000
# 单个段落(无双换行),但有单换行
text = f"{line1}\n{line2}"
result = chunker._pre_split(text)
assert len(result) == 2
assert result[0] == line1
assert result[1] == line2
def test_hard_split_very_long_line(self):
"""超长单行硬切分"""
chunker, _ = _make_chunker()
# 一行超过 PRE_SPLIT_SIZE 且无段落/换行分隔
text = "a" * 30000
result = chunker._pre_split(text)
assert len(result) >= 2
for seg in result:
assert len(seg) <= chunker.PRE_SPLIT_SIZE
# 拼接后内容不丢失
assert "".join(result) == text
def test_no_content_loss(self):
"""预切分后拼接不丢失内容"""
chunker, _ = _make_chunker()
para1 = "a" * 5000
para2 = "b" * 5000
para3 = "c" * 5000
text = f"{para1}\n\n{para2}\n\n{para3}"
result = chunker._pre_split(text)
joined = "\n\n".join(result)
assert para1 in joined
assert para2 in joined
assert para3 in joined
def test_each_segment_within_limit(self):
"""每段不超过 PRE_SPLIT_SIZE"""
chunker, _ = _make_chunker()
paragraphs = ["p" * 5000 for _ in range(10)]
text = "\n\n".join(paragraphs)
result = chunker._pre_split(text)
for seg in result:
assert len(seg) <= chunker.PRE_SPLIT_SIZE
class TestCallApi:
"""_call_api() 方法测试"""
def test_calls_api_with_correct_prompts(self):
"""调用 API 时使用正确的提示词"""
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
chunker._call_api("测试文本")
mock_api.chat.assert_called_once()
args = mock_api.chat.call_args
system_prompt = args[0][0] if args[0] else args[1].get("system_prompt")
user_content = args[0][1] if len(args[0]) > 1 else args[1].get("user_content")
# 系统提示词应包含 delimiter
assert "---" in system_prompt
# 用户提示词应包含文本内容
assert "测试文本" in user_content
def test_returns_parsed_chunks(self):
"""返回解析后的 Chunk 列表"""
chunker, _ = _make_chunker(api_response="标题\n\n内容")
result = chunker._call_api("文本")
assert len(result) == 1
assert isinstance(result[0], Chunk)
assert result[0].title == "标题"
def test_api_error_propagates(self):
"""API 错误向上传播"""
chunker, mock_api = _make_chunker()
mock_api.chat.side_effect = ApiError("调用失败")
with pytest.raises(ApiError, match="调用失败"):
chunker._call_api("文本")
class TestChunk:
"""chunk() 方法测试"""
def test_short_text_single_api_call(self):
"""短文本(≤ PRE_SPLIT_SIZE只调用一次 API"""
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
text = "短文本" * 100 # well under 12000
result = chunker.chunk(text)
assert len(result) == 1
assert mock_api.chat.call_count == 1
def test_long_text_multiple_api_calls(self):
"""长文本(> PRE_SPLIT_SIZE预切分后多次调用 API"""
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
# 创建超过 PRE_SPLIT_SIZE 的文本
text = ("段落内容" * 2000 + "\n\n") * 5
result = chunker.chunk(text)
assert mock_api.chat.call_count > 1
assert len(result) >= 1
def test_progress_callback_called(self):
"""长文本时 on_progress 回调被正确调用"""
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
# 构造需要 pre_split 的长文本
para1 = "a" * 7000
para2 = "b" * 7000
para3 = "c" * 7000
text = f"{para1}\n\n{para2}\n\n{para3}"
progress_calls = []
def on_progress(current, total):
progress_calls.append((current, total))
chunker.chunk(text, on_progress=on_progress)
# 验证 progress 回调
assert len(progress_calls) > 0
total = progress_calls[0][1]
for i, (current, t) in enumerate(progress_calls):
assert current == i + 1
assert t == total
def test_no_progress_callback_no_error(self):
"""不传 on_progress 不报错"""
chunker, _ = _make_chunker(api_response="标题\n\n内容")
text = "短文本"
result = chunker.chunk(text)
assert len(result) == 1
def test_short_text_no_progress_callback(self):
"""短文本不触发 on_progress 回调"""
chunker, _ = _make_chunker(api_response="标题\n\n内容")
progress_calls = []
chunker.chunk("短文本", on_progress=lambda c, t: progress_calls.append((c, t)))
assert len(progress_calls) == 0
def test_chunks_aggregated_from_segments(self):
"""多段的 chunks 被正确聚合"""
mock_api = MagicMock()
# 每次 API 调用返回不同的响应
mock_api.chat.side_effect = [
"标题A\n\n内容A",
"标题B\n\n内容B",
]
chunker = AIChunker(api_client=mock_api, delimiter="---")
# 构造需要 2 段的文本(每段 > 12000/2 使得合并后超限)
text = "a" * 7000 + "\n\n" + "b" * 7000
result = chunker.chunk(text)
assert len(result) == 2
assert result[0].title == "标题A"
assert result[1].title == "标题B"
class TestHardSplit:
"""_hard_split() 和 _find_sentence_boundary() 测试"""
def test_hard_split_preserves_content(self):
"""硬切分不丢失内容"""
chunker, _ = _make_chunker()
text = "x" * 30000
result = chunker._hard_split(text)
assert "".join(result) == text
def test_hard_split_respects_limit(self):
"""硬切分每段不超过 PRE_SPLIT_SIZE"""
chunker, _ = _make_chunker()
text = "x" * 30000
result = chunker._hard_split(text)
for seg in result:
assert len(seg) <= chunker.PRE_SPLIT_SIZE
def test_sentence_boundary_split(self):
"""硬切分优先在句子边界切分"""
chunker, _ = _make_chunker()
# 构造在中间有句号的超长文本,总长超过 12000
text = "a" * 9000 + "" + "b" * 5000
result = chunker._hard_split(text)
assert len(result) == 2
# 第一段应在句号后切分
assert result[0].endswith("")
def test_find_sentence_boundary_chinese(self):
"""中文句号作为句子边界"""
boundary = AIChunker._find_sentence_boundary("你好世界。再见")
assert boundary == 5 # "你好世界。" 的长度
def test_find_sentence_boundary_english(self):
"""英文句号作为句子边界"""
boundary = AIChunker._find_sentence_boundary("Hello world. Bye")
assert boundary == 12 # index of '.' is 11, return 11+1=12