"""AIChunker 单元测试""" import pytest from unittest.mock import MagicMock, call from chunker import AIChunker from exceptions import ApiError from models import Chunk def _make_chunker(api_response="标题1\n\n内容1", delimiter="---"): """创建注入 mock ApiClient 的 AIChunker""" mock_api = MagicMock() mock_api.chat.return_value = api_response chunker = AIChunker(api_client=mock_api, delimiter=delimiter) return chunker, mock_api class TestParseResponse: """_parse_response() 方法测试""" def test_single_chunk(self): """解析单个分块""" chunker, _ = _make_chunker() result = chunker._parse_response("摘要标题\n\n这是分块内容") assert len(result) == 1 assert result[0].title == "摘要标题" assert result[0].content == "这是分块内容" def test_multiple_chunks(self): """解析多个分块(用 delimiter 分隔)""" chunker, _ = _make_chunker() response = "标题一\n\n内容一\n\n---\n标题二\n\n内容二" result = chunker._parse_response(response) assert len(result) == 2 assert result[0].title == "标题一" assert result[0].content == "内容一" assert result[1].title == "标题二" assert result[1].content == "内容二" def test_skip_empty_parts(self): """跳过空片段""" chunker, _ = _make_chunker() response = "标题\n\n内容\n\n---\n\n---\n" result = chunker._parse_response(response) assert len(result) == 1 assert result[0].title == "标题" def test_title_only_no_content(self): """只有标题没有内容的分块""" chunker, _ = _make_chunker() result = chunker._parse_response("仅标题") assert len(result) == 1 assert result[0].title == "仅标题" assert result[0].content == "" def test_empty_response_raises_error(self): """空响应抛出 ApiError""" chunker, _ = _make_chunker() with pytest.raises(ApiError, match="API 返回空响应"): chunker._parse_response("") def test_whitespace_only_response_raises_error(self): """纯空白响应抛出 ApiError""" chunker, _ = _make_chunker() with pytest.raises(ApiError, match="API 返回空响应"): chunker._parse_response(" \n\n ") def test_custom_delimiter(self): """使用自定义分隔符解析""" chunker, _ = _make_chunker(delimiter="===") response = "标题A\n\n内容A\n\n===\n标题B\n\n内容B" result = chunker._parse_response(response) assert len(result) == 2 assert result[0].title == "标题A" assert result[1].title == "标题B" def test_strips_whitespace_from_parts(self): """去除片段首尾空白""" chunker, _ = _make_chunker() response = " \n标题\n\n内容\n " result = chunker._parse_response(response) assert len(result) == 1 assert result[0].title == "标题" assert result[0].content == "内容" class TestPreSplit: """_pre_split() 方法测试""" def test_short_text_single_segment(self): """短文本不需要切分,返回单段""" chunker, _ = _make_chunker() text = "短文本内容" result = chunker._pre_split(text) assert len(result) == 1 assert result[0] == text def test_split_on_paragraph_boundary(self): """在段落边界(双换行)处切分""" chunker, _ = _make_chunker() para1 = "a" * 7000 para2 = "b" * 7000 text = f"{para1}\n\n{para2}" result = chunker._pre_split(text) assert len(result) == 2 assert result[0] == para1 assert result[1] == para2 def test_greedy_merge_paragraphs(self): """贪心合并段落,不超过 PRE_SPLIT_SIZE""" chunker, _ = _make_chunker() para1 = "a" * 4000 para2 = "b" * 4000 para3 = "c" * 5000 text = f"{para1}\n\n{para2}\n\n{para3}" result = chunker._pre_split(text) # para1 + \n\n + para2 = 8002 <= 12000, so they merge # adding para3 would be 8002 + 2 + 5000 = 13004 > 12000 assert len(result) == 2 assert result[0] == f"{para1}\n\n{para2}" assert result[1] == para3 def test_single_paragraph_exceeds_limit_split_by_newline(self): """单段落超限时按单换行符切分""" chunker, _ = _make_chunker() line1 = "x" * 7000 line2 = "y" * 7000 # 单个段落(无双换行),但有单换行 text = f"{line1}\n{line2}" result = chunker._pre_split(text) assert len(result) == 2 assert result[0] == line1 assert result[1] == line2 def test_hard_split_very_long_line(self): """超长单行硬切分""" chunker, _ = _make_chunker() # 一行超过 PRE_SPLIT_SIZE 且无段落/换行分隔 text = "a" * 30000 result = chunker._pre_split(text) assert len(result) >= 2 for seg in result: assert len(seg) <= chunker.PRE_SPLIT_SIZE # 拼接后内容不丢失 assert "".join(result) == text def test_no_content_loss(self): """预切分后拼接不丢失内容""" chunker, _ = _make_chunker() para1 = "a" * 5000 para2 = "b" * 5000 para3 = "c" * 5000 text = f"{para1}\n\n{para2}\n\n{para3}" result = chunker._pre_split(text) joined = "\n\n".join(result) assert para1 in joined assert para2 in joined assert para3 in joined def test_each_segment_within_limit(self): """每段不超过 PRE_SPLIT_SIZE""" chunker, _ = _make_chunker() paragraphs = ["p" * 5000 for _ in range(10)] text = "\n\n".join(paragraphs) result = chunker._pre_split(text) for seg in result: assert len(seg) <= chunker.PRE_SPLIT_SIZE class TestCallApi: """_call_api() 方法测试""" def test_calls_api_with_correct_prompts(self): """调用 API 时使用正确的提示词""" chunker, mock_api = _make_chunker(api_response="标题\n\n内容") chunker._call_api("测试文本") mock_api.chat.assert_called_once() args = mock_api.chat.call_args system_prompt = args[0][0] if args[0] else args[1].get("system_prompt") user_content = args[0][1] if len(args[0]) > 1 else args[1].get("user_content") # 系统提示词应包含 delimiter assert "---" in system_prompt # 用户提示词应包含文本内容 assert "测试文本" in user_content def test_returns_parsed_chunks(self): """返回解析后的 Chunk 列表""" chunker, _ = _make_chunker(api_response="标题\n\n内容") result = chunker._call_api("文本") assert len(result) == 1 assert isinstance(result[0], Chunk) assert result[0].title == "标题" def test_api_error_propagates(self): """API 错误向上传播""" chunker, mock_api = _make_chunker() mock_api.chat.side_effect = ApiError("调用失败") with pytest.raises(ApiError, match="调用失败"): chunker._call_api("文本") class TestChunk: """chunk() 方法测试""" def test_short_text_single_api_call(self): """短文本(≤ PRE_SPLIT_SIZE)只调用一次 API""" chunker, mock_api = _make_chunker(api_response="标题\n\n内容") text = "短文本" * 100 # well under 12000 result = chunker.chunk(text) assert len(result) == 1 assert mock_api.chat.call_count == 1 def test_long_text_multiple_api_calls(self): """长文本(> PRE_SPLIT_SIZE)预切分后多次调用 API""" chunker, mock_api = _make_chunker(api_response="标题\n\n内容") # 创建超过 PRE_SPLIT_SIZE 的文本 text = ("段落内容" * 2000 + "\n\n") * 5 result = chunker.chunk(text) assert mock_api.chat.call_count > 1 assert len(result) >= 1 def test_progress_callback_called(self): """长文本时 on_progress 回调被正确调用""" chunker, mock_api = _make_chunker(api_response="标题\n\n内容") # 构造需要 pre_split 的长文本 para1 = "a" * 7000 para2 = "b" * 7000 para3 = "c" * 7000 text = f"{para1}\n\n{para2}\n\n{para3}" progress_calls = [] def on_progress(current, total): progress_calls.append((current, total)) chunker.chunk(text, on_progress=on_progress) # 验证 progress 回调 assert len(progress_calls) > 0 total = progress_calls[0][1] for i, (current, t) in enumerate(progress_calls): assert current == i + 1 assert t == total def test_no_progress_callback_no_error(self): """不传 on_progress 不报错""" chunker, _ = _make_chunker(api_response="标题\n\n内容") text = "短文本" result = chunker.chunk(text) assert len(result) == 1 def test_short_text_no_progress_callback(self): """短文本不触发 on_progress 回调""" chunker, _ = _make_chunker(api_response="标题\n\n内容") progress_calls = [] chunker.chunk("短文本", on_progress=lambda c, t: progress_calls.append((c, t))) assert len(progress_calls) == 0 def test_chunks_aggregated_from_segments(self): """多段的 chunks 被正确聚合""" mock_api = MagicMock() # 每次 API 调用返回不同的响应 mock_api.chat.side_effect = [ "标题A\n\n内容A", "标题B\n\n内容B", ] chunker = AIChunker(api_client=mock_api, delimiter="---") # 构造需要 2 段的文本(每段 > 12000/2 使得合并后超限) text = "a" * 7000 + "\n\n" + "b" * 7000 result = chunker.chunk(text) assert len(result) == 2 assert result[0].title == "标题A" assert result[1].title == "标题B" class TestHardSplit: """_hard_split() 和 _find_sentence_boundary() 测试""" def test_hard_split_preserves_content(self): """硬切分不丢失内容""" chunker, _ = _make_chunker() text = "x" * 30000 result = chunker._hard_split(text) assert "".join(result) == text def test_hard_split_respects_limit(self): """硬切分每段不超过 PRE_SPLIT_SIZE""" chunker, _ = _make_chunker() text = "x" * 30000 result = chunker._hard_split(text) for seg in result: assert len(seg) <= chunker.PRE_SPLIT_SIZE def test_sentence_boundary_split(self): """硬切分优先在句子边界切分""" chunker, _ = _make_chunker() # 构造在中间有句号的超长文本,总长超过 12000 text = "a" * 9000 + "。" + "b" * 5000 result = chunker._hard_split(text) assert len(result) == 2 # 第一段应在句号后切分 assert result[0].endswith("。") def test_find_sentence_boundary_chinese(self): """中文句号作为句子边界""" boundary = AIChunker._find_sentence_boundary("你好世界。再见") assert boundary == 5 # "你好世界。" 的长度 def test_find_sentence_boundary_english(self): """英文句号作为句子边界""" boundary = AIChunker._find_sentence_boundary("Hello world. Bye") assert boundary == 12 # index of '.' is 11, return 11+1=12