Initial commit: AI 知识库文档智能分块工具
This commit is contained in:
317
tests/test_chunker.py
Normal file
317
tests/test_chunker.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""AIChunker 单元测试"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, call
|
||||
|
||||
from chunker import AIChunker
|
||||
from exceptions import ApiError
|
||||
from models import Chunk
|
||||
|
||||
|
||||
def _make_chunker(api_response="标题1\n\n内容1", delimiter="---"):
|
||||
"""创建注入 mock ApiClient 的 AIChunker"""
|
||||
mock_api = MagicMock()
|
||||
mock_api.chat.return_value = api_response
|
||||
chunker = AIChunker(api_client=mock_api, delimiter=delimiter)
|
||||
return chunker, mock_api
|
||||
|
||||
|
||||
class TestParseResponse:
|
||||
"""_parse_response() 方法测试"""
|
||||
|
||||
def test_single_chunk(self):
|
||||
"""解析单个分块"""
|
||||
chunker, _ = _make_chunker()
|
||||
result = chunker._parse_response("摘要标题\n\n这是分块内容")
|
||||
assert len(result) == 1
|
||||
assert result[0].title == "摘要标题"
|
||||
assert result[0].content == "这是分块内容"
|
||||
|
||||
def test_multiple_chunks(self):
|
||||
"""解析多个分块(用 delimiter 分隔)"""
|
||||
chunker, _ = _make_chunker()
|
||||
response = "标题一\n\n内容一\n\n---\n标题二\n\n内容二"
|
||||
result = chunker._parse_response(response)
|
||||
assert len(result) == 2
|
||||
assert result[0].title == "标题一"
|
||||
assert result[0].content == "内容一"
|
||||
assert result[1].title == "标题二"
|
||||
assert result[1].content == "内容二"
|
||||
|
||||
def test_skip_empty_parts(self):
|
||||
"""跳过空片段"""
|
||||
chunker, _ = _make_chunker()
|
||||
response = "标题\n\n内容\n\n---\n\n---\n"
|
||||
result = chunker._parse_response(response)
|
||||
assert len(result) == 1
|
||||
assert result[0].title == "标题"
|
||||
|
||||
def test_title_only_no_content(self):
|
||||
"""只有标题没有内容的分块"""
|
||||
chunker, _ = _make_chunker()
|
||||
result = chunker._parse_response("仅标题")
|
||||
assert len(result) == 1
|
||||
assert result[0].title == "仅标题"
|
||||
assert result[0].content == ""
|
||||
|
||||
def test_empty_response_raises_error(self):
|
||||
"""空响应抛出 ApiError"""
|
||||
chunker, _ = _make_chunker()
|
||||
with pytest.raises(ApiError, match="API 返回空响应"):
|
||||
chunker._parse_response("")
|
||||
|
||||
def test_whitespace_only_response_raises_error(self):
|
||||
"""纯空白响应抛出 ApiError"""
|
||||
chunker, _ = _make_chunker()
|
||||
with pytest.raises(ApiError, match="API 返回空响应"):
|
||||
chunker._parse_response(" \n\n ")
|
||||
|
||||
def test_custom_delimiter(self):
|
||||
"""使用自定义分隔符解析"""
|
||||
chunker, _ = _make_chunker(delimiter="===")
|
||||
response = "标题A\n\n内容A\n\n===\n标题B\n\n内容B"
|
||||
result = chunker._parse_response(response)
|
||||
assert len(result) == 2
|
||||
assert result[0].title == "标题A"
|
||||
assert result[1].title == "标题B"
|
||||
|
||||
def test_strips_whitespace_from_parts(self):
|
||||
"""去除片段首尾空白"""
|
||||
chunker, _ = _make_chunker()
|
||||
response = " \n标题\n\n内容\n "
|
||||
result = chunker._parse_response(response)
|
||||
assert len(result) == 1
|
||||
assert result[0].title == "标题"
|
||||
assert result[0].content == "内容"
|
||||
|
||||
|
||||
class TestPreSplit:
|
||||
"""_pre_split() 方法测试"""
|
||||
|
||||
def test_short_text_single_segment(self):
|
||||
"""短文本不需要切分,返回单段"""
|
||||
chunker, _ = _make_chunker()
|
||||
text = "短文本内容"
|
||||
result = chunker._pre_split(text)
|
||||
assert len(result) == 1
|
||||
assert result[0] == text
|
||||
|
||||
def test_split_on_paragraph_boundary(self):
|
||||
"""在段落边界(双换行)处切分"""
|
||||
chunker, _ = _make_chunker()
|
||||
para1 = "a" * 7000
|
||||
para2 = "b" * 7000
|
||||
text = f"{para1}\n\n{para2}"
|
||||
result = chunker._pre_split(text)
|
||||
assert len(result) == 2
|
||||
assert result[0] == para1
|
||||
assert result[1] == para2
|
||||
|
||||
def test_greedy_merge_paragraphs(self):
|
||||
"""贪心合并段落,不超过 PRE_SPLIT_SIZE"""
|
||||
chunker, _ = _make_chunker()
|
||||
para1 = "a" * 4000
|
||||
para2 = "b" * 4000
|
||||
para3 = "c" * 5000
|
||||
text = f"{para1}\n\n{para2}\n\n{para3}"
|
||||
result = chunker._pre_split(text)
|
||||
# para1 + \n\n + para2 = 8002 <= 12000, so they merge
|
||||
# adding para3 would be 8002 + 2 + 5000 = 13004 > 12000
|
||||
assert len(result) == 2
|
||||
assert result[0] == f"{para1}\n\n{para2}"
|
||||
assert result[1] == para3
|
||||
|
||||
def test_single_paragraph_exceeds_limit_split_by_newline(self):
|
||||
"""单段落超限时按单换行符切分"""
|
||||
chunker, _ = _make_chunker()
|
||||
line1 = "x" * 7000
|
||||
line2 = "y" * 7000
|
||||
# 单个段落(无双换行),但有单换行
|
||||
text = f"{line1}\n{line2}"
|
||||
result = chunker._pre_split(text)
|
||||
assert len(result) == 2
|
||||
assert result[0] == line1
|
||||
assert result[1] == line2
|
||||
|
||||
def test_hard_split_very_long_line(self):
|
||||
"""超长单行硬切分"""
|
||||
chunker, _ = _make_chunker()
|
||||
# 一行超过 PRE_SPLIT_SIZE 且无段落/换行分隔
|
||||
text = "a" * 30000
|
||||
result = chunker._pre_split(text)
|
||||
assert len(result) >= 2
|
||||
for seg in result:
|
||||
assert len(seg) <= chunker.PRE_SPLIT_SIZE
|
||||
# 拼接后内容不丢失
|
||||
assert "".join(result) == text
|
||||
|
||||
def test_no_content_loss(self):
|
||||
"""预切分后拼接不丢失内容"""
|
||||
chunker, _ = _make_chunker()
|
||||
para1 = "a" * 5000
|
||||
para2 = "b" * 5000
|
||||
para3 = "c" * 5000
|
||||
text = f"{para1}\n\n{para2}\n\n{para3}"
|
||||
result = chunker._pre_split(text)
|
||||
joined = "\n\n".join(result)
|
||||
assert para1 in joined
|
||||
assert para2 in joined
|
||||
assert para3 in joined
|
||||
|
||||
def test_each_segment_within_limit(self):
|
||||
"""每段不超过 PRE_SPLIT_SIZE"""
|
||||
chunker, _ = _make_chunker()
|
||||
paragraphs = ["p" * 5000 for _ in range(10)]
|
||||
text = "\n\n".join(paragraphs)
|
||||
result = chunker._pre_split(text)
|
||||
for seg in result:
|
||||
assert len(seg) <= chunker.PRE_SPLIT_SIZE
|
||||
|
||||
|
||||
class TestCallApi:
|
||||
"""_call_api() 方法测试"""
|
||||
|
||||
def test_calls_api_with_correct_prompts(self):
|
||||
"""调用 API 时使用正确的提示词"""
|
||||
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
|
||||
chunker._call_api("测试文本")
|
||||
|
||||
mock_api.chat.assert_called_once()
|
||||
args = mock_api.chat.call_args
|
||||
system_prompt = args[0][0] if args[0] else args[1].get("system_prompt")
|
||||
user_content = args[0][1] if len(args[0]) > 1 else args[1].get("user_content")
|
||||
# 系统提示词应包含 delimiter
|
||||
assert "---" in system_prompt
|
||||
# 用户提示词应包含文本内容
|
||||
assert "测试文本" in user_content
|
||||
|
||||
def test_returns_parsed_chunks(self):
|
||||
"""返回解析后的 Chunk 列表"""
|
||||
chunker, _ = _make_chunker(api_response="标题\n\n内容")
|
||||
result = chunker._call_api("文本")
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], Chunk)
|
||||
assert result[0].title == "标题"
|
||||
|
||||
def test_api_error_propagates(self):
|
||||
"""API 错误向上传播"""
|
||||
chunker, mock_api = _make_chunker()
|
||||
mock_api.chat.side_effect = ApiError("调用失败")
|
||||
with pytest.raises(ApiError, match="调用失败"):
|
||||
chunker._call_api("文本")
|
||||
|
||||
|
||||
class TestChunk:
|
||||
"""chunk() 方法测试"""
|
||||
|
||||
def test_short_text_single_api_call(self):
|
||||
"""短文本(≤ PRE_SPLIT_SIZE)只调用一次 API"""
|
||||
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
|
||||
text = "短文本" * 100 # well under 12000
|
||||
result = chunker.chunk(text)
|
||||
|
||||
assert len(result) == 1
|
||||
assert mock_api.chat.call_count == 1
|
||||
|
||||
def test_long_text_multiple_api_calls(self):
|
||||
"""长文本(> PRE_SPLIT_SIZE)预切分后多次调用 API"""
|
||||
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
|
||||
# 创建超过 PRE_SPLIT_SIZE 的文本
|
||||
text = ("段落内容" * 2000 + "\n\n") * 5
|
||||
result = chunker.chunk(text)
|
||||
|
||||
assert mock_api.chat.call_count > 1
|
||||
assert len(result) >= 1
|
||||
|
||||
def test_progress_callback_called(self):
|
||||
"""长文本时 on_progress 回调被正确调用"""
|
||||
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
|
||||
# 构造需要 pre_split 的长文本
|
||||
para1 = "a" * 7000
|
||||
para2 = "b" * 7000
|
||||
para3 = "c" * 7000
|
||||
text = f"{para1}\n\n{para2}\n\n{para3}"
|
||||
|
||||
progress_calls = []
|
||||
def on_progress(current, total):
|
||||
progress_calls.append((current, total))
|
||||
|
||||
chunker.chunk(text, on_progress=on_progress)
|
||||
|
||||
# 验证 progress 回调
|
||||
assert len(progress_calls) > 0
|
||||
total = progress_calls[0][1]
|
||||
for i, (current, t) in enumerate(progress_calls):
|
||||
assert current == i + 1
|
||||
assert t == total
|
||||
|
||||
def test_no_progress_callback_no_error(self):
|
||||
"""不传 on_progress 不报错"""
|
||||
chunker, _ = _make_chunker(api_response="标题\n\n内容")
|
||||
text = "短文本"
|
||||
result = chunker.chunk(text)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_short_text_no_progress_callback(self):
|
||||
"""短文本不触发 on_progress 回调"""
|
||||
chunker, _ = _make_chunker(api_response="标题\n\n内容")
|
||||
progress_calls = []
|
||||
chunker.chunk("短文本", on_progress=lambda c, t: progress_calls.append((c, t)))
|
||||
assert len(progress_calls) == 0
|
||||
|
||||
def test_chunks_aggregated_from_segments(self):
|
||||
"""多段的 chunks 被正确聚合"""
|
||||
mock_api = MagicMock()
|
||||
# 每次 API 调用返回不同的响应
|
||||
mock_api.chat.side_effect = [
|
||||
"标题A\n\n内容A",
|
||||
"标题B\n\n内容B",
|
||||
]
|
||||
chunker = AIChunker(api_client=mock_api, delimiter="---")
|
||||
|
||||
# 构造需要 2 段的文本(每段 > 12000/2 使得合并后超限)
|
||||
text = "a" * 7000 + "\n\n" + "b" * 7000
|
||||
result = chunker.chunk(text)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].title == "标题A"
|
||||
assert result[1].title == "标题B"
|
||||
|
||||
|
||||
class TestHardSplit:
|
||||
"""_hard_split() 和 _find_sentence_boundary() 测试"""
|
||||
|
||||
def test_hard_split_preserves_content(self):
|
||||
"""硬切分不丢失内容"""
|
||||
chunker, _ = _make_chunker()
|
||||
text = "x" * 30000
|
||||
result = chunker._hard_split(text)
|
||||
assert "".join(result) == text
|
||||
|
||||
def test_hard_split_respects_limit(self):
|
||||
"""硬切分每段不超过 PRE_SPLIT_SIZE"""
|
||||
chunker, _ = _make_chunker()
|
||||
text = "x" * 30000
|
||||
result = chunker._hard_split(text)
|
||||
for seg in result:
|
||||
assert len(seg) <= chunker.PRE_SPLIT_SIZE
|
||||
|
||||
def test_sentence_boundary_split(self):
|
||||
"""硬切分优先在句子边界切分"""
|
||||
chunker, _ = _make_chunker()
|
||||
# 构造在中间有句号的超长文本,总长超过 12000
|
||||
text = "a" * 9000 + "。" + "b" * 5000
|
||||
result = chunker._hard_split(text)
|
||||
assert len(result) == 2
|
||||
# 第一段应在句号后切分
|
||||
assert result[0].endswith("。")
|
||||
|
||||
def test_find_sentence_boundary_chinese(self):
|
||||
"""中文句号作为句子边界"""
|
||||
boundary = AIChunker._find_sentence_boundary("你好世界。再见")
|
||||
assert boundary == 5 # "你好世界。" 的长度
|
||||
|
||||
def test_find_sentence_boundary_english(self):
|
||||
"""英文句号作为句子边界"""
|
||||
boundary = AIChunker._find_sentence_boundary("Hello world. Bye")
|
||||
assert boundary == 12 # index of '.' is 11, return 11+1=12
|
||||
Reference in New Issue
Block a user