318 lines
11 KiB
Python
318 lines
11 KiB
Python
"""AIChunker 单元测试"""
|
||
|
||
import pytest
|
||
from unittest.mock import MagicMock, call
|
||
|
||
from chunker import AIChunker
|
||
from exceptions import ApiError
|
||
from models import Chunk
|
||
|
||
|
||
def _make_chunker(api_response="标题1\n\n内容1", delimiter="---"):
|
||
"""创建注入 mock ApiClient 的 AIChunker"""
|
||
mock_api = MagicMock()
|
||
mock_api.chat.return_value = api_response
|
||
chunker = AIChunker(api_client=mock_api, delimiter=delimiter)
|
||
return chunker, mock_api
|
||
|
||
|
||
class TestParseResponse:
|
||
"""_parse_response() 方法测试"""
|
||
|
||
def test_single_chunk(self):
|
||
"""解析单个分块"""
|
||
chunker, _ = _make_chunker()
|
||
result = chunker._parse_response("摘要标题\n\n这是分块内容")
|
||
assert len(result) == 1
|
||
assert result[0].title == "摘要标题"
|
||
assert result[0].content == "这是分块内容"
|
||
|
||
def test_multiple_chunks(self):
|
||
"""解析多个分块(用 delimiter 分隔)"""
|
||
chunker, _ = _make_chunker()
|
||
response = "标题一\n\n内容一\n\n---\n标题二\n\n内容二"
|
||
result = chunker._parse_response(response)
|
||
assert len(result) == 2
|
||
assert result[0].title == "标题一"
|
||
assert result[0].content == "内容一"
|
||
assert result[1].title == "标题二"
|
||
assert result[1].content == "内容二"
|
||
|
||
def test_skip_empty_parts(self):
|
||
"""跳过空片段"""
|
||
chunker, _ = _make_chunker()
|
||
response = "标题\n\n内容\n\n---\n\n---\n"
|
||
result = chunker._parse_response(response)
|
||
assert len(result) == 1
|
||
assert result[0].title == "标题"
|
||
|
||
def test_title_only_no_content(self):
|
||
"""只有标题没有内容的分块"""
|
||
chunker, _ = _make_chunker()
|
||
result = chunker._parse_response("仅标题")
|
||
assert len(result) == 1
|
||
assert result[0].title == "仅标题"
|
||
assert result[0].content == ""
|
||
|
||
def test_empty_response_raises_error(self):
|
||
"""空响应抛出 ApiError"""
|
||
chunker, _ = _make_chunker()
|
||
with pytest.raises(ApiError, match="API 返回空响应"):
|
||
chunker._parse_response("")
|
||
|
||
def test_whitespace_only_response_raises_error(self):
|
||
"""纯空白响应抛出 ApiError"""
|
||
chunker, _ = _make_chunker()
|
||
with pytest.raises(ApiError, match="API 返回空响应"):
|
||
chunker._parse_response(" \n\n ")
|
||
|
||
def test_custom_delimiter(self):
|
||
"""使用自定义分隔符解析"""
|
||
chunker, _ = _make_chunker(delimiter="===")
|
||
response = "标题A\n\n内容A\n\n===\n标题B\n\n内容B"
|
||
result = chunker._parse_response(response)
|
||
assert len(result) == 2
|
||
assert result[0].title == "标题A"
|
||
assert result[1].title == "标题B"
|
||
|
||
def test_strips_whitespace_from_parts(self):
|
||
"""去除片段首尾空白"""
|
||
chunker, _ = _make_chunker()
|
||
response = " \n标题\n\n内容\n "
|
||
result = chunker._parse_response(response)
|
||
assert len(result) == 1
|
||
assert result[0].title == "标题"
|
||
assert result[0].content == "内容"
|
||
|
||
|
||
class TestPreSplit:
|
||
"""_pre_split() 方法测试"""
|
||
|
||
def test_short_text_single_segment(self):
|
||
"""短文本不需要切分,返回单段"""
|
||
chunker, _ = _make_chunker()
|
||
text = "短文本内容"
|
||
result = chunker._pre_split(text)
|
||
assert len(result) == 1
|
||
assert result[0] == text
|
||
|
||
def test_split_on_paragraph_boundary(self):
|
||
"""在段落边界(双换行)处切分"""
|
||
chunker, _ = _make_chunker()
|
||
para1 = "a" * 7000
|
||
para2 = "b" * 7000
|
||
text = f"{para1}\n\n{para2}"
|
||
result = chunker._pre_split(text)
|
||
assert len(result) == 2
|
||
assert result[0] == para1
|
||
assert result[1] == para2
|
||
|
||
def test_greedy_merge_paragraphs(self):
|
||
"""贪心合并段落,不超过 PRE_SPLIT_SIZE"""
|
||
chunker, _ = _make_chunker()
|
||
para1 = "a" * 4000
|
||
para2 = "b" * 4000
|
||
para3 = "c" * 5000
|
||
text = f"{para1}\n\n{para2}\n\n{para3}"
|
||
result = chunker._pre_split(text)
|
||
# para1 + \n\n + para2 = 8002 <= 12000, so they merge
|
||
# adding para3 would be 8002 + 2 + 5000 = 13004 > 12000
|
||
assert len(result) == 2
|
||
assert result[0] == f"{para1}\n\n{para2}"
|
||
assert result[1] == para3
|
||
|
||
def test_single_paragraph_exceeds_limit_split_by_newline(self):
|
||
"""单段落超限时按单换行符切分"""
|
||
chunker, _ = _make_chunker()
|
||
line1 = "x" * 7000
|
||
line2 = "y" * 7000
|
||
# 单个段落(无双换行),但有单换行
|
||
text = f"{line1}\n{line2}"
|
||
result = chunker._pre_split(text)
|
||
assert len(result) == 2
|
||
assert result[0] == line1
|
||
assert result[1] == line2
|
||
|
||
def test_hard_split_very_long_line(self):
|
||
"""超长单行硬切分"""
|
||
chunker, _ = _make_chunker()
|
||
# 一行超过 PRE_SPLIT_SIZE 且无段落/换行分隔
|
||
text = "a" * 30000
|
||
result = chunker._pre_split(text)
|
||
assert len(result) >= 2
|
||
for seg in result:
|
||
assert len(seg) <= chunker.PRE_SPLIT_SIZE
|
||
# 拼接后内容不丢失
|
||
assert "".join(result) == text
|
||
|
||
def test_no_content_loss(self):
|
||
"""预切分后拼接不丢失内容"""
|
||
chunker, _ = _make_chunker()
|
||
para1 = "a" * 5000
|
||
para2 = "b" * 5000
|
||
para3 = "c" * 5000
|
||
text = f"{para1}\n\n{para2}\n\n{para3}"
|
||
result = chunker._pre_split(text)
|
||
joined = "\n\n".join(result)
|
||
assert para1 in joined
|
||
assert para2 in joined
|
||
assert para3 in joined
|
||
|
||
def test_each_segment_within_limit(self):
|
||
"""每段不超过 PRE_SPLIT_SIZE"""
|
||
chunker, _ = _make_chunker()
|
||
paragraphs = ["p" * 5000 for _ in range(10)]
|
||
text = "\n\n".join(paragraphs)
|
||
result = chunker._pre_split(text)
|
||
for seg in result:
|
||
assert len(seg) <= chunker.PRE_SPLIT_SIZE
|
||
|
||
|
||
class TestCallApi:
|
||
"""_call_api() 方法测试"""
|
||
|
||
def test_calls_api_with_correct_prompts(self):
|
||
"""调用 API 时使用正确的提示词"""
|
||
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
|
||
chunker._call_api("测试文本")
|
||
|
||
mock_api.chat.assert_called_once()
|
||
args = mock_api.chat.call_args
|
||
system_prompt = args[0][0] if args[0] else args[1].get("system_prompt")
|
||
user_content = args[0][1] if len(args[0]) > 1 else args[1].get("user_content")
|
||
# 系统提示词应包含 delimiter
|
||
assert "---" in system_prompt
|
||
# 用户提示词应包含文本内容
|
||
assert "测试文本" in user_content
|
||
|
||
def test_returns_parsed_chunks(self):
|
||
"""返回解析后的 Chunk 列表"""
|
||
chunker, _ = _make_chunker(api_response="标题\n\n内容")
|
||
result = chunker._call_api("文本")
|
||
assert len(result) == 1
|
||
assert isinstance(result[0], Chunk)
|
||
assert result[0].title == "标题"
|
||
|
||
def test_api_error_propagates(self):
|
||
"""API 错误向上传播"""
|
||
chunker, mock_api = _make_chunker()
|
||
mock_api.chat.side_effect = ApiError("调用失败")
|
||
with pytest.raises(ApiError, match="调用失败"):
|
||
chunker._call_api("文本")
|
||
|
||
|
||
class TestChunk:
|
||
"""chunk() 方法测试"""
|
||
|
||
def test_short_text_single_api_call(self):
|
||
"""短文本(≤ PRE_SPLIT_SIZE)只调用一次 API"""
|
||
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
|
||
text = "短文本" * 100 # well under 12000
|
||
result = chunker.chunk(text)
|
||
|
||
assert len(result) == 1
|
||
assert mock_api.chat.call_count == 1
|
||
|
||
def test_long_text_multiple_api_calls(self):
|
||
"""长文本(> PRE_SPLIT_SIZE)预切分后多次调用 API"""
|
||
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
|
||
# 创建超过 PRE_SPLIT_SIZE 的文本
|
||
text = ("段落内容" * 2000 + "\n\n") * 5
|
||
result = chunker.chunk(text)
|
||
|
||
assert mock_api.chat.call_count > 1
|
||
assert len(result) >= 1
|
||
|
||
def test_progress_callback_called(self):
|
||
"""长文本时 on_progress 回调被正确调用"""
|
||
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
|
||
# 构造需要 pre_split 的长文本
|
||
para1 = "a" * 7000
|
||
para2 = "b" * 7000
|
||
para3 = "c" * 7000
|
||
text = f"{para1}\n\n{para2}\n\n{para3}"
|
||
|
||
progress_calls = []
|
||
def on_progress(current, total):
|
||
progress_calls.append((current, total))
|
||
|
||
chunker.chunk(text, on_progress=on_progress)
|
||
|
||
# 验证 progress 回调
|
||
assert len(progress_calls) > 0
|
||
total = progress_calls[0][1]
|
||
for i, (current, t) in enumerate(progress_calls):
|
||
assert current == i + 1
|
||
assert t == total
|
||
|
||
def test_no_progress_callback_no_error(self):
|
||
"""不传 on_progress 不报错"""
|
||
chunker, _ = _make_chunker(api_response="标题\n\n内容")
|
||
text = "短文本"
|
||
result = chunker.chunk(text)
|
||
assert len(result) == 1
|
||
|
||
def test_short_text_no_progress_callback(self):
|
||
"""短文本不触发 on_progress 回调"""
|
||
chunker, _ = _make_chunker(api_response="标题\n\n内容")
|
||
progress_calls = []
|
||
chunker.chunk("短文本", on_progress=lambda c, t: progress_calls.append((c, t)))
|
||
assert len(progress_calls) == 0
|
||
|
||
def test_chunks_aggregated_from_segments(self):
|
||
"""多段的 chunks 被正确聚合"""
|
||
mock_api = MagicMock()
|
||
# 每次 API 调用返回不同的响应
|
||
mock_api.chat.side_effect = [
|
||
"标题A\n\n内容A",
|
||
"标题B\n\n内容B",
|
||
]
|
||
chunker = AIChunker(api_client=mock_api, delimiter="---")
|
||
|
||
# 构造需要 2 段的文本(每段 > 12000/2 使得合并后超限)
|
||
text = "a" * 7000 + "\n\n" + "b" * 7000
|
||
result = chunker.chunk(text)
|
||
|
||
assert len(result) == 2
|
||
assert result[0].title == "标题A"
|
||
assert result[1].title == "标题B"
|
||
|
||
|
||
class TestHardSplit:
|
||
"""_hard_split() 和 _find_sentence_boundary() 测试"""
|
||
|
||
def test_hard_split_preserves_content(self):
|
||
"""硬切分不丢失内容"""
|
||
chunker, _ = _make_chunker()
|
||
text = "x" * 30000
|
||
result = chunker._hard_split(text)
|
||
assert "".join(result) == text
|
||
|
||
def test_hard_split_respects_limit(self):
|
||
"""硬切分每段不超过 PRE_SPLIT_SIZE"""
|
||
chunker, _ = _make_chunker()
|
||
text = "x" * 30000
|
||
result = chunker._hard_split(text)
|
||
for seg in result:
|
||
assert len(seg) <= chunker.PRE_SPLIT_SIZE
|
||
|
||
def test_sentence_boundary_split(self):
|
||
"""硬切分优先在句子边界切分"""
|
||
chunker, _ = _make_chunker()
|
||
# 构造在中间有句号的超长文本,总长超过 12000
|
||
text = "a" * 9000 + "。" + "b" * 5000
|
||
result = chunker._hard_split(text)
|
||
assert len(result) == 2
|
||
# 第一段应在句号后切分
|
||
assert result[0].endswith("。")
|
||
|
||
def test_find_sentence_boundary_chinese(self):
|
||
"""中文句号作为句子边界"""
|
||
boundary = AIChunker._find_sentence_boundary("你好世界。再见")
|
||
assert boundary == 5 # "你好世界。" 的长度
|
||
|
||
def test_find_sentence_boundary_english(self):
|
||
"""英文句号作为句子边界"""
|
||
boundary = AIChunker._find_sentence_boundary("Hello world. Bye")
|
||
assert boundary == 12 # index of '.' is 11, return 11+1=12
|