Files
bigwo/tests/test_chunker.py
2026-03-02 17:38:28 +08:00

318 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""AIChunker 单元测试"""
import pytest
from unittest.mock import MagicMock, call
from chunker import AIChunker
from exceptions import ApiError
from models import Chunk
def _make_chunker(api_response="标题1\n\n内容1", delimiter="---"):
"""创建注入 mock ApiClient 的 AIChunker"""
mock_api = MagicMock()
mock_api.chat.return_value = api_response
chunker = AIChunker(api_client=mock_api, delimiter=delimiter)
return chunker, mock_api
class TestParseResponse:
"""_parse_response() 方法测试"""
def test_single_chunk(self):
"""解析单个分块"""
chunker, _ = _make_chunker()
result = chunker._parse_response("摘要标题\n\n这是分块内容")
assert len(result) == 1
assert result[0].title == "摘要标题"
assert result[0].content == "这是分块内容"
def test_multiple_chunks(self):
"""解析多个分块(用 delimiter 分隔)"""
chunker, _ = _make_chunker()
response = "标题一\n\n内容一\n\n---\n标题二\n\n内容二"
result = chunker._parse_response(response)
assert len(result) == 2
assert result[0].title == "标题一"
assert result[0].content == "内容一"
assert result[1].title == "标题二"
assert result[1].content == "内容二"
def test_skip_empty_parts(self):
"""跳过空片段"""
chunker, _ = _make_chunker()
response = "标题\n\n内容\n\n---\n\n---\n"
result = chunker._parse_response(response)
assert len(result) == 1
assert result[0].title == "标题"
def test_title_only_no_content(self):
"""只有标题没有内容的分块"""
chunker, _ = _make_chunker()
result = chunker._parse_response("仅标题")
assert len(result) == 1
assert result[0].title == "仅标题"
assert result[0].content == ""
def test_empty_response_raises_error(self):
"""空响应抛出 ApiError"""
chunker, _ = _make_chunker()
with pytest.raises(ApiError, match="API 返回空响应"):
chunker._parse_response("")
def test_whitespace_only_response_raises_error(self):
"""纯空白响应抛出 ApiError"""
chunker, _ = _make_chunker()
with pytest.raises(ApiError, match="API 返回空响应"):
chunker._parse_response(" \n\n ")
def test_custom_delimiter(self):
"""使用自定义分隔符解析"""
chunker, _ = _make_chunker(delimiter="===")
response = "标题A\n\n内容A\n\n===\n标题B\n\n内容B"
result = chunker._parse_response(response)
assert len(result) == 2
assert result[0].title == "标题A"
assert result[1].title == "标题B"
def test_strips_whitespace_from_parts(self):
"""去除片段首尾空白"""
chunker, _ = _make_chunker()
response = " \n标题\n\n内容\n "
result = chunker._parse_response(response)
assert len(result) == 1
assert result[0].title == "标题"
assert result[0].content == "内容"
class TestPreSplit:
"""_pre_split() 方法测试"""
def test_short_text_single_segment(self):
"""短文本不需要切分,返回单段"""
chunker, _ = _make_chunker()
text = "短文本内容"
result = chunker._pre_split(text)
assert len(result) == 1
assert result[0] == text
def test_split_on_paragraph_boundary(self):
"""在段落边界(双换行)处切分"""
chunker, _ = _make_chunker()
para1 = "a" * 7000
para2 = "b" * 7000
text = f"{para1}\n\n{para2}"
result = chunker._pre_split(text)
assert len(result) == 2
assert result[0] == para1
assert result[1] == para2
def test_greedy_merge_paragraphs(self):
"""贪心合并段落,不超过 PRE_SPLIT_SIZE"""
chunker, _ = _make_chunker()
para1 = "a" * 4000
para2 = "b" * 4000
para3 = "c" * 5000
text = f"{para1}\n\n{para2}\n\n{para3}"
result = chunker._pre_split(text)
# para1 + \n\n + para2 = 8002 <= 12000, so they merge
# adding para3 would be 8002 + 2 + 5000 = 13004 > 12000
assert len(result) == 2
assert result[0] == f"{para1}\n\n{para2}"
assert result[1] == para3
def test_single_paragraph_exceeds_limit_split_by_newline(self):
"""单段落超限时按单换行符切分"""
chunker, _ = _make_chunker()
line1 = "x" * 7000
line2 = "y" * 7000
# 单个段落(无双换行),但有单换行
text = f"{line1}\n{line2}"
result = chunker._pre_split(text)
assert len(result) == 2
assert result[0] == line1
assert result[1] == line2
def test_hard_split_very_long_line(self):
"""超长单行硬切分"""
chunker, _ = _make_chunker()
# 一行超过 PRE_SPLIT_SIZE 且无段落/换行分隔
text = "a" * 30000
result = chunker._pre_split(text)
assert len(result) >= 2
for seg in result:
assert len(seg) <= chunker.PRE_SPLIT_SIZE
# 拼接后内容不丢失
assert "".join(result) == text
def test_no_content_loss(self):
"""预切分后拼接不丢失内容"""
chunker, _ = _make_chunker()
para1 = "a" * 5000
para2 = "b" * 5000
para3 = "c" * 5000
text = f"{para1}\n\n{para2}\n\n{para3}"
result = chunker._pre_split(text)
joined = "\n\n".join(result)
assert para1 in joined
assert para2 in joined
assert para3 in joined
def test_each_segment_within_limit(self):
"""每段不超过 PRE_SPLIT_SIZE"""
chunker, _ = _make_chunker()
paragraphs = ["p" * 5000 for _ in range(10)]
text = "\n\n".join(paragraphs)
result = chunker._pre_split(text)
for seg in result:
assert len(seg) <= chunker.PRE_SPLIT_SIZE
class TestCallApi:
"""_call_api() 方法测试"""
def test_calls_api_with_correct_prompts(self):
"""调用 API 时使用正确的提示词"""
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
chunker._call_api("测试文本")
mock_api.chat.assert_called_once()
args = mock_api.chat.call_args
system_prompt = args[0][0] if args[0] else args[1].get("system_prompt")
user_content = args[0][1] if len(args[0]) > 1 else args[1].get("user_content")
# 系统提示词应包含 delimiter
assert "---" in system_prompt
# 用户提示词应包含文本内容
assert "测试文本" in user_content
def test_returns_parsed_chunks(self):
"""返回解析后的 Chunk 列表"""
chunker, _ = _make_chunker(api_response="标题\n\n内容")
result = chunker._call_api("文本")
assert len(result) == 1
assert isinstance(result[0], Chunk)
assert result[0].title == "标题"
def test_api_error_propagates(self):
"""API 错误向上传播"""
chunker, mock_api = _make_chunker()
mock_api.chat.side_effect = ApiError("调用失败")
with pytest.raises(ApiError, match="调用失败"):
chunker._call_api("文本")
class TestChunk:
"""chunk() 方法测试"""
def test_short_text_single_api_call(self):
"""短文本(≤ PRE_SPLIT_SIZE只调用一次 API"""
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
text = "短文本" * 100 # well under 12000
result = chunker.chunk(text)
assert len(result) == 1
assert mock_api.chat.call_count == 1
def test_long_text_multiple_api_calls(self):
"""长文本(> PRE_SPLIT_SIZE预切分后多次调用 API"""
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
# 创建超过 PRE_SPLIT_SIZE 的文本
text = ("段落内容" * 2000 + "\n\n") * 5
result = chunker.chunk(text)
assert mock_api.chat.call_count > 1
assert len(result) >= 1
def test_progress_callback_called(self):
"""长文本时 on_progress 回调被正确调用"""
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
# 构造需要 pre_split 的长文本
para1 = "a" * 7000
para2 = "b" * 7000
para3 = "c" * 7000
text = f"{para1}\n\n{para2}\n\n{para3}"
progress_calls = []
def on_progress(current, total):
progress_calls.append((current, total))
chunker.chunk(text, on_progress=on_progress)
# 验证 progress 回调
assert len(progress_calls) > 0
total = progress_calls[0][1]
for i, (current, t) in enumerate(progress_calls):
assert current == i + 1
assert t == total
def test_no_progress_callback_no_error(self):
"""不传 on_progress 不报错"""
chunker, _ = _make_chunker(api_response="标题\n\n内容")
text = "短文本"
result = chunker.chunk(text)
assert len(result) == 1
def test_short_text_no_progress_callback(self):
"""短文本不触发 on_progress 回调"""
chunker, _ = _make_chunker(api_response="标题\n\n内容")
progress_calls = []
chunker.chunk("短文本", on_progress=lambda c, t: progress_calls.append((c, t)))
assert len(progress_calls) == 0
def test_chunks_aggregated_from_segments(self):
"""多段的 chunks 被正确聚合"""
mock_api = MagicMock()
# 每次 API 调用返回不同的响应
mock_api.chat.side_effect = [
"标题A\n\n内容A",
"标题B\n\n内容B",
]
chunker = AIChunker(api_client=mock_api, delimiter="---")
# 构造需要 2 段的文本(每段 > 12000/2 使得合并后超限)
text = "a" * 7000 + "\n\n" + "b" * 7000
result = chunker.chunk(text)
assert len(result) == 2
assert result[0].title == "标题A"
assert result[1].title == "标题B"
class TestHardSplit:
"""_hard_split() 和 _find_sentence_boundary() 测试"""
def test_hard_split_preserves_content(self):
"""硬切分不丢失内容"""
chunker, _ = _make_chunker()
text = "x" * 30000
result = chunker._hard_split(text)
assert "".join(result) == text
def test_hard_split_respects_limit(self):
"""硬切分每段不超过 PRE_SPLIT_SIZE"""
chunker, _ = _make_chunker()
text = "x" * 30000
result = chunker._hard_split(text)
for seg in result:
assert len(seg) <= chunker.PRE_SPLIT_SIZE
def test_sentence_boundary_split(self):
"""硬切分优先在句子边界切分"""
chunker, _ = _make_chunker()
# 构造在中间有句号的超长文本,总长超过 12000
text = "a" * 9000 + "" + "b" * 5000
result = chunker._hard_split(text)
assert len(result) == 2
# 第一段应在句号后切分
assert result[0].endswith("")
def test_find_sentence_boundary_chinese(self):
"""中文句号作为句子边界"""
boundary = AIChunker._find_sentence_boundary("你好世界。再见")
assert boundary == 5 # "你好世界。" 的长度
def test_find_sentence_boundary_english(self):
"""英文句号作为句子边界"""
boundary = AIChunker._find_sentence_boundary("Hello world. Bye")
assert boundary == 12 # index of '.' is 11, return 11+1=12