Initial commit: AI 知识库文档智能分块工具
This commit is contained in:
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# tests package
|
||||
251
tests/test_api_client.py
Normal file
251
tests/test_api_client.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""ApiClient 单元测试"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
import openai
|
||||
|
||||
from api_client import ApiClient
|
||||
from exceptions import ApiError
|
||||
|
||||
|
||||
def _make_completion_response(content: str):
|
||||
"""构造模拟的 ChatCompletion 响应"""
|
||||
message = MagicMock()
|
||||
message.content = content
|
||||
choice = MagicMock()
|
||||
choice.message = message
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
return response
|
||||
|
||||
|
||||
def _make_rate_limit_error():
|
||||
"""构造 openai.RateLimitError"""
|
||||
return openai.RateLimitError(
|
||||
message="Rate limit exceeded",
|
||||
response=MagicMock(status_code=429),
|
||||
body=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_api_error(status_code=500, message="Internal server error"):
|
||||
"""构造非速率限制的 openai.APIStatusError"""
|
||||
return openai.APIStatusError(
|
||||
message=message,
|
||||
response=MagicMock(status_code=status_code),
|
||||
body=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_client(**kwargs):
|
||||
"""创建注入 mock OpenAI client 的 ApiClient"""
|
||||
mock_openai = MagicMock()
|
||||
sleep_fn = kwargs.get("sleep_fn", MagicMock())
|
||||
return ApiClient(
|
||||
api_key="test-key",
|
||||
_client=mock_openai,
|
||||
_sleep=sleep_fn,
|
||||
), mock_openai, sleep_fn
|
||||
|
||||
|
||||
class TestApiClientChat:
|
||||
"""chat() 方法测试"""
|
||||
|
||||
def test_successful_chat(self):
|
||||
"""成功调用 chat 返回内容"""
|
||||
client, mock_openai, sleep_fn = _make_client()
|
||||
|
||||
expected = "这是 AI 的回复"
|
||||
mock_openai.chat.completions.create.return_value = _make_completion_response(expected)
|
||||
|
||||
result = client.chat("你是助手", "你好")
|
||||
|
||||
assert result == expected
|
||||
mock_openai.chat.completions.create.assert_called_once_with(
|
||||
model="deepseek-chat",
|
||||
messages=[
|
||||
{"role": "system", "content": "你是助手"},
|
||||
{"role": "user", "content": "你好"},
|
||||
],
|
||||
)
|
||||
sleep_fn.assert_not_called()
|
||||
|
||||
def test_chat_custom_model(self):
|
||||
"""chat 支持自定义模型"""
|
||||
client, mock_openai, _ = _make_client()
|
||||
mock_openai.chat.completions.create.return_value = _make_completion_response("ok")
|
||||
|
||||
client.chat("sys", "user", model="deepseek-reasoner")
|
||||
|
||||
mock_openai.chat.completions.create.assert_called_once_with(
|
||||
model="deepseek-reasoner",
|
||||
messages=[
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "user"},
|
||||
],
|
||||
)
|
||||
|
||||
def test_chat_retry_on_429_then_success(self):
|
||||
"""chat 遇到 429 后重试成功"""
|
||||
client, mock_openai, sleep_fn = _make_client()
|
||||
|
||||
mock_openai.chat.completions.create.side_effect = [
|
||||
_make_rate_limit_error(),
|
||||
_make_rate_limit_error(),
|
||||
_make_completion_response("成功"),
|
||||
]
|
||||
|
||||
result = client.chat("sys", "user")
|
||||
|
||||
assert result == "成功"
|
||||
assert sleep_fn.call_count == 2
|
||||
sleep_fn.assert_any_call(1)
|
||||
sleep_fn.assert_any_call(2)
|
||||
|
||||
def test_chat_retry_exhausted_raises_api_error(self):
|
||||
"""chat 重试耗尽抛出 ApiError"""
|
||||
client, mock_openai, sleep_fn = _make_client()
|
||||
|
||||
mock_openai.chat.completions.create.side_effect = [
|
||||
_make_rate_limit_error(),
|
||||
_make_rate_limit_error(),
|
||||
_make_rate_limit_error(),
|
||||
_make_rate_limit_error(),
|
||||
]
|
||||
|
||||
with pytest.raises(ApiError, match="速率限制重试耗尽") as exc_info:
|
||||
client.chat("sys", "user")
|
||||
|
||||
assert exc_info.value.status_code == 429
|
||||
assert sleep_fn.call_count == 3
|
||||
sleep_fn.assert_any_call(1)
|
||||
sleep_fn.assert_any_call(2)
|
||||
sleep_fn.assert_any_call(4)
|
||||
|
||||
def test_chat_non_429_error_raises_immediately(self):
|
||||
"""chat 遇到非 429 错误立即抛出 ApiError,不重试"""
|
||||
client, mock_openai, sleep_fn = _make_client()
|
||||
|
||||
mock_openai.chat.completions.create.side_effect = _make_api_error(500)
|
||||
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
client.chat("sys", "user")
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
sleep_fn.assert_not_called()
|
||||
|
||||
|
||||
class TestApiClientVision:
|
||||
"""vision() 方法测试"""
|
||||
|
||||
def test_successful_vision(self):
|
||||
"""成功调用 vision 返回内容"""
|
||||
client, mock_openai, sleep_fn = _make_client()
|
||||
|
||||
expected = "图片中包含一段文字"
|
||||
mock_openai.chat.completions.create.return_value = _make_completion_response(expected)
|
||||
|
||||
result = client.vision("识别图片", "aW1hZ2VfZGF0YQ==")
|
||||
|
||||
assert result == expected
|
||||
mock_openai.chat.completions.create.assert_called_once_with(
|
||||
model="deepseek-chat",
|
||||
messages=[
|
||||
{"role": "system", "content": "识别图片"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64,aW1hZ2VfZGF0YQ==",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
sleep_fn.assert_not_called()
|
||||
|
||||
def test_vision_retry_on_429_then_success(self):
|
||||
"""vision 遇到 429 后重试成功"""
|
||||
client, mock_openai, sleep_fn = _make_client()
|
||||
|
||||
mock_openai.chat.completions.create.side_effect = [
|
||||
_make_rate_limit_error(),
|
||||
_make_completion_response("识别结果"),
|
||||
]
|
||||
|
||||
result = client.vision("sys", "base64data")
|
||||
|
||||
assert result == "识别结果"
|
||||
assert sleep_fn.call_count == 1
|
||||
sleep_fn.assert_called_with(1)
|
||||
|
||||
def test_vision_retry_exhausted_raises_api_error(self):
|
||||
"""vision 重试耗尽抛出 ApiError"""
|
||||
client, mock_openai, sleep_fn = _make_client()
|
||||
|
||||
mock_openai.chat.completions.create.side_effect = [
|
||||
_make_rate_limit_error(),
|
||||
_make_rate_limit_error(),
|
||||
_make_rate_limit_error(),
|
||||
_make_rate_limit_error(),
|
||||
]
|
||||
|
||||
with pytest.raises(ApiError, match="速率限制重试耗尽"):
|
||||
client.vision("sys", "base64data")
|
||||
|
||||
assert sleep_fn.call_count == 3
|
||||
|
||||
def test_vision_non_429_error_raises_immediately(self):
|
||||
"""vision 遇到非 429 错误立即抛出"""
|
||||
client, mock_openai, sleep_fn = _make_client()
|
||||
|
||||
mock_openai.chat.completions.create.side_effect = _make_api_error(401, "Unauthorized")
|
||||
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
client.vision("sys", "base64data")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
sleep_fn.assert_not_called()
|
||||
|
||||
|
||||
class TestRetryDelays:
|
||||
"""重试延迟验证"""
|
||||
|
||||
def test_retry_delays_are_exponential(self):
|
||||
"""验证重试延迟为 1, 2, 4 秒"""
|
||||
assert ApiClient.RETRY_DELAYS == [1, 2, 4]
|
||||
assert ApiClient.MAX_RETRIES == 3
|
||||
|
||||
def test_single_retry_uses_correct_delay(self):
|
||||
"""单次 429 后使用 1 秒延迟"""
|
||||
client, mock_openai, sleep_fn = _make_client()
|
||||
|
||||
mock_openai.chat.completions.create.side_effect = [
|
||||
_make_rate_limit_error(),
|
||||
_make_completion_response("ok"),
|
||||
]
|
||||
|
||||
client.chat("sys", "user")
|
||||
|
||||
sleep_fn.assert_called_once_with(1)
|
||||
|
||||
def test_three_retries_use_correct_delays(self):
|
||||
"""三次 429 后使用 1, 2, 4 秒延迟"""
|
||||
client, mock_openai, sleep_fn = _make_client()
|
||||
|
||||
mock_openai.chat.completions.create.side_effect = [
|
||||
_make_rate_limit_error(),
|
||||
_make_rate_limit_error(),
|
||||
_make_rate_limit_error(),
|
||||
_make_completion_response("ok"),
|
||||
]
|
||||
|
||||
result = client.chat("sys", "user")
|
||||
|
||||
assert result == "ok"
|
||||
assert sleep_fn.call_count == 3
|
||||
calls = [c.args[0] for c in sleep_fn.call_args_list]
|
||||
assert calls == [1, 2, 4]
|
||||
317
tests/test_chunker.py
Normal file
317
tests/test_chunker.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""AIChunker 单元测试"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, call
|
||||
|
||||
from chunker import AIChunker
|
||||
from exceptions import ApiError
|
||||
from models import Chunk
|
||||
|
||||
|
||||
def _make_chunker(api_response="标题1\n\n内容1", delimiter="---"):
|
||||
"""创建注入 mock ApiClient 的 AIChunker"""
|
||||
mock_api = MagicMock()
|
||||
mock_api.chat.return_value = api_response
|
||||
chunker = AIChunker(api_client=mock_api, delimiter=delimiter)
|
||||
return chunker, mock_api
|
||||
|
||||
|
||||
class TestParseResponse:
|
||||
"""_parse_response() 方法测试"""
|
||||
|
||||
def test_single_chunk(self):
|
||||
"""解析单个分块"""
|
||||
chunker, _ = _make_chunker()
|
||||
result = chunker._parse_response("摘要标题\n\n这是分块内容")
|
||||
assert len(result) == 1
|
||||
assert result[0].title == "摘要标题"
|
||||
assert result[0].content == "这是分块内容"
|
||||
|
||||
def test_multiple_chunks(self):
|
||||
"""解析多个分块(用 delimiter 分隔)"""
|
||||
chunker, _ = _make_chunker()
|
||||
response = "标题一\n\n内容一\n\n---\n标题二\n\n内容二"
|
||||
result = chunker._parse_response(response)
|
||||
assert len(result) == 2
|
||||
assert result[0].title == "标题一"
|
||||
assert result[0].content == "内容一"
|
||||
assert result[1].title == "标题二"
|
||||
assert result[1].content == "内容二"
|
||||
|
||||
def test_skip_empty_parts(self):
|
||||
"""跳过空片段"""
|
||||
chunker, _ = _make_chunker()
|
||||
response = "标题\n\n内容\n\n---\n\n---\n"
|
||||
result = chunker._parse_response(response)
|
||||
assert len(result) == 1
|
||||
assert result[0].title == "标题"
|
||||
|
||||
def test_title_only_no_content(self):
|
||||
"""只有标题没有内容的分块"""
|
||||
chunker, _ = _make_chunker()
|
||||
result = chunker._parse_response("仅标题")
|
||||
assert len(result) == 1
|
||||
assert result[0].title == "仅标题"
|
||||
assert result[0].content == ""
|
||||
|
||||
def test_empty_response_raises_error(self):
|
||||
"""空响应抛出 ApiError"""
|
||||
chunker, _ = _make_chunker()
|
||||
with pytest.raises(ApiError, match="API 返回空响应"):
|
||||
chunker._parse_response("")
|
||||
|
||||
def test_whitespace_only_response_raises_error(self):
|
||||
"""纯空白响应抛出 ApiError"""
|
||||
chunker, _ = _make_chunker()
|
||||
with pytest.raises(ApiError, match="API 返回空响应"):
|
||||
chunker._parse_response(" \n\n ")
|
||||
|
||||
def test_custom_delimiter(self):
|
||||
"""使用自定义分隔符解析"""
|
||||
chunker, _ = _make_chunker(delimiter="===")
|
||||
response = "标题A\n\n内容A\n\n===\n标题B\n\n内容B"
|
||||
result = chunker._parse_response(response)
|
||||
assert len(result) == 2
|
||||
assert result[0].title == "标题A"
|
||||
assert result[1].title == "标题B"
|
||||
|
||||
def test_strips_whitespace_from_parts(self):
|
||||
"""去除片段首尾空白"""
|
||||
chunker, _ = _make_chunker()
|
||||
response = " \n标题\n\n内容\n "
|
||||
result = chunker._parse_response(response)
|
||||
assert len(result) == 1
|
||||
assert result[0].title == "标题"
|
||||
assert result[0].content == "内容"
|
||||
|
||||
|
||||
class TestPreSplit:
|
||||
"""_pre_split() 方法测试"""
|
||||
|
||||
def test_short_text_single_segment(self):
|
||||
"""短文本不需要切分,返回单段"""
|
||||
chunker, _ = _make_chunker()
|
||||
text = "短文本内容"
|
||||
result = chunker._pre_split(text)
|
||||
assert len(result) == 1
|
||||
assert result[0] == text
|
||||
|
||||
def test_split_on_paragraph_boundary(self):
|
||||
"""在段落边界(双换行)处切分"""
|
||||
chunker, _ = _make_chunker()
|
||||
para1 = "a" * 7000
|
||||
para2 = "b" * 7000
|
||||
text = f"{para1}\n\n{para2}"
|
||||
result = chunker._pre_split(text)
|
||||
assert len(result) == 2
|
||||
assert result[0] == para1
|
||||
assert result[1] == para2
|
||||
|
||||
def test_greedy_merge_paragraphs(self):
|
||||
"""贪心合并段落,不超过 PRE_SPLIT_SIZE"""
|
||||
chunker, _ = _make_chunker()
|
||||
para1 = "a" * 4000
|
||||
para2 = "b" * 4000
|
||||
para3 = "c" * 5000
|
||||
text = f"{para1}\n\n{para2}\n\n{para3}"
|
||||
result = chunker._pre_split(text)
|
||||
# para1 + \n\n + para2 = 8002 <= 12000, so they merge
|
||||
# adding para3 would be 8002 + 2 + 5000 = 13004 > 12000
|
||||
assert len(result) == 2
|
||||
assert result[0] == f"{para1}\n\n{para2}"
|
||||
assert result[1] == para3
|
||||
|
||||
def test_single_paragraph_exceeds_limit_split_by_newline(self):
|
||||
"""单段落超限时按单换行符切分"""
|
||||
chunker, _ = _make_chunker()
|
||||
line1 = "x" * 7000
|
||||
line2 = "y" * 7000
|
||||
# 单个段落(无双换行),但有单换行
|
||||
text = f"{line1}\n{line2}"
|
||||
result = chunker._pre_split(text)
|
||||
assert len(result) == 2
|
||||
assert result[0] == line1
|
||||
assert result[1] == line2
|
||||
|
||||
def test_hard_split_very_long_line(self):
|
||||
"""超长单行硬切分"""
|
||||
chunker, _ = _make_chunker()
|
||||
# 一行超过 PRE_SPLIT_SIZE 且无段落/换行分隔
|
||||
text = "a" * 30000
|
||||
result = chunker._pre_split(text)
|
||||
assert len(result) >= 2
|
||||
for seg in result:
|
||||
assert len(seg) <= chunker.PRE_SPLIT_SIZE
|
||||
# 拼接后内容不丢失
|
||||
assert "".join(result) == text
|
||||
|
||||
def test_no_content_loss(self):
|
||||
"""预切分后拼接不丢失内容"""
|
||||
chunker, _ = _make_chunker()
|
||||
para1 = "a" * 5000
|
||||
para2 = "b" * 5000
|
||||
para3 = "c" * 5000
|
||||
text = f"{para1}\n\n{para2}\n\n{para3}"
|
||||
result = chunker._pre_split(text)
|
||||
joined = "\n\n".join(result)
|
||||
assert para1 in joined
|
||||
assert para2 in joined
|
||||
assert para3 in joined
|
||||
|
||||
def test_each_segment_within_limit(self):
|
||||
"""每段不超过 PRE_SPLIT_SIZE"""
|
||||
chunker, _ = _make_chunker()
|
||||
paragraphs = ["p" * 5000 for _ in range(10)]
|
||||
text = "\n\n".join(paragraphs)
|
||||
result = chunker._pre_split(text)
|
||||
for seg in result:
|
||||
assert len(seg) <= chunker.PRE_SPLIT_SIZE
|
||||
|
||||
|
||||
class TestCallApi:
|
||||
"""_call_api() 方法测试"""
|
||||
|
||||
def test_calls_api_with_correct_prompts(self):
|
||||
"""调用 API 时使用正确的提示词"""
|
||||
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
|
||||
chunker._call_api("测试文本")
|
||||
|
||||
mock_api.chat.assert_called_once()
|
||||
args = mock_api.chat.call_args
|
||||
system_prompt = args[0][0] if args[0] else args[1].get("system_prompt")
|
||||
user_content = args[0][1] if len(args[0]) > 1 else args[1].get("user_content")
|
||||
# 系统提示词应包含 delimiter
|
||||
assert "---" in system_prompt
|
||||
# 用户提示词应包含文本内容
|
||||
assert "测试文本" in user_content
|
||||
|
||||
def test_returns_parsed_chunks(self):
|
||||
"""返回解析后的 Chunk 列表"""
|
||||
chunker, _ = _make_chunker(api_response="标题\n\n内容")
|
||||
result = chunker._call_api("文本")
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], Chunk)
|
||||
assert result[0].title == "标题"
|
||||
|
||||
def test_api_error_propagates(self):
|
||||
"""API 错误向上传播"""
|
||||
chunker, mock_api = _make_chunker()
|
||||
mock_api.chat.side_effect = ApiError("调用失败")
|
||||
with pytest.raises(ApiError, match="调用失败"):
|
||||
chunker._call_api("文本")
|
||||
|
||||
|
||||
class TestChunk:
|
||||
"""chunk() 方法测试"""
|
||||
|
||||
def test_short_text_single_api_call(self):
|
||||
"""短文本(≤ PRE_SPLIT_SIZE)只调用一次 API"""
|
||||
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
|
||||
text = "短文本" * 100 # well under 12000
|
||||
result = chunker.chunk(text)
|
||||
|
||||
assert len(result) == 1
|
||||
assert mock_api.chat.call_count == 1
|
||||
|
||||
def test_long_text_multiple_api_calls(self):
|
||||
"""长文本(> PRE_SPLIT_SIZE)预切分后多次调用 API"""
|
||||
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
|
||||
# 创建超过 PRE_SPLIT_SIZE 的文本
|
||||
text = ("段落内容" * 2000 + "\n\n") * 5
|
||||
result = chunker.chunk(text)
|
||||
|
||||
assert mock_api.chat.call_count > 1
|
||||
assert len(result) >= 1
|
||||
|
||||
def test_progress_callback_called(self):
|
||||
"""长文本时 on_progress 回调被正确调用"""
|
||||
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
|
||||
# 构造需要 pre_split 的长文本
|
||||
para1 = "a" * 7000
|
||||
para2 = "b" * 7000
|
||||
para3 = "c" * 7000
|
||||
text = f"{para1}\n\n{para2}\n\n{para3}"
|
||||
|
||||
progress_calls = []
|
||||
def on_progress(current, total):
|
||||
progress_calls.append((current, total))
|
||||
|
||||
chunker.chunk(text, on_progress=on_progress)
|
||||
|
||||
# 验证 progress 回调
|
||||
assert len(progress_calls) > 0
|
||||
total = progress_calls[0][1]
|
||||
for i, (current, t) in enumerate(progress_calls):
|
||||
assert current == i + 1
|
||||
assert t == total
|
||||
|
||||
def test_no_progress_callback_no_error(self):
|
||||
"""不传 on_progress 不报错"""
|
||||
chunker, _ = _make_chunker(api_response="标题\n\n内容")
|
||||
text = "短文本"
|
||||
result = chunker.chunk(text)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_short_text_no_progress_callback(self):
|
||||
"""短文本不触发 on_progress 回调"""
|
||||
chunker, _ = _make_chunker(api_response="标题\n\n内容")
|
||||
progress_calls = []
|
||||
chunker.chunk("短文本", on_progress=lambda c, t: progress_calls.append((c, t)))
|
||||
assert len(progress_calls) == 0
|
||||
|
||||
def test_chunks_aggregated_from_segments(self):
|
||||
"""多段的 chunks 被正确聚合"""
|
||||
mock_api = MagicMock()
|
||||
# 每次 API 调用返回不同的响应
|
||||
mock_api.chat.side_effect = [
|
||||
"标题A\n\n内容A",
|
||||
"标题B\n\n内容B",
|
||||
]
|
||||
chunker = AIChunker(api_client=mock_api, delimiter="---")
|
||||
|
||||
# 构造需要 2 段的文本(每段 > 12000/2 使得合并后超限)
|
||||
text = "a" * 7000 + "\n\n" + "b" * 7000
|
||||
result = chunker.chunk(text)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].title == "标题A"
|
||||
assert result[1].title == "标题B"
|
||||
|
||||
|
||||
class TestHardSplit:
|
||||
"""_hard_split() 和 _find_sentence_boundary() 测试"""
|
||||
|
||||
def test_hard_split_preserves_content(self):
|
||||
"""硬切分不丢失内容"""
|
||||
chunker, _ = _make_chunker()
|
||||
text = "x" * 30000
|
||||
result = chunker._hard_split(text)
|
||||
assert "".join(result) == text
|
||||
|
||||
def test_hard_split_respects_limit(self):
|
||||
"""硬切分每段不超过 PRE_SPLIT_SIZE"""
|
||||
chunker, _ = _make_chunker()
|
||||
text = "x" * 30000
|
||||
result = chunker._hard_split(text)
|
||||
for seg in result:
|
||||
assert len(seg) <= chunker.PRE_SPLIT_SIZE
|
||||
|
||||
def test_sentence_boundary_split(self):
|
||||
"""硬切分优先在句子边界切分"""
|
||||
chunker, _ = _make_chunker()
|
||||
# 构造在中间有句号的超长文本,总长超过 12000
|
||||
text = "a" * 9000 + "。" + "b" * 5000
|
||||
result = chunker._hard_split(text)
|
||||
assert len(result) == 2
|
||||
# 第一段应在句号后切分
|
||||
assert result[0].endswith("。")
|
||||
|
||||
def test_find_sentence_boundary_chinese(self):
|
||||
"""中文句号作为句子边界"""
|
||||
boundary = AIChunker._find_sentence_boundary("你好世界。再见")
|
||||
assert boundary == 5 # "你好世界。" 的长度
|
||||
|
||||
def test_find_sentence_boundary_english(self):
|
||||
"""英文句号作为句子边界"""
|
||||
boundary = AIChunker._find_sentence_boundary("Hello world. Bye")
|
||||
assert boundary == 12 # index of '.' is 11, return 11+1=12
|
||||
105
tests/test_csv_parser.py
Normal file
105
tests/test_csv_parser.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""CsvParser 单元测试"""
|
||||
|
||||
import pytest
|
||||
|
||||
from exceptions import ParseError
|
||||
from parsers.csv_parser import CsvParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser():
|
||||
return CsvParser()
|
||||
|
||||
|
||||
class TestSupportedExtensions:
|
||||
def test_supports_csv(self, parser):
|
||||
assert ".csv" in parser.supported_extensions()
|
||||
|
||||
def test_only_one_extension(self, parser):
|
||||
assert len(parser.supported_extensions()) == 1
|
||||
|
||||
|
||||
class TestParse:
|
||||
def test_basic_csv(self, parser, tmp_path):
|
||||
f = tmp_path / "basic.csv"
|
||||
f.write_text("name,age,city\nAlice,30,Beijing\nBob,25,Shanghai\n", encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert "| name | age | city |" in result
|
||||
assert "| --- | --- | --- |" in result
|
||||
assert "| Alice | 30 | Beijing |" in result
|
||||
assert "| Bob | 25 | Shanghai |" in result
|
||||
|
||||
def test_empty_file(self, parser, tmp_path):
|
||||
f = tmp_path / "empty.csv"
|
||||
f.write_bytes(b"")
|
||||
assert parser.parse(str(f)) == ""
|
||||
|
||||
def test_header_only(self, parser, tmp_path):
|
||||
f = tmp_path / "header.csv"
|
||||
f.write_text("col1,col2,col3\n", encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert "| col1 | col2 | col3 |" in result
|
||||
assert "| --- | --- | --- |" in result
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
|
||||
def test_pipe_char_escaped(self, parser, tmp_path):
|
||||
f = tmp_path / "pipe.csv"
|
||||
f.write_text('header\n"a|b"\n', encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert "|" in result
|
||||
assert "a|b" in result
|
||||
|
||||
def test_newline_in_cell(self, parser, tmp_path):
|
||||
f = tmp_path / "newline.csv"
|
||||
f.write_text('header\n"line1\nline2"\n', encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert "<br>" in result
|
||||
assert "line1<br>line2" in result
|
||||
|
||||
def test_gbk_encoded_csv(self, parser, tmp_path):
|
||||
f = tmp_path / "gbk.csv"
|
||||
content = "姓名,年龄,城市\n张三,28,北京\n李四,32,上海\n"
|
||||
f.write_bytes(content.encode("gbk"))
|
||||
result = parser.parse(str(f))
|
||||
assert "张三" in result
|
||||
assert "北京" in result
|
||||
|
||||
def test_nonexistent_file_raises(self, parser):
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/nonexistent/path/data.csv")
|
||||
assert "data.csv" in exc_info.value.file_name
|
||||
assert exc_info.value.reason != ""
|
||||
|
||||
def test_short_row_padded(self, parser, tmp_path):
|
||||
"""Rows shorter than header should be padded with empty cells."""
|
||||
f = tmp_path / "short.csv"
|
||||
f.write_text("a,b,c\n1\n", encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert "| 1 | | |" in result
|
||||
|
||||
def test_result_ends_with_newline(self, parser, tmp_path):
|
||||
f = tmp_path / "trail.csv"
|
||||
f.write_text("h1,h2\nv1,v2\n", encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert result.endswith("\n")
|
||||
|
||||
|
||||
class TestEscapeCell:
|
||||
def test_no_special_chars(self):
|
||||
assert CsvParser._escape_cell("hello") == "hello"
|
||||
|
||||
def test_pipe_escaped(self):
|
||||
assert CsvParser._escape_cell("a|b") == "a|b"
|
||||
|
||||
def test_newline_escaped(self):
|
||||
assert CsvParser._escape_cell("a\nb") == "a<br>b"
|
||||
|
||||
def test_crlf_escaped(self):
|
||||
assert CsvParser._escape_cell("a\r\nb") == "a<br>b"
|
||||
|
||||
def test_cr_escaped(self):
|
||||
assert CsvParser._escape_cell("a\rb") == "a<br>b"
|
||||
|
||||
def test_combined_escapes(self):
|
||||
assert CsvParser._escape_cell("a|b\nc") == "a|b<br>c"
|
||||
260
tests/test_doc_parser.py
Normal file
260
tests/test_doc_parser.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""DocParser 单元测试"""
|
||||
|
||||
import pytest
|
||||
from docx import Document
|
||||
from docx.shared import Pt
|
||||
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
||||
|
||||
from exceptions import ParseError
|
||||
from parsers.doc_parser import DocParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser():
|
||||
return DocParser()
|
||||
|
||||
|
||||
def _create_docx(path, paragraphs=None, tables=None):
|
||||
"""
|
||||
创建测试用 Word 文档。
|
||||
|
||||
Args:
|
||||
path: 输出文件路径
|
||||
paragraphs: 列表,每个元素是 dict:
|
||||
- text: 段落文本
|
||||
- style: 可选,样式名(如 'Heading 1')
|
||||
- font_size: 可选,字体大小 (Pt)
|
||||
- bold: 可选,是否加粗
|
||||
tables: 列表,每个元素是二维列表(行×列的文本)
|
||||
"""
|
||||
doc = Document()
|
||||
# 清除默认的空段落
|
||||
for p in doc.paragraphs:
|
||||
p._element.getparent().remove(p._element)
|
||||
|
||||
if paragraphs:
|
||||
for para_info in paragraphs:
|
||||
if isinstance(para_info, str):
|
||||
doc.add_paragraph(para_info)
|
||||
else:
|
||||
text = para_info.get("text", "")
|
||||
style = para_info.get("style", None)
|
||||
font_size = para_info.get("font_size", None)
|
||||
bold = para_info.get("bold", None)
|
||||
|
||||
if style:
|
||||
p = doc.add_paragraph(text, style=style)
|
||||
else:
|
||||
p = doc.add_paragraph(text)
|
||||
|
||||
if font_size is not None or bold is not None:
|
||||
# 需要通过 run 设置字体属性
|
||||
# 清除默认 run,重新添加
|
||||
for run in p.runs:
|
||||
if font_size is not None:
|
||||
run.font.size = Pt(font_size)
|
||||
if bold is not None:
|
||||
run.bold = bold
|
||||
|
||||
if tables:
|
||||
for table_data in tables:
|
||||
if not table_data:
|
||||
continue
|
||||
rows = len(table_data)
|
||||
cols = len(table_data[0]) if table_data else 0
|
||||
table = doc.add_table(rows=rows, cols=cols)
|
||||
for i, row_data in enumerate(table_data):
|
||||
for j, cell_text in enumerate(row_data):
|
||||
table.rows[i].cells[j].text = cell_text
|
||||
|
||||
doc.save(str(path))
|
||||
|
||||
|
||||
class TestSupportedExtensions:
|
||||
def test_supports_docx(self, parser):
|
||||
assert ".docx" in parser.supported_extensions()
|
||||
|
||||
def test_only_one_extension(self, parser):
|
||||
assert len(parser.supported_extensions()) == 1
|
||||
|
||||
|
||||
class TestParse:
|
||||
def test_parse_simple_text(self, parser, tmp_path):
|
||||
docx_path = tmp_path / "simple.docx"
|
||||
_create_docx(docx_path, paragraphs=["Hello, world!"])
|
||||
result = parser.parse(str(docx_path))
|
||||
assert "Hello, world!" in result
|
||||
|
||||
def test_parse_multiple_paragraphs(self, parser, tmp_path):
|
||||
docx_path = tmp_path / "multi.docx"
|
||||
_create_docx(docx_path, paragraphs=["First paragraph", "Second paragraph"])
|
||||
result = parser.parse(str(docx_path))
|
||||
assert "First paragraph" in result
|
||||
assert "Second paragraph" in result
|
||||
|
||||
def test_heading_by_style_name(self, parser, tmp_path):
|
||||
"""Heading style should produce Markdown heading"""
|
||||
docx_path = tmp_path / "heading.docx"
|
||||
_create_docx(docx_path, paragraphs=[
|
||||
{"text": "Main Title", "style": "Heading 1"},
|
||||
{"text": "Body text"},
|
||||
])
|
||||
result = parser.parse(str(docx_path))
|
||||
assert "# Main Title" in result
|
||||
# Should be exactly H1, not H2
|
||||
assert "## Main Title" not in result
|
||||
|
||||
def test_heading2_by_style_name(self, parser, tmp_path):
|
||||
docx_path = tmp_path / "h2.docx"
|
||||
_create_docx(docx_path, paragraphs=[
|
||||
{"text": "Section Title", "style": "Heading 2"},
|
||||
{"text": "Some content"},
|
||||
])
|
||||
result = parser.parse(str(docx_path))
|
||||
assert "## Section Title" in result
|
||||
assert "### Section Title" not in result
|
||||
|
||||
def test_heading3_by_style_name(self, parser, tmp_path):
|
||||
docx_path = tmp_path / "h3.docx"
|
||||
_create_docx(docx_path, paragraphs=[
|
||||
{"text": "Subsection", "style": "Heading 3"},
|
||||
])
|
||||
result = parser.parse(str(docx_path))
|
||||
assert "### Subsection" in result
|
||||
|
||||
def test_heading_by_font_size_bold(self, parser, tmp_path):
|
||||
"""Bold text with large font size should be detected as heading"""
|
||||
docx_path = tmp_path / "font_heading.docx"
|
||||
_create_docx(docx_path, paragraphs=[
|
||||
{"text": "Big Bold Title", "font_size": 36, "bold": True},
|
||||
{"text": "Normal text"},
|
||||
])
|
||||
result = parser.parse(str(docx_path))
|
||||
assert "# Big Bold Title" in result
|
||||
|
||||
def test_heading_h2_by_font_size(self, parser, tmp_path):
|
||||
docx_path = tmp_path / "font_h2.docx"
|
||||
_create_docx(docx_path, paragraphs=[
|
||||
{"text": "H2 Title", "font_size": 28, "bold": True},
|
||||
{"text": "Normal text"},
|
||||
])
|
||||
result = parser.parse(str(docx_path))
|
||||
assert "## H2 Title" in result
|
||||
|
||||
def test_heading_h5_by_font_size(self, parser, tmp_path):
|
||||
docx_path = tmp_path / "font_h5.docx"
|
||||
_create_docx(docx_path, paragraphs=[
|
||||
{"text": "H5 Title", "font_size": 20, "bold": True},
|
||||
{"text": "Normal text"},
|
||||
])
|
||||
result = parser.parse(str(docx_path))
|
||||
assert "##### H5 Title" in result
|
||||
|
||||
def test_no_heading_without_bold(self, parser, tmp_path):
|
||||
"""Large font without bold should NOT be detected as heading via font size"""
|
||||
docx_path = tmp_path / "no_bold.docx"
|
||||
_create_docx(docx_path, paragraphs=[
|
||||
{"text": "Large Not Bold", "font_size": 36, "bold": False},
|
||||
])
|
||||
result = parser.parse(str(docx_path))
|
||||
assert "# Large Not Bold" not in result
|
||||
assert "Large Not Bold" in result
|
||||
|
||||
def test_simple_table(self, parser, tmp_path):
|
||||
docx_path = tmp_path / "table.docx"
|
||||
_create_docx(docx_path, tables=[
|
||||
[["Name", "Age"], ["Alice", "30"], ["Bob", "25"]],
|
||||
])
|
||||
result = parser.parse(str(docx_path))
|
||||
assert "| Name | Age |" in result
|
||||
assert "| --- | --- |" in result
|
||||
assert "| Alice | 30 |" in result
|
||||
assert "| Bob | 25 |" in result
|
||||
|
||||
def test_table_with_pipe_in_cell(self, parser, tmp_path):
|
||||
"""Pipe characters in cells should be escaped"""
|
||||
docx_path = tmp_path / "pipe.docx"
|
||||
_create_docx(docx_path, tables=[
|
||||
[["Header"], ["value|with|pipes"]],
|
||||
])
|
||||
result = parser.parse(str(docx_path))
|
||||
assert "|" in result
|
||||
assert "value|with|pipes" in result
|
||||
|
||||
def test_mixed_paragraphs_and_tables(self, parser, tmp_path):
|
||||
"""Document with both paragraphs and tables"""
|
||||
docx_path = tmp_path / "mixed.docx"
|
||||
doc = Document()
|
||||
# Clear default paragraph
|
||||
for p in doc.paragraphs:
|
||||
p._element.getparent().remove(p._element)
|
||||
|
||||
doc.add_paragraph("Introduction", style="Heading 1")
|
||||
doc.add_paragraph("Some intro text.")
|
||||
table = doc.add_table(rows=2, cols=2)
|
||||
table.rows[0].cells[0].text = "Col1"
|
||||
table.rows[0].cells[1].text = "Col2"
|
||||
table.rows[1].cells[0].text = "A"
|
||||
table.rows[1].cells[1].text = "B"
|
||||
doc.add_paragraph("Conclusion")
|
||||
doc.save(str(docx_path))
|
||||
|
||||
result = parser.parse(str(docx_path))
|
||||
assert "# Introduction" in result
|
||||
assert "Some intro text." in result
|
||||
assert "| Col1 | Col2 |" in result
|
||||
assert "| A | B |" in result
|
||||
assert "Conclusion" in result
|
||||
|
||||
def test_empty_document(self, parser, tmp_path):
|
||||
docx_path = tmp_path / "empty.docx"
|
||||
doc = Document()
|
||||
# Clear default paragraph
|
||||
for p in doc.paragraphs:
|
||||
p._element.getparent().remove(p._element)
|
||||
doc.save(str(docx_path))
|
||||
result = parser.parse(str(docx_path))
|
||||
assert result.strip() == ""
|
||||
|
||||
def test_empty_paragraphs_skipped(self, parser, tmp_path):
|
||||
docx_path = tmp_path / "empty_para.docx"
|
||||
_create_docx(docx_path, paragraphs=["", "Actual content", ""])
|
||||
result = parser.parse(str(docx_path))
|
||||
assert "Actual content" in result
|
||||
# Empty paragraphs should not produce extra lines
|
||||
assert result.strip() == "Actual content"
|
||||
|
||||
def test_nonexistent_file_raises(self, parser):
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/nonexistent/path/file.docx")
|
||||
assert "file.docx" in exc_info.value.file_name
|
||||
assert exc_info.value.reason != ""
|
||||
|
||||
def test_corrupted_file_raises(self, parser, tmp_path):
|
||||
docx_path = tmp_path / "corrupted.docx"
|
||||
docx_path.write_bytes(b"this is not a docx file at all")
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse(str(docx_path))
|
||||
assert "corrupted.docx" in exc_info.value.file_name
|
||||
|
||||
def test_parse_error_contains_filename(self, parser):
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/no/such/report.docx")
|
||||
assert exc_info.value.file_name == "report.docx"
|
||||
|
||||
def test_multiple_heading_levels(self, parser, tmp_path):
|
||||
"""Test document with multiple heading levels via styles"""
|
||||
docx_path = tmp_path / "levels.docx"
|
||||
_create_docx(docx_path, paragraphs=[
|
||||
{"text": "Title", "style": "Heading 1"},
|
||||
{"text": "Chapter", "style": "Heading 2"},
|
||||
{"text": "Section", "style": "Heading 3"},
|
||||
{"text": "Body text"},
|
||||
])
|
||||
result = parser.parse(str(docx_path))
|
||||
assert "# Title" in result
|
||||
assert "## Chapter" in result
|
||||
assert "### Section" in result
|
||||
assert "Body text" in result
|
||||
# Body text should not have heading prefix
|
||||
assert "# Body text" not in result
|
||||
64
tests/test_exceptions.py
Normal file
64
tests/test_exceptions.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""异常类型单元测试"""
|
||||
|
||||
import pytest
|
||||
from exceptions import ApiError, ParseError, RateLimitError, UnsupportedFormatError
|
||||
|
||||
|
||||
class TestParseError:
|
||||
def test_attributes(self):
|
||||
err = ParseError("test.pdf", "文件损坏")
|
||||
assert err.file_name == "test.pdf"
|
||||
assert err.reason == "文件损坏"
|
||||
|
||||
def test_message_format(self):
|
||||
err = ParseError("data.csv", "编码无法识别")
|
||||
assert str(err) == "解析失败 [data.csv]: 编码无法识别"
|
||||
|
||||
def test_is_exception(self):
|
||||
err = ParseError("f.txt", "reason")
|
||||
assert isinstance(err, Exception)
|
||||
|
||||
|
||||
class TestUnsupportedFormatError:
|
||||
def test_inherits_parse_error(self):
|
||||
err = UnsupportedFormatError("file.xyz", ".xyz")
|
||||
assert isinstance(err, ParseError)
|
||||
|
||||
def test_extension_attribute(self):
|
||||
err = UnsupportedFormatError("file.abc", ".abc")
|
||||
assert err.extension == ".abc"
|
||||
|
||||
def test_message_format(self):
|
||||
err = UnsupportedFormatError("doc.bin", ".bin")
|
||||
assert str(err) == "解析失败 [doc.bin]: 不支持的文件格式: .bin"
|
||||
|
||||
def test_file_name_propagated(self):
|
||||
err = UnsupportedFormatError("my_file.xyz", ".xyz")
|
||||
assert err.file_name == "my_file.xyz"
|
||||
assert err.reason == "不支持的文件格式: .xyz"
|
||||
|
||||
|
||||
class TestApiError:
|
||||
def test_with_status_code(self):
|
||||
err = ApiError("服务端错误", status_code=500)
|
||||
assert err.status_code == 500
|
||||
assert str(err) == "服务端错误"
|
||||
|
||||
def test_without_status_code(self):
|
||||
err = ApiError("网络错误")
|
||||
assert err.status_code is None
|
||||
assert str(err) == "网络错误"
|
||||
|
||||
def test_is_exception(self):
|
||||
assert isinstance(ApiError("msg"), Exception)
|
||||
|
||||
|
||||
class TestRateLimitError:
|
||||
def test_inherits_api_error(self):
|
||||
err = RateLimitError("速率限制", status_code=429)
|
||||
assert isinstance(err, ApiError)
|
||||
|
||||
def test_status_code(self):
|
||||
err = RateLimitError("速率限制", status_code=429)
|
||||
assert err.status_code == 429
|
||||
assert str(err) == "速率限制"
|
||||
145
tests/test_html_parser.py
Normal file
145
tests/test_html_parser.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""HtmlParser 单元测试"""
|
||||
|
||||
import pytest
|
||||
|
||||
from exceptions import ParseError
|
||||
from parsers.html_parser import HtmlParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser():
|
||||
return HtmlParser()
|
||||
|
||||
|
||||
class TestSupportedExtensions:
|
||||
def test_supports_html(self, parser):
|
||||
assert ".html" in parser.supported_extensions()
|
||||
|
||||
def test_supports_htm(self, parser):
|
||||
assert ".htm" in parser.supported_extensions()
|
||||
|
||||
def test_only_two_extensions(self, parser):
|
||||
assert len(parser.supported_extensions()) == 2
|
||||
|
||||
|
||||
class TestParse:
|
||||
def test_parse_simple_html(self, parser, tmp_path):
|
||||
f = tmp_path / "test.html"
|
||||
f.write_text("<html><body><p>Hello, world!</p></body></html>", encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert "Hello, world!" in result
|
||||
|
||||
def test_parse_htm_extension(self, parser, tmp_path):
|
||||
f = tmp_path / "test.htm"
|
||||
f.write_text("<html><body><p>HTM file</p></body></html>", encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert "HTM file" in result
|
||||
|
||||
def test_parse_empty_file(self, parser, tmp_path):
|
||||
f = tmp_path / "empty.html"
|
||||
f.write_bytes(b"")
|
||||
assert parser.parse(str(f)) == ""
|
||||
|
||||
def test_removes_script_tags(self, parser, tmp_path):
|
||||
f = tmp_path / "script.html"
|
||||
html = "<html><body><script>alert('xss');</script><p>Content</p></body></html>"
|
||||
f.write_text(html, encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert "alert" not in result
|
||||
assert "script" not in result.lower() or "Content" in result
|
||||
assert "Content" in result
|
||||
|
||||
def test_removes_style_tags(self, parser, tmp_path):
|
||||
f = tmp_path / "style.html"
|
||||
html = "<html><head><style>body { color: red; }</style></head><body><p>Styled</p></body></html>"
|
||||
f.write_text(html, encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert "color: red" not in result
|
||||
assert "Styled" in result
|
||||
|
||||
def test_converts_headings_to_markdown(self, parser, tmp_path):
|
||||
f = tmp_path / "headings.html"
|
||||
html = "<html><body><h1>Title</h1><h2>Subtitle</h2><p>Text</p></body></html>"
|
||||
f.write_text(html, encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert "# Title" in result
|
||||
assert "## Subtitle" in result
|
||||
|
||||
def test_converts_links_to_markdown(self, parser, tmp_path):
|
||||
f = tmp_path / "links.html"
|
||||
html = '<html><body><a href="https://example.com">Example</a></body></html>'
|
||||
f.write_text(html, encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert "Example" in result
|
||||
assert "https://example.com" in result
|
||||
|
||||
def test_converts_lists_to_markdown(self, parser, tmp_path):
|
||||
f = tmp_path / "lists.html"
|
||||
html = "<html><body><ul><li>Item 1</li><li>Item 2</li></ul></body></html>"
|
||||
f.write_text(html, encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert "Item 1" in result
|
||||
assert "Item 2" in result
|
||||
|
||||
def test_meta_charset_detection(self, parser, tmp_path):
|
||||
f = tmp_path / "charset.html"
|
||||
html = '<html><head><meta charset="utf-8"></head><body><p>UTF-8 content</p></body></html>'
|
||||
f.write_text(html, encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert "UTF-8 content" in result
|
||||
|
||||
def test_gbk_encoded_html_with_meta_charset(self, parser, tmp_path):
|
||||
f = tmp_path / "gbk.html"
|
||||
html = '<html><head><meta charset="gbk"></head><body><p>你好世界,这是中文内容测试</p></body></html>'
|
||||
f.write_bytes(html.encode("gbk"))
|
||||
result = parser.parse(str(f))
|
||||
assert "你好世界" in result
|
||||
|
||||
def test_encoding_fallback_to_charset_normalizer(self, parser, tmp_path):
|
||||
f = tmp_path / "no_meta.html"
|
||||
html = "<html><body><p>Hello, this is a test with enough text for encoding detection to work properly.</p></body></html>"
|
||||
f.write_bytes(html.encode("utf-8"))
|
||||
result = parser.parse(str(f))
|
||||
assert "Hello" in result
|
||||
|
||||
def test_nonexistent_file_raises(self, parser):
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/nonexistent/path/file.html")
|
||||
assert "file.html" in exc_info.value.file_name
|
||||
assert exc_info.value.reason != ""
|
||||
|
||||
def test_parse_error_contains_filename(self, parser):
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/no/such/mypage.html")
|
||||
assert exc_info.value.file_name == "mypage.html"
|
||||
|
||||
def test_complex_html_removes_all_tags(self, parser, tmp_path):
|
||||
f = tmp_path / "complex.html"
|
||||
html = """<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Test Page</title>
|
||||
<style>.hidden { display: none; }</style>
|
||||
<script>var x = 1;</script>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Main Title</h1>
|
||||
<p>Paragraph with <strong>bold</strong> and <em>italic</em> text.</p>
|
||||
<script>console.log('inline script');</script>
|
||||
<table>
|
||||
<tr><th>Name</th><th>Value</th></tr>
|
||||
<tr><td>A</td><td>1</td></tr>
|
||||
</table>
|
||||
</div>
|
||||
</body>
|
||||
</html>"""
|
||||
f.write_text(html, encoding="utf-8")
|
||||
result = parser.parse(str(f))
|
||||
assert "Main Title" in result
|
||||
assert "bold" in result.lower() or "**bold**" in result
|
||||
assert "<script>" not in result
|
||||
assert "<style>" not in result
|
||||
assert "<div" not in result
|
||||
assert "console.log" not in result
|
||||
assert "var x" not in result
|
||||
135
tests/test_image_parser.py
Normal file
135
tests/test_image_parser.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""ImageParser 单元测试"""
|
||||
|
||||
import base64
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from exceptions import ApiError, ParseError
|
||||
from parsers.image_parser import ImageParser, DEFAULT_VISION_PROMPT
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_client():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser(mock_api_client):
|
||||
return ImageParser(mock_api_client)
|
||||
|
||||
|
||||
class TestSupportedExtensions:
|
||||
def test_supports_png(self, parser):
|
||||
assert ".png" in parser.supported_extensions()
|
||||
|
||||
def test_supports_jpg(self, parser):
|
||||
assert ".jpg" in parser.supported_extensions()
|
||||
|
||||
def test_supports_jpeg(self, parser):
|
||||
assert ".jpeg" in parser.supported_extensions()
|
||||
|
||||
def test_supports_bmp(self, parser):
|
||||
assert ".bmp" in parser.supported_extensions()
|
||||
|
||||
def test_supports_gif(self, parser):
|
||||
assert ".gif" in parser.supported_extensions()
|
||||
|
||||
def test_supports_webp(self, parser):
|
||||
assert ".webp" in parser.supported_extensions()
|
||||
|
||||
def test_has_six_extensions(self, parser):
|
||||
assert len(parser.supported_extensions()) == 6
|
||||
|
||||
|
||||
class TestParse:
|
||||
def test_successful_parse(self, mock_api_client, tmp_path):
|
||||
"""成功解析图片文件,返回 Vision API 的文本描述"""
|
||||
img = tmp_path / "photo.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 20)
|
||||
|
||||
mock_api_client.vision.return_value = "图片中包含一段中文文字"
|
||||
parser = ImageParser(mock_api_client)
|
||||
|
||||
result = parser.parse(str(img))
|
||||
|
||||
assert result == "图片中包含一段中文文字"
|
||||
mock_api_client.vision.assert_called_once()
|
||||
|
||||
def test_base64_encoding_correctness(self, mock_api_client, tmp_path):
|
||||
"""验证传递给 API 的 base64 编码与文件内容一致"""
|
||||
raw_bytes = b"\x89PNG\r\n\x1a\nSOME_IMAGE_DATA"
|
||||
img = tmp_path / "check.png"
|
||||
img.write_bytes(raw_bytes)
|
||||
|
||||
mock_api_client.vision.return_value = "ok"
|
||||
parser = ImageParser(mock_api_client)
|
||||
parser.parse(str(img))
|
||||
|
||||
call_args = mock_api_client.vision.call_args
|
||||
sent_base64 = call_args.kwargs.get("image_base64") or call_args[1].get("image_base64") or call_args[0][1]
|
||||
assert base64.b64decode(sent_base64) == raw_bytes
|
||||
|
||||
def test_system_prompt_passed_to_api(self, mock_api_client, tmp_path):
|
||||
"""验证使用了正确的系统提示词,且包含文件名上下文"""
|
||||
img = tmp_path / "prompt.png"
|
||||
img.write_bytes(b"\x00")
|
||||
|
||||
mock_api_client.vision.return_value = "text"
|
||||
parser = ImageParser(mock_api_client)
|
||||
parser.parse(str(img))
|
||||
|
||||
call_args = mock_api_client.vision.call_args
|
||||
sent_prompt = call_args.kwargs.get("system_prompt") or call_args[0][0]
|
||||
assert DEFAULT_VISION_PROMPT in sent_prompt
|
||||
assert "prompt" in sent_prompt
|
||||
|
||||
def test_file_not_found_raises_parse_error(self, parser):
|
||||
"""文件不存在时抛出 ParseError"""
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/nonexistent/path/missing.png")
|
||||
assert exc_info.value.file_name == "missing.png"
|
||||
assert "文件读取失败" in exc_info.value.reason
|
||||
|
||||
def test_unreadable_file_raises_parse_error(self, mock_api_client, tmp_path):
|
||||
"""文件无法读取时抛出 ParseError(使用目录路径模拟不可读文件)"""
|
||||
dir_path = tmp_path / "fakefile.jpg"
|
||||
dir_path.mkdir()
|
||||
|
||||
parser = ImageParser(mock_api_client)
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse(str(dir_path))
|
||||
assert exc_info.value.file_name == "fakefile.jpg"
|
||||
assert "文件读取失败" in exc_info.value.reason
|
||||
|
||||
def test_api_error_raises_parse_error(self, mock_api_client, tmp_path):
|
||||
"""API 调用失败时抛出 ParseError"""
|
||||
img = tmp_path / "api_fail.png"
|
||||
img.write_bytes(b"\x89PNG")
|
||||
|
||||
mock_api_client.vision.side_effect = ApiError("服务不可用", status_code=503)
|
||||
parser = ImageParser(mock_api_client)
|
||||
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse(str(img))
|
||||
assert exc_info.value.file_name == "api_fail.png"
|
||||
assert "Vision API 调用失败" in exc_info.value.reason
|
||||
|
||||
def test_api_rate_limit_error_raises_parse_error(self, mock_api_client, tmp_path):
|
||||
"""API 速率限制错误(经重试耗尽后)也被包装为 ParseError"""
|
||||
img = tmp_path / "rate.png"
|
||||
img.write_bytes(b"\x89PNG")
|
||||
|
||||
mock_api_client.vision.side_effect = ApiError("速率限制重试耗尽", status_code=429)
|
||||
parser = ImageParser(mock_api_client)
|
||||
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse(str(img))
|
||||
assert "Vision API 调用失败" in exc_info.value.reason
|
||||
|
||||
def test_parse_error_contains_filename_for_missing_file(self, parser):
|
||||
"""ParseError 包含正确的文件名"""
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/tmp/does_not_exist/myimage.jpeg")
|
||||
assert exc_info.value.file_name == "myimage.jpeg"
|
||||
assert exc_info.value.reason != ""
|
||||
182
tests/test_main.py
Normal file
182
tests/test_main.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""CLI 入口 main.py 单元测试"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from main import derive_output_path, build_parser, main
|
||||
from exceptions import ParseError, UnsupportedFormatError, ApiError
|
||||
|
||||
|
||||
class TestDeriveOutputPath:
|
||||
"""默认输出路径推导测试"""
|
||||
|
||||
def test_pdf_to_md(self):
|
||||
assert derive_output_path("report.pdf") == "report.md"
|
||||
|
||||
def test_xlsx_to_md(self):
|
||||
assert derive_output_path("data.xlsx") == "data.md"
|
||||
|
||||
def test_with_directory(self):
|
||||
assert derive_output_path("/home/user/docs/file.docx") == "/home/user/docs/file.md"
|
||||
|
||||
def test_txt_to_md(self):
|
||||
assert derive_output_path("notes.txt") == "notes.md"
|
||||
|
||||
def test_no_extension(self):
|
||||
assert derive_output_path("README") == "README.md"
|
||||
|
||||
def test_multiple_dots(self):
|
||||
assert derive_output_path("my.report.v2.pdf") == "my.report.v2.md"
|
||||
|
||||
|
||||
class TestBuildParser:
|
||||
"""argparse 参数解析测试"""
|
||||
|
||||
def test_all_args(self):
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(["input.pdf", "-k", "sk-abc", "-o", "out.md", "-d", "==="])
|
||||
assert args.input_file == "input.pdf"
|
||||
assert args.api_key == "sk-abc"
|
||||
assert args.output == "out.md"
|
||||
assert args.delimiter == "==="
|
||||
|
||||
def test_required_args_only(self):
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(["input.pdf", "-k", "sk-abc"])
|
||||
assert args.input_file == "input.pdf"
|
||||
assert args.api_key == "sk-abc"
|
||||
assert args.output is None
|
||||
assert args.delimiter == "---"
|
||||
|
||||
def test_long_option_names(self):
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(["input.pdf", "--api-key", "sk-abc", "--output", "out.md", "--delimiter", "***"])
|
||||
assert args.api_key == "sk-abc"
|
||||
assert args.output == "out.md"
|
||||
assert args.delimiter == "***"
|
||||
|
||||
def test_missing_input_file(self):
|
||||
parser = build_parser()
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
parser.parse_args(["-k", "sk-abc"])
|
||||
assert exc_info.value.code != 0
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_missing_api_key(self):
|
||||
"""无 -k 且无环境变量时,api_key 应为 None"""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(["input.pdf"])
|
||||
assert args.api_key is None
|
||||
|
||||
|
||||
class TestMainFunction:
|
||||
"""main() 函数集成测试"""
|
||||
|
||||
@patch("main.Splitter")
|
||||
def test_success(self, mock_splitter_cls):
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
|
||||
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc"]):
|
||||
main()
|
||||
|
||||
mock_splitter_cls.assert_called_once_with(
|
||||
api_key="sk-abc", delimiter="---",
|
||||
pre_split_size=None, vision_prompt=None, output_format="markdown",
|
||||
)
|
||||
mock_splitter.process.assert_called_once_with("input.pdf", "input.md")
|
||||
|
||||
@patch("main.Splitter")
|
||||
def test_custom_output(self, mock_splitter_cls):
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
|
||||
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc", "-o", "custom.md"]):
|
||||
main()
|
||||
|
||||
mock_splitter.process.assert_called_once_with("input.pdf", "custom.md")
|
||||
|
||||
@patch("main.Splitter")
|
||||
def test_custom_delimiter(self, mock_splitter_cls):
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
|
||||
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc", "-d", "==="]):
|
||||
main()
|
||||
|
||||
mock_splitter_cls.assert_called_once_with(
|
||||
api_key="sk-abc", delimiter="===",
|
||||
pre_split_size=None, vision_prompt=None, output_format="markdown",
|
||||
)
|
||||
|
||||
@patch("main.Splitter")
|
||||
def test_file_not_found_error(self, mock_splitter_cls, capsys):
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
mock_splitter.process.side_effect = FileNotFoundError("输入文件不存在: missing.pdf")
|
||||
|
||||
with patch("sys.argv", ["main.py", "missing.pdf", "-k", "sk-abc"]):
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
main()
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "missing.pdf" in captured.err
|
||||
|
||||
@patch("main.Splitter")
|
||||
def test_unsupported_format_error(self, mock_splitter_cls, capsys):
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
mock_splitter.process.side_effect = UnsupportedFormatError("file.xyz", ".xyz")
|
||||
|
||||
with patch("sys.argv", ["main.py", "file.xyz", "-k", "sk-abc"]):
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
main()
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert ".xyz" in captured.err
|
||||
|
||||
@patch("main.Splitter")
|
||||
def test_parse_error(self, mock_splitter_cls, capsys):
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
mock_splitter.process.side_effect = ParseError("bad.pdf", "文件损坏")
|
||||
|
||||
with patch("sys.argv", ["main.py", "bad.pdf", "-k", "sk-abc"]):
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
main()
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "bad.pdf" in captured.err
|
||||
|
||||
@patch("main.Splitter")
|
||||
def test_api_error(self, mock_splitter_cls, capsys):
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
mock_splitter.process.side_effect = ApiError("认证失败", status_code=401)
|
||||
|
||||
with patch("sys.argv", ["main.py", "input.pdf", "-k", "bad-key"]):
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
main()
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "API" in captured.err
|
||||
|
||||
@patch("main.Splitter")
|
||||
def test_generic_exception(self, mock_splitter_cls, capsys):
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
mock_splitter.process.side_effect = RuntimeError("意外错误")
|
||||
|
||||
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc"]):
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
main()
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "意外错误" in captured.err
|
||||
57
tests/test_models.py
Normal file
57
tests/test_models.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""核心数据结构单元测试"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from models import Chunk, CLIArgs, ProcessResult
|
||||
|
||||
|
||||
class TestChunk:
|
||||
def test_creation(self):
|
||||
chunk = Chunk(title="概述", content="这是内容")
|
||||
assert chunk.title == "概述"
|
||||
assert chunk.content == "这是内容"
|
||||
|
||||
def test_equality(self):
|
||||
a = Chunk(title="t", content="c")
|
||||
b = Chunk(title="t", content="c")
|
||||
assert a == b
|
||||
|
||||
|
||||
class TestProcessResult:
|
||||
def test_creation(self):
|
||||
now = datetime.now()
|
||||
chunks = [Chunk("t1", "c1"), Chunk("t2", "c2")]
|
||||
result = ProcessResult(
|
||||
source_file="input.pdf",
|
||||
output_file="output.md",
|
||||
chunks=chunks,
|
||||
process_time=now,
|
||||
total_chunks=2,
|
||||
)
|
||||
assert result.source_file == "input.pdf"
|
||||
assert result.output_file == "output.md"
|
||||
assert len(result.chunks) == 2
|
||||
assert result.process_time == now
|
||||
assert result.total_chunks == 2
|
||||
|
||||
|
||||
class TestCLIArgs:
|
||||
def test_required_fields(self):
|
||||
args = CLIArgs(input_file="doc.pdf", api_key="sk-123")
|
||||
assert args.input_file == "doc.pdf"
|
||||
assert args.api_key == "sk-123"
|
||||
|
||||
def test_defaults(self):
|
||||
args = CLIArgs(input_file="doc.pdf", api_key="sk-123")
|
||||
assert args.output_file is None
|
||||
assert args.delimiter == "---"
|
||||
|
||||
def test_custom_values(self):
|
||||
args = CLIArgs(
|
||||
input_file="doc.pdf",
|
||||
api_key="sk-123",
|
||||
output_file="out.md",
|
||||
delimiter="***",
|
||||
)
|
||||
assert args.output_file == "out.md"
|
||||
assert args.delimiter == "***"
|
||||
82
tests/test_parsers_base.py
Normal file
82
tests/test_parsers_base.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""BaseParser 和 ParserRegistry 单元测试"""
|
||||
|
||||
import pytest
|
||||
from typing import List
|
||||
|
||||
from exceptions import UnsupportedFormatError
|
||||
from parsers.base import BaseParser, ParserRegistry
|
||||
|
||||
|
||||
class StubParser(BaseParser):
|
||||
"""用于测试的具体解析器实现"""
|
||||
|
||||
def __init__(self, extensions: List[str]):
|
||||
self._extensions = extensions
|
||||
|
||||
def supported_extensions(self) -> List[str]:
|
||||
return self._extensions
|
||||
|
||||
def parse(self, file_path: str) -> str:
|
||||
return f"parsed: {file_path}"
|
||||
|
||||
|
||||
class TestBaseParser:
|
||||
def test_cannot_instantiate_directly(self):
|
||||
with pytest.raises(TypeError):
|
||||
BaseParser()
|
||||
|
||||
def test_concrete_subclass_works(self):
|
||||
parser = StubParser([".txt"])
|
||||
assert parser.supported_extensions() == [".txt"]
|
||||
assert parser.parse("test.txt") == "parsed: test.txt"
|
||||
|
||||
|
||||
class TestParserRegistry:
|
||||
def test_empty_registry_raises(self):
|
||||
registry = ParserRegistry()
|
||||
with pytest.raises(UnsupportedFormatError):
|
||||
registry.get_parser("file.pdf")
|
||||
|
||||
def test_register_and_get_parser(self):
|
||||
registry = ParserRegistry()
|
||||
pdf_parser = StubParser([".pdf"])
|
||||
registry.register(pdf_parser)
|
||||
assert registry.get_parser("document.pdf") is pdf_parser
|
||||
|
||||
def test_multiple_parsers(self):
|
||||
registry = ParserRegistry()
|
||||
pdf_parser = StubParser([".pdf"])
|
||||
txt_parser = StubParser([".txt", ".md"])
|
||||
registry.register(pdf_parser)
|
||||
registry.register(txt_parser)
|
||||
|
||||
assert registry.get_parser("doc.pdf") is pdf_parser
|
||||
assert registry.get_parser("readme.txt") is txt_parser
|
||||
assert registry.get_parser("notes.md") is txt_parser
|
||||
|
||||
def test_unsupported_format_error_details(self):
|
||||
registry = ParserRegistry()
|
||||
registry.register(StubParser([".pdf"]))
|
||||
with pytest.raises(UnsupportedFormatError) as exc_info:
|
||||
registry.get_parser("file.xyz")
|
||||
assert exc_info.value.extension == ".xyz"
|
||||
assert exc_info.value.file_name == "file.xyz"
|
||||
|
||||
def test_case_insensitive_extension(self):
|
||||
registry = ParserRegistry()
|
||||
registry.register(StubParser([".pdf"]))
|
||||
assert registry.get_parser("DOC.PDF") is not None
|
||||
|
||||
def test_file_path_with_directory(self):
|
||||
registry = ParserRegistry()
|
||||
parser = StubParser([".csv"])
|
||||
registry.register(parser)
|
||||
assert registry.get_parser("/home/user/data/report.csv") is parser
|
||||
|
||||
def test_first_matching_parser_wins(self):
|
||||
registry = ParserRegistry()
|
||||
first = StubParser([".txt"])
|
||||
second = StubParser([".txt"])
|
||||
registry.register(first)
|
||||
registry.register(second)
|
||||
assert registry.get_parser("file.txt") is first
|
||||
159
tests/test_pdf_parser.py
Normal file
159
tests/test_pdf_parser.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""PdfParser 单元测试"""
|
||||
|
||||
import pytest
|
||||
import fitz
|
||||
|
||||
from exceptions import ParseError
|
||||
from parsers.pdf_parser import PdfParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser():
|
||||
return PdfParser()
|
||||
|
||||
|
||||
def _create_pdf(path, pages):
|
||||
"""
|
||||
创建测试用 PDF 文件。
|
||||
|
||||
Args:
|
||||
path: 输出文件路径
|
||||
pages: 列表,每个元素是 (text, fontsize) 元组的列表,代表一页中的文本行
|
||||
"""
|
||||
doc = fitz.open()
|
||||
for page_items in pages:
|
||||
page = doc.new_page()
|
||||
y = 72
|
||||
for text, fontsize in page_items:
|
||||
page.insert_text((72, y), text, fontsize=fontsize)
|
||||
y += fontsize + 10
|
||||
doc.save(str(path))
|
||||
doc.close()
|
||||
|
||||
|
||||
class TestSupportedExtensions:
|
||||
def test_supports_pdf(self, parser):
|
||||
assert ".pdf" in parser.supported_extensions()
|
||||
|
||||
def test_only_one_extension(self, parser):
|
||||
assert len(parser.supported_extensions()) == 1
|
||||
|
||||
|
||||
class TestParse:
|
||||
def test_parse_simple_text(self, parser, tmp_path):
|
||||
pdf_path = tmp_path / "simple.pdf"
|
||||
_create_pdf(pdf_path, [
|
||||
[("Hello, world!", 12)],
|
||||
])
|
||||
result = parser.parse(str(pdf_path))
|
||||
assert "Hello, world!" in result
|
||||
|
||||
def test_parse_multiline_text(self, parser, tmp_path):
|
||||
pdf_path = tmp_path / "multi.pdf"
|
||||
_create_pdf(pdf_path, [
|
||||
[("Line one", 12), ("Line two", 12)],
|
||||
])
|
||||
result = parser.parse(str(pdf_path))
|
||||
assert "Line one" in result
|
||||
assert "Line two" in result
|
||||
|
||||
def test_parse_multiple_pages(self, parser, tmp_path):
|
||||
pdf_path = tmp_path / "pages.pdf"
|
||||
_create_pdf(pdf_path, [
|
||||
[("Page one content", 12)],
|
||||
[("Page two content", 12)],
|
||||
])
|
||||
result = parser.parse(str(pdf_path))
|
||||
assert "Page one content" in result
|
||||
assert "Page two content" in result
|
||||
|
||||
def test_heading_level2_detection(self, parser, tmp_path):
|
||||
"""Font size > body_mode + 2 should produce ## heading"""
|
||||
pdf_path = tmp_path / "h2.pdf"
|
||||
# Body text at size 12 (will be the mode), heading at size 18 (diff=6 > 2)
|
||||
_create_pdf(pdf_path, [
|
||||
[
|
||||
("Body text line one", 12),
|
||||
("Body text line two", 12),
|
||||
("Body text line three", 12),
|
||||
("Big Heading", 18),
|
||||
],
|
||||
])
|
||||
result = parser.parse(str(pdf_path))
|
||||
assert "## Big Heading" in result
|
||||
|
||||
def test_heading_level3_detection(self, parser, tmp_path):
|
||||
"""Font size > body_mode + 0.5 but <= body_mode + 2 should produce ### heading"""
|
||||
pdf_path = tmp_path / "h3.pdf"
|
||||
# Body text at size 12 (mode), heading at size 13.5 (diff=1.5, >0.5 and <=2)
|
||||
_create_pdf(pdf_path, [
|
||||
[
|
||||
("Body text one", 12),
|
||||
("Body text two", 12),
|
||||
("Body text three", 12),
|
||||
("Sub Heading", 13.5),
|
||||
],
|
||||
])
|
||||
result = parser.parse(str(pdf_path))
|
||||
assert "### Sub Heading" in result
|
||||
|
||||
def test_body_text_no_heading_prefix(self, parser, tmp_path):
|
||||
"""Text at body font size should not have heading prefix"""
|
||||
pdf_path = tmp_path / "body.pdf"
|
||||
_create_pdf(pdf_path, [
|
||||
[("Normal text", 12), ("More normal text", 12)],
|
||||
])
|
||||
result = parser.parse(str(pdf_path))
|
||||
assert "## Normal text" not in result
|
||||
assert "### Normal text" not in result
|
||||
assert "Normal text" in result
|
||||
|
||||
def test_empty_pdf(self, parser, tmp_path):
|
||||
"""Empty PDF (no text) should return empty string"""
|
||||
pdf_path = tmp_path / "empty.pdf"
|
||||
doc = fitz.open()
|
||||
doc.new_page()
|
||||
doc.save(str(pdf_path))
|
||||
doc.close()
|
||||
result = parser.parse(str(pdf_path))
|
||||
assert result.strip() == ""
|
||||
|
||||
def test_nonexistent_file_raises(self, parser):
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/nonexistent/path/file.pdf")
|
||||
assert "file.pdf" in exc_info.value.file_name
|
||||
assert exc_info.value.reason != ""
|
||||
|
||||
def test_corrupted_file_raises(self, parser, tmp_path):
|
||||
pdf_path = tmp_path / "corrupted.pdf"
|
||||
pdf_path.write_bytes(b"this is not a pdf file at all")
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse(str(pdf_path))
|
||||
assert "corrupted.pdf" in exc_info.value.file_name
|
||||
|
||||
def test_parse_error_contains_filename(self, parser):
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/no/such/report.pdf")
|
||||
assert exc_info.value.file_name == "report.pdf"
|
||||
|
||||
def test_mixed_headings_and_body(self, parser, tmp_path):
|
||||
"""Test a document with mixed heading levels and body text"""
|
||||
pdf_path = tmp_path / "mixed.pdf"
|
||||
_create_pdf(pdf_path, [
|
||||
[
|
||||
("Body one", 12),
|
||||
("Body two", 12),
|
||||
("Body three", 12),
|
||||
("Body four", 12),
|
||||
("Body five", 12),
|
||||
("Main Title", 20),
|
||||
("Section Title", 14),
|
||||
("Paragraph text", 12),
|
||||
],
|
||||
])
|
||||
result = parser.parse(str(pdf_path))
|
||||
assert "## Main Title" in result
|
||||
assert "### Section Title" in result
|
||||
# Body text should not have heading markers
|
||||
assert "## Body one" not in result
|
||||
assert "## Paragraph text" not in result
|
||||
84
tests/test_prompts.py
Normal file
84
tests/test_prompts.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""提示词模块单元测试。"""
|
||||
|
||||
from prompts import (
|
||||
SYSTEM_PROMPT_TEMPLATE,
|
||||
USER_PROMPT_TEMPLATE,
|
||||
get_system_prompt,
|
||||
get_user_prompt,
|
||||
)
|
||||
|
||||
|
||||
class TestSystemPromptTemplate:
|
||||
"""系统提示词模板测试。"""
|
||||
|
||||
def test_contains_delimiter_placeholder(self):
|
||||
assert "{delimiter}" in SYSTEM_PROMPT_TEMPLATE
|
||||
|
||||
def test_contains_semantic_completeness_rule(self):
|
||||
assert "语义完整性" in SYSTEM_PROMPT_TEMPLATE
|
||||
|
||||
def test_contains_self_contained_rule(self):
|
||||
assert "自包含性" in SYSTEM_PROMPT_TEMPLATE
|
||||
|
||||
def test_contains_heading_preservation_rule(self):
|
||||
assert "标题层级保留" in SYSTEM_PROMPT_TEMPLATE
|
||||
|
||||
def test_contains_table_integrity_rule(self):
|
||||
assert "表格完整性" in SYSTEM_PROMPT_TEMPLATE
|
||||
|
||||
def test_contains_granularity_rule(self):
|
||||
assert "合理粒度" in SYSTEM_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
class TestUserPromptTemplate:
|
||||
"""用户提示词模板测试。"""
|
||||
|
||||
def test_contains_text_content_placeholder(self):
|
||||
assert "{text_content}" in USER_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
class TestGetSystemPrompt:
|
||||
"""get_system_prompt 函数测试。"""
|
||||
|
||||
def test_default_delimiter(self):
|
||||
result = get_system_prompt()
|
||||
assert "---" in result
|
||||
assert "{delimiter}" not in result
|
||||
|
||||
def test_custom_delimiter(self):
|
||||
result = get_system_prompt("===SPLIT===")
|
||||
assert "===SPLIT===" in result
|
||||
assert "{delimiter}" not in result
|
||||
|
||||
def test_delimiter_appears_in_format_example(self):
|
||||
result = get_system_prompt("***")
|
||||
# 分隔符应出现在格式说明和示例中
|
||||
assert "`***`" in result
|
||||
|
||||
def test_empty_delimiter(self):
|
||||
result = get_system_prompt("")
|
||||
assert "{delimiter}" not in result
|
||||
|
||||
|
||||
class TestGetUserPrompt:
|
||||
"""get_user_prompt 函数测试。"""
|
||||
|
||||
def test_text_content_substitution(self):
|
||||
result = get_user_prompt("这是一段测试文本。")
|
||||
assert "这是一段测试文本。" in result
|
||||
assert "{text_content}" not in result
|
||||
|
||||
def test_preserves_surrounding_markers(self):
|
||||
result = get_user_prompt("内容")
|
||||
assert "---开始---" in result
|
||||
assert "---结束---" in result
|
||||
|
||||
def test_multiline_content(self):
|
||||
content = "第一行\n第二行\n第三行"
|
||||
result = get_user_prompt(content)
|
||||
assert content in result
|
||||
|
||||
def test_empty_content(self):
|
||||
result = get_user_prompt("")
|
||||
assert "{text_content}" not in result
|
||||
assert "---开始---" in result
|
||||
359
tests/test_splitter.py
Normal file
359
tests/test_splitter.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""Splitter 协调器单元测试"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
from exceptions import ApiError, ParseError, UnsupportedFormatError
|
||||
from models import Chunk
|
||||
from splitter import Splitter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_deps():
|
||||
"""Patch all external dependencies and return their mocks."""
|
||||
with (
|
||||
patch("splitter.ApiClient") as mock_api_cls,
|
||||
patch("splitter.AIChunker") as mock_chunker_cls,
|
||||
patch("splitter.MarkdownWriter") as mock_writer_cls,
|
||||
patch("splitter.JsonWriter"),
|
||||
patch("splitter.ParserRegistry") as mock_registry_cls,
|
||||
patch("splitter.TextParser"),
|
||||
patch("splitter.CsvParser"),
|
||||
patch("splitter.HtmlParser"),
|
||||
patch("splitter.PdfParser"),
|
||||
patch("splitter.DocParser"),
|
||||
patch("splitter.LegacyDocParser"),
|
||||
patch("splitter.XlsxParser"),
|
||||
patch("splitter.XlsParser"),
|
||||
patch("splitter.ImageParser"),
|
||||
):
|
||||
api_client = mock_api_cls.return_value
|
||||
chunker = mock_chunker_cls.return_value
|
||||
writer = mock_writer_cls.return_value
|
||||
registry = mock_registry_cls.return_value
|
||||
|
||||
splitter = Splitter(api_key="test-key", delimiter="---")
|
||||
|
||||
yield {
|
||||
"splitter": splitter,
|
||||
"api_client": api_client,
|
||||
"chunker": chunker,
|
||||
"writer": writer,
|
||||
"registry": registry,
|
||||
}
|
||||
|
||||
|
||||
class TestInit:
|
||||
"""初始化测试"""
|
||||
|
||||
def test_registers_all_parsers(self):
|
||||
"""验证所有解析器都被注册"""
|
||||
with (
|
||||
patch("splitter.ApiClient"),
|
||||
patch("splitter.AIChunker"),
|
||||
patch("splitter.MarkdownWriter"),
|
||||
patch("splitter.JsonWriter"),
|
||||
patch("splitter.ParserRegistry") as mock_registry_cls,
|
||||
patch("splitter.TextParser"),
|
||||
patch("splitter.CsvParser"),
|
||||
patch("splitter.HtmlParser"),
|
||||
patch("splitter.PdfParser"),
|
||||
patch("splitter.DocParser"),
|
||||
patch("splitter.LegacyDocParser"),
|
||||
patch("splitter.XlsxParser"),
|
||||
patch("splitter.XlsParser"),
|
||||
patch("splitter.ImageParser"),
|
||||
):
|
||||
registry = mock_registry_cls.return_value
|
||||
Splitter(api_key="test-key")
|
||||
# 9 parsers: Text, Csv, Html, Pdf, Doc, LegacyDoc, Xlsx, Xls, Image
|
||||
assert registry.register.call_count == 9
|
||||
|
||||
def test_creates_api_client_with_key(self):
|
||||
"""验证 ApiClient 使用正确的 api_key 创建"""
|
||||
with (
|
||||
patch("splitter.ApiClient") as mock_api_cls,
|
||||
patch("splitter.AIChunker"),
|
||||
patch("splitter.MarkdownWriter"),
|
||||
patch("splitter.JsonWriter"),
|
||||
patch("splitter.ParserRegistry"),
|
||||
patch("splitter.TextParser"),
|
||||
patch("splitter.CsvParser"),
|
||||
patch("splitter.HtmlParser"),
|
||||
patch("splitter.PdfParser"),
|
||||
patch("splitter.DocParser"),
|
||||
patch("splitter.LegacyDocParser"),
|
||||
patch("splitter.XlsxParser"),
|
||||
patch("splitter.XlsParser"),
|
||||
patch("splitter.ImageParser"),
|
||||
):
|
||||
Splitter(api_key="my-secret-key")
|
||||
mock_api_cls.assert_called_once_with(api_key="my-secret-key")
|
||||
|
||||
def test_creates_chunker_with_delimiter(self):
|
||||
"""验证 AIChunker 使用正确的 delimiter 创建"""
|
||||
with (
|
||||
patch("splitter.ApiClient") as mock_api_cls,
|
||||
patch("splitter.AIChunker") as mock_chunker_cls,
|
||||
patch("splitter.MarkdownWriter"),
|
||||
patch("splitter.JsonWriter"),
|
||||
patch("splitter.ParserRegistry"),
|
||||
patch("splitter.TextParser"),
|
||||
patch("splitter.CsvParser"),
|
||||
patch("splitter.HtmlParser"),
|
||||
patch("splitter.PdfParser"),
|
||||
patch("splitter.DocParser"),
|
||||
patch("splitter.LegacyDocParser"),
|
||||
patch("splitter.XlsxParser"),
|
||||
patch("splitter.XlsParser"),
|
||||
patch("splitter.ImageParser"),
|
||||
):
|
||||
Splitter(api_key="key", delimiter="===")
|
||||
mock_chunker_cls.assert_called_once_with(
|
||||
mock_api_cls.return_value, "===", pre_split_size=None
|
||||
)
|
||||
|
||||
|
||||
class TestProcessSuccess:
|
||||
"""成功处理流程测试"""
|
||||
|
||||
def test_full_flow(self, mock_deps, tmp_path, capsys):
|
||||
"""验证完整的成功处理流程"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
chunker = mock_deps["chunker"]
|
||||
writer = mock_deps["writer"]
|
||||
|
||||
# Setup
|
||||
input_file = tmp_path / "test.txt"
|
||||
input_file.write_text("hello")
|
||||
output_file = str(tmp_path / "output.md")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "parsed text"
|
||||
registry.get_parser.return_value = mock_parser
|
||||
|
||||
chunks = [Chunk(title="标题", content="内容")]
|
||||
chunker.chunk.return_value = chunks
|
||||
|
||||
# Execute
|
||||
splitter.process(str(input_file), output_file)
|
||||
|
||||
# Verify call chain
|
||||
registry.get_parser.assert_called_once_with(str(input_file))
|
||||
mock_parser.parse.assert_called_once_with(str(input_file))
|
||||
chunker.chunk.assert_called_once()
|
||||
assert chunker.chunk.call_args[0][0] == "parsed text"
|
||||
writer.write.assert_called_once_with(
|
||||
chunks, output_file, "test.txt", "---"
|
||||
)
|
||||
|
||||
def test_logs_parsing_stage(self, mock_deps, tmp_path, capsys):
|
||||
"""验证输出文件解析日志"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
chunker = mock_deps["chunker"]
|
||||
|
||||
input_file = tmp_path / "doc.pdf"
|
||||
input_file.write_text("data")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "text"
|
||||
registry.get_parser.return_value = mock_parser
|
||||
chunker.chunk.return_value = [Chunk(title="t", content="c")]
|
||||
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "解析文件: doc.pdf" in output
|
||||
|
||||
def test_logs_chunking_stage(self, mock_deps, tmp_path, capsys):
|
||||
"""验证输出 AI 分块日志"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
chunker = mock_deps["chunker"]
|
||||
|
||||
input_file = tmp_path / "doc.txt"
|
||||
input_file.write_text("data")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "text"
|
||||
registry.get_parser.return_value = mock_parser
|
||||
chunker.chunk.return_value = [Chunk(title="t", content="c")]
|
||||
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "AI 语义分块" in output
|
||||
|
||||
def test_logs_writing_stage(self, mock_deps, tmp_path, capsys):
|
||||
"""验证输出写入日志"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
chunker = mock_deps["chunker"]
|
||||
|
||||
input_file = tmp_path / "doc.txt"
|
||||
input_file.write_text("data")
|
||||
output_path = str(tmp_path / "out.md")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "text"
|
||||
registry.get_parser.return_value = mock_parser
|
||||
chunker.chunk.return_value = [Chunk(title="t", content="c")]
|
||||
|
||||
splitter.process(str(input_file), output_path)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "写入输出" in output
|
||||
|
||||
def test_logs_summary(self, mock_deps, tmp_path, capsys):
|
||||
"""验证输出处理摘要"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
chunker = mock_deps["chunker"]
|
||||
|
||||
input_file = tmp_path / "doc.txt"
|
||||
input_file.write_text("data")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "text"
|
||||
registry.get_parser.return_value = mock_parser
|
||||
chunker.chunk.return_value = [
|
||||
Chunk(title="t1", content="c1"),
|
||||
Chunk(title="t2", content="c2"),
|
||||
Chunk(title="t3", content="c3"),
|
||||
]
|
||||
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "3 个分块" in output
|
||||
|
||||
def test_progress_callback_passed_to_chunker(self, mock_deps, tmp_path, capsys):
|
||||
"""验证进度回调被传递给 chunker 并正确输出"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
chunker = mock_deps["chunker"]
|
||||
|
||||
input_file = tmp_path / "doc.txt"
|
||||
input_file.write_text("data")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "text"
|
||||
registry.get_parser.return_value = mock_parser
|
||||
|
||||
# Simulate chunker calling the progress callback
|
||||
def fake_chunk(text, content_type=None, source_file="", on_progress=None):
|
||||
if on_progress:
|
||||
on_progress(1, 3)
|
||||
on_progress(2, 3)
|
||||
on_progress(3, 3)
|
||||
return [Chunk(title="t", content="c")]
|
||||
|
||||
chunker.chunk.side_effect = fake_chunk
|
||||
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "分块进度: 1/3" in output
|
||||
assert "分块进度: 2/3" in output
|
||||
assert "分块进度: 3/3" in output
|
||||
|
||||
|
||||
class TestProcessErrors:
|
||||
"""错误处理测试"""
|
||||
|
||||
def test_file_not_found(self, mock_deps):
|
||||
"""验证文件不存在时抛出 FileNotFoundError"""
|
||||
splitter = mock_deps["splitter"]
|
||||
with pytest.raises(FileNotFoundError, match="输入文件不存在"):
|
||||
splitter.process("/nonexistent/path/file.txt", "output.md")
|
||||
|
||||
def test_unsupported_format(self, mock_deps, tmp_path):
|
||||
"""验证不支持的格式时抛出 UnsupportedFormatError"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
|
||||
input_file = tmp_path / "file.xyz"
|
||||
input_file.write_text("data")
|
||||
|
||||
registry.get_parser.side_effect = UnsupportedFormatError("file.xyz", ".xyz")
|
||||
|
||||
with pytest.raises(UnsupportedFormatError):
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
def test_parse_error(self, mock_deps, tmp_path):
|
||||
"""验证解析错误时抛出 ParseError"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
|
||||
input_file = tmp_path / "bad.pdf"
|
||||
input_file.write_bytes(b"\x00\x01\x02")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.side_effect = ParseError("bad.pdf", "文件损坏")
|
||||
registry.get_parser.return_value = mock_parser
|
||||
|
||||
with pytest.raises(ParseError, match="bad.pdf"):
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
def test_api_error(self, mock_deps, tmp_path):
|
||||
"""验证 API 错误时抛出 ApiError"""
|
||||
splitter = mock_deps["splitter"]
|
||||
registry = mock_deps["registry"]
|
||||
chunker = mock_deps["chunker"]
|
||||
|
||||
input_file = tmp_path / "doc.txt"
|
||||
input_file.write_text("data")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "text"
|
||||
registry.get_parser.return_value = mock_parser
|
||||
chunker.chunk.side_effect = ApiError("API 调用失败")
|
||||
|
||||
with pytest.raises(ApiError, match="API 调用失败"):
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
|
||||
class TestCustomDelimiter:
|
||||
"""自定义分隔符测试"""
|
||||
|
||||
def test_delimiter_passed_to_writer(self, tmp_path):
|
||||
"""验证自定义分隔符传递给 writer"""
|
||||
with (
|
||||
patch("splitter.ApiClient"),
|
||||
patch("splitter.AIChunker") as mock_chunker_cls,
|
||||
patch("splitter.MarkdownWriter") as mock_writer_cls,
|
||||
patch("splitter.JsonWriter"),
|
||||
patch("splitter.ParserRegistry") as mock_registry_cls,
|
||||
patch("splitter.TextParser"),
|
||||
patch("splitter.CsvParser"),
|
||||
patch("splitter.HtmlParser"),
|
||||
patch("splitter.PdfParser"),
|
||||
patch("splitter.DocParser"),
|
||||
patch("splitter.LegacyDocParser"),
|
||||
patch("splitter.XlsxParser"),
|
||||
patch("splitter.XlsParser"),
|
||||
patch("splitter.ImageParser"),
|
||||
):
|
||||
splitter = Splitter(api_key="key", delimiter="===")
|
||||
|
||||
input_file = tmp_path / "test.txt"
|
||||
input_file.write_text("hello")
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = "text"
|
||||
mock_registry_cls.return_value.get_parser.return_value = mock_parser
|
||||
|
||||
chunks = [Chunk(title="t", content="c")]
|
||||
mock_chunker_cls.return_value.chunk.return_value = chunks
|
||||
|
||||
writer = mock_writer_cls.return_value
|
||||
|
||||
splitter.process(str(input_file), str(tmp_path / "out.md"))
|
||||
|
||||
writer.write.assert_called_once_with(
|
||||
chunks, str(tmp_path / "out.md"), "test.txt", "==="
|
||||
)
|
||||
83
tests/test_text_parser.py
Normal file
83
tests/test_text_parser.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""TextParser 单元测试"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from exceptions import ParseError
|
||||
from parsers.text_parser import TextParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser():
|
||||
return TextParser()
|
||||
|
||||
|
||||
class TestSupportedExtensions:
|
||||
def test_supports_txt(self, parser):
|
||||
assert ".txt" in parser.supported_extensions()
|
||||
|
||||
def test_supports_md(self, parser):
|
||||
assert ".md" in parser.supported_extensions()
|
||||
|
||||
def test_only_two_extensions(self, parser):
|
||||
assert len(parser.supported_extensions()) == 2
|
||||
|
||||
|
||||
class TestParse:
|
||||
def test_parse_utf8_txt(self, parser, tmp_path):
|
||||
f = tmp_path / "test.txt"
|
||||
f.write_text("Hello, world!", encoding="utf-8")
|
||||
assert parser.parse(str(f)) == "Hello, world!"
|
||||
|
||||
def test_parse_utf8_md(self, parser, tmp_path):
|
||||
f = tmp_path / "readme.md"
|
||||
content = "# Title\n\nSome **bold** text."
|
||||
f.write_bytes(content.encode("utf-8"))
|
||||
assert parser.parse(str(f)) == content
|
||||
|
||||
def test_parse_gbk_encoded_file(self, parser, tmp_path):
|
||||
f = tmp_path / "chinese.txt"
|
||||
# Use longer text so charset_normalizer can reliably detect GBK
|
||||
content = "你好,世界!这是一段中文文本。我们正在测试文件编码的自动检测功能,需要足够长的文本才能让检测器准确识别编码格式。"
|
||||
f.write_bytes(content.encode("gbk"))
|
||||
result = parser.parse(str(f))
|
||||
assert result == content
|
||||
|
||||
def test_parse_utf8_bom(self, parser, tmp_path):
|
||||
f = tmp_path / "bom.txt"
|
||||
content = "UTF-8 with BOM"
|
||||
f.write_bytes(b"\xef\xbb\xbf" + content.encode("utf-8"))
|
||||
result = parser.parse(str(f))
|
||||
assert "UTF-8 with BOM" in result
|
||||
|
||||
def test_parse_empty_file(self, parser, tmp_path):
|
||||
f = tmp_path / "empty.txt"
|
||||
f.write_bytes(b"")
|
||||
assert parser.parse(str(f)) == ""
|
||||
|
||||
def test_parse_multiline(self, parser, tmp_path):
|
||||
f = tmp_path / "multi.md"
|
||||
content = "Line 1\nLine 2\nLine 3\n"
|
||||
f.write_bytes(content.encode("utf-8"))
|
||||
assert parser.parse(str(f)) == content
|
||||
|
||||
def test_parse_nonexistent_file_raises(self, parser):
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/nonexistent/path/file.txt")
|
||||
assert "file.txt" in exc_info.value.file_name
|
||||
assert exc_info.value.reason != ""
|
||||
|
||||
def test_parse_error_contains_filename(self, parser):
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/no/such/myfile.txt")
|
||||
assert exc_info.value.file_name == "myfile.txt"
|
||||
|
||||
def test_parse_latin1_encoded_file(self, parser, tmp_path):
|
||||
f = tmp_path / "latin.txt"
|
||||
content = "café résumé naïve"
|
||||
f.write_bytes(content.encode("latin-1"))
|
||||
result = parser.parse(str(f))
|
||||
assert "caf" in result
|
||||
assert "sum" in result
|
||||
225
tests/test_writer.py
Normal file
225
tests/test_writer.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""MarkdownWriter 单元测试"""
|
||||
|
||||
import pytest
|
||||
|
||||
from models import Chunk
|
||||
from writer import MarkdownWriter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def writer():
|
||||
return MarkdownWriter()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_output(tmp_path):
|
||||
return str(tmp_path / "output.md")
|
||||
|
||||
|
||||
class TestSingleChunk:
|
||||
"""单个 Chunk 输出测试"""
|
||||
|
||||
def test_single_chunk_no_delimiter(self, writer, tmp_output):
|
||||
chunks = [Chunk(title="摘要标题", content="这是内容")]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
assert "---" not in content.split("-->", 1)[1]
|
||||
|
||||
def test_single_chunk_has_title(self, writer, tmp_output):
|
||||
chunks = [Chunk(title="摘要标题", content="这是内容")]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
assert "## 摘要标题" in content
|
||||
|
||||
def test_single_chunk_has_content(self, writer, tmp_output):
|
||||
chunks = [Chunk(title="摘要标题", content="这是内容")]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
assert "这是内容" in content
|
||||
|
||||
|
||||
class TestMultipleChunks:
|
||||
"""多个 Chunk 输出测试"""
|
||||
|
||||
def test_delimiter_between_chunks(self, writer, tmp_output):
|
||||
chunks = [
|
||||
Chunk(title="标题1", content="内容1"),
|
||||
Chunk(title="标题2", content="内容2"),
|
||||
Chunk(title="标题3", content="内容3"),
|
||||
]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
after_meta = content.split("-->", 1)[1]
|
||||
assert after_meta.count("\n---\n") == 2
|
||||
|
||||
def test_all_titles_present(self, writer, tmp_output):
|
||||
chunks = [
|
||||
Chunk(title="标题A", content="内容A"),
|
||||
Chunk(title="标题B", content="内容B"),
|
||||
]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
assert "## 标题A" in content
|
||||
assert "## 标题B" in content
|
||||
|
||||
def test_all_contents_present(self, writer, tmp_output):
|
||||
chunks = [
|
||||
Chunk(title="标题A", content="内容A"),
|
||||
Chunk(title="标题B", content="内容B"),
|
||||
]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
assert "内容A" in content
|
||||
assert "内容B" in content
|
||||
|
||||
def test_no_trailing_delimiter(self, writer, tmp_output):
|
||||
chunks = [
|
||||
Chunk(title="标题1", content="内容1"),
|
||||
Chunk(title="标题2", content="内容2"),
|
||||
]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
after_meta = content.split("-->", 1)[1]
|
||||
# The last chunk content should appear after the last delimiter
|
||||
# and there should be no delimiter after the last content
|
||||
last_delimiter_pos = after_meta.rfind("\n---\n")
|
||||
last_content_pos = after_meta.rfind("内容2")
|
||||
assert last_content_pos > last_delimiter_pos
|
||||
|
||||
|
||||
class TestMetaInfo:
|
||||
"""元信息注释测试"""
|
||||
|
||||
def test_contains_source_file(self, writer, tmp_output):
|
||||
chunks = [Chunk(title="标题", content="内容")]
|
||||
writer.write(chunks, tmp_output, "example.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
assert "源文件: example.pdf" in content
|
||||
|
||||
def test_contains_process_time(self, writer, tmp_output):
|
||||
chunks = [Chunk(title="标题", content="内容")]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
assert "处理时间:" in content
|
||||
|
||||
def test_contains_chunk_count(self, writer, tmp_output):
|
||||
chunks = [
|
||||
Chunk(title="标题1", content="内容1"),
|
||||
Chunk(title="标题2", content="内容2"),
|
||||
Chunk(title="标题3", content="内容3"),
|
||||
]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
assert "分块总数: 3" in content
|
||||
|
||||
def test_meta_is_html_comment(self, writer, tmp_output):
|
||||
chunks = [Chunk(title="标题", content="内容")]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
assert content.startswith("<!-- ")
|
||||
assert "-->" in content
|
||||
|
||||
def test_meta_at_file_start(self, writer, tmp_output):
|
||||
chunks = [Chunk(title="标题", content="内容")]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
comment_end = content.index("-->")
|
||||
title_pos = content.index("## 标题")
|
||||
assert comment_end < title_pos
|
||||
|
||||
|
||||
class TestFileOverwrite:
|
||||
"""文件覆盖测试"""
|
||||
|
||||
def test_overwrites_existing_file(self, writer, tmp_output):
|
||||
with open(tmp_output, "w", encoding="utf-8") as f:
|
||||
f.write("旧内容")
|
||||
|
||||
chunks = [Chunk(title="新标题", content="新内容")]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
assert "旧内容" not in content
|
||||
assert "新内容" in content
|
||||
|
||||
def test_prints_warning_on_overwrite(self, writer, tmp_output, capsys):
|
||||
with open(tmp_output, "w", encoding="utf-8") as f:
|
||||
f.write("旧内容")
|
||||
|
||||
chunks = [Chunk(title="标题", content="内容")]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "警告" in captured.out
|
||||
assert tmp_output in captured.out
|
||||
|
||||
def test_no_warning_for_new_file(self, writer, tmp_output, capsys):
|
||||
chunks = [Chunk(title="标题", content="内容")]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "警告" not in captured.out
|
||||
|
||||
|
||||
class TestCustomDelimiter:
|
||||
"""自定义分隔符测试"""
|
||||
|
||||
def test_custom_delimiter(self, writer, tmp_output):
|
||||
chunks = [
|
||||
Chunk(title="标题1", content="内容1"),
|
||||
Chunk(title="标题2", content="内容2"),
|
||||
]
|
||||
writer.write(chunks, tmp_output, "test.pdf", delimiter="===")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
after_meta = content.split("-->", 1)[1]
|
||||
assert "\n===\n" in after_meta
|
||||
assert "\n---\n" not in after_meta
|
||||
|
||||
|
||||
class TestEmptyContent:
|
||||
"""空内容 Chunk 测试"""
|
||||
|
||||
def test_empty_content_chunk(self, writer, tmp_output):
|
||||
chunks = [Chunk(title="空内容标题", content="")]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
assert "## 空内容标题" in content
|
||||
|
||||
def test_empty_content_with_multiple_chunks(self, writer, tmp_output):
|
||||
chunks = [
|
||||
Chunk(title="标题1", content=""),
|
||||
Chunk(title="标题2", content="有内容"),
|
||||
]
|
||||
writer.write(chunks, tmp_output, "test.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
assert "## 标题1" in content
|
||||
assert "## 标题2" in content
|
||||
assert "有内容" in content
|
||||
|
||||
|
||||
class TestUTF8Encoding:
|
||||
"""UTF-8 编码测试"""
|
||||
|
||||
def test_utf8_encoding(self, writer, tmp_output):
|
||||
chunks = [Chunk(title="中文标题", content="中文内容,包含特殊字符:①②③")]
|
||||
writer.write(chunks, tmp_output, "测试文件.pdf")
|
||||
|
||||
content = open(tmp_output, encoding="utf-8").read()
|
||||
assert "中文标题" in content
|
||||
assert "①②③" in content
|
||||
assert "测试文件.pdf" in content
|
||||
178
tests/test_xls_parser.py
Normal file
178
tests/test_xls_parser.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""XlsParser 单元测试"""
|
||||
|
||||
import pytest
|
||||
import xlwt
|
||||
|
||||
from exceptions import ParseError
|
||||
from parsers.xls_parser import XlsParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser():
|
||||
return XlsParser()
|
||||
|
||||
|
||||
def _create_xls(path, sheets=None):
|
||||
"""
|
||||
创建测试用 XLS 文件。
|
||||
|
||||
Args:
|
||||
path: 输出文件路径
|
||||
sheets: dict,key 为 sheet 名称,value 为二维列表(行×列的数据)
|
||||
如果为 None,创建空工作簿
|
||||
"""
|
||||
wb = xlwt.Workbook()
|
||||
|
||||
if sheets:
|
||||
for sheet_name, rows in sheets.items():
|
||||
ws = wb.add_sheet(sheet_name)
|
||||
for row_idx, row in enumerate(rows):
|
||||
for col_idx, value in enumerate(row):
|
||||
ws.write(row_idx, col_idx, value)
|
||||
else:
|
||||
# xlwt 需要至少一个 sheet
|
||||
wb.add_sheet("Sheet1")
|
||||
|
||||
wb.save(str(path))
|
||||
|
||||
|
||||
class TestSupportedExtensions:
|
||||
def test_supports_xls(self, parser):
|
||||
assert ".xls" in parser.supported_extensions()
|
||||
|
||||
def test_only_one_extension(self, parser):
|
||||
assert len(parser.supported_extensions()) == 1
|
||||
|
||||
|
||||
class TestParse:
|
||||
def test_simple_table(self, parser, tmp_path):
|
||||
"""基本表格转换为 Markdown"""
|
||||
xls_path = tmp_path / "simple.xls"
|
||||
_create_xls(xls_path, {
|
||||
"Sheet1": [
|
||||
["Name", "Age"],
|
||||
["Alice", 30],
|
||||
["Bob", 25],
|
||||
]
|
||||
})
|
||||
result = parser.parse(str(xls_path))
|
||||
assert "## Sheet1" in result
|
||||
assert "| Name | Age |" in result
|
||||
assert "| --- | --- |" in result
|
||||
assert "Alice" in result
|
||||
assert "Bob" in result
|
||||
|
||||
def test_multiple_sheets(self, parser, tmp_path):
|
||||
"""多个工作表各自生成标题和表格"""
|
||||
xls_path = tmp_path / "multi.xls"
|
||||
_create_xls(xls_path, {
|
||||
"Users": [["Name"], ["Alice"]],
|
||||
"Orders": [["ID"], ["001"]],
|
||||
})
|
||||
result = parser.parse(str(xls_path))
|
||||
assert "## Users" in result
|
||||
assert "## Orders" in result
|
||||
assert "| Name |" in result
|
||||
assert "| ID |" in result
|
||||
|
||||
def test_empty_sheet_skipped(self, parser, tmp_path):
|
||||
"""空工作表应被跳过"""
|
||||
xls_path = tmp_path / "empty_sheet.xls"
|
||||
wb = xlwt.Workbook()
|
||||
wb.add_sheet("Empty") # no data written
|
||||
ws = wb.add_sheet("Data")
|
||||
ws.write(0, 0, "Col1")
|
||||
ws.write(1, 0, "Val1")
|
||||
wb.save(str(xls_path))
|
||||
|
||||
result = parser.parse(str(xls_path))
|
||||
assert "## Empty" not in result
|
||||
assert "## Data" in result
|
||||
|
||||
def test_pipe_escaped(self, parser, tmp_path):
|
||||
"""单元格中的 | 应被转义为 |"""
|
||||
xls_path = tmp_path / "pipe.xls"
|
||||
_create_xls(xls_path, {
|
||||
"Sheet1": [["Header"], ["value|with|pipes"]],
|
||||
})
|
||||
result = parser.parse(str(xls_path))
|
||||
assert "|" in result
|
||||
assert "value|with|pipes" in result
|
||||
|
||||
def test_newline_escaped(self, parser, tmp_path):
|
||||
"""单元格中的换行符应被转义为 <br>"""
|
||||
xls_path = tmp_path / "newline.xls"
|
||||
_create_xls(xls_path, {
|
||||
"Sheet1": [["Header"], ["line1\nline2"]],
|
||||
})
|
||||
result = parser.parse(str(xls_path))
|
||||
assert "line1<br>line2" in result
|
||||
|
||||
def test_backtick_escaped(self, parser, tmp_path):
|
||||
"""单元格中的反引号应被转义为 `"""
|
||||
xls_path = tmp_path / "backtick.xls"
|
||||
_create_xls(xls_path, {
|
||||
"Sheet1": [["Header"], ["code `snippet`"]],
|
||||
})
|
||||
result = parser.parse(str(xls_path))
|
||||
assert "`" in result
|
||||
|
||||
def test_empty_cell_becomes_empty(self, parser, tmp_path):
|
||||
"""空单元格应显示为空字符串"""
|
||||
xls_path = tmp_path / "empty_cell.xls"
|
||||
wb = xlwt.Workbook()
|
||||
ws = wb.add_sheet("Sheet1")
|
||||
ws.write(0, 0, "A")
|
||||
ws.write(0, 1, "B")
|
||||
ws.write(1, 0, "val")
|
||||
# cell (1,1) is not written — will be empty
|
||||
wb.save(str(xls_path))
|
||||
|
||||
result = parser.parse(str(xls_path))
|
||||
assert "| val | |" in result
|
||||
|
||||
def test_sheet_name_as_heading(self, parser, tmp_path):
|
||||
"""工作表名称应作为 ## 标题"""
|
||||
xls_path = tmp_path / "named.xls"
|
||||
_create_xls(xls_path, {
|
||||
"Sales Report": [["Month", "Revenue"], ["Jan", "1000"]],
|
||||
})
|
||||
result = parser.parse(str(xls_path))
|
||||
assert "## Sales Report" in result
|
||||
|
||||
def test_nonexistent_file_raises(self, parser):
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/nonexistent/path/file.xls")
|
||||
assert "file.xls" in exc_info.value.file_name
|
||||
assert exc_info.value.reason != ""
|
||||
|
||||
def test_corrupted_file_raises(self, parser, tmp_path):
|
||||
xls_path = tmp_path / "corrupted.xls"
|
||||
xls_path.write_bytes(b"this is not an xls file")
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse(str(xls_path))
|
||||
assert "corrupted.xls" in exc_info.value.file_name
|
||||
|
||||
def test_parse_error_contains_filename(self, parser):
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/no/such/report.xls")
|
||||
assert exc_info.value.file_name == "report.xls"
|
||||
|
||||
def test_numeric_values(self, parser, tmp_path):
|
||||
"""数值类型应正确转换为字符串"""
|
||||
xls_path = tmp_path / "numeric.xls"
|
||||
_create_xls(xls_path, {
|
||||
"Sheet1": [["Int", "Float"], [42, 3.14]],
|
||||
})
|
||||
result = parser.parse(str(xls_path))
|
||||
assert "42" in result
|
||||
assert "3.14" in result
|
||||
|
||||
def test_crlf_escaped(self, parser, tmp_path):
|
||||
"""\\r\\n 应被转义为 <br>"""
|
||||
xls_path = tmp_path / "crlf.xls"
|
||||
_create_xls(xls_path, {
|
||||
"Sheet1": [["Header"], ["line1\r\nline2"]],
|
||||
})
|
||||
result = parser.parse(str(xls_path))
|
||||
assert "line1<br>line2" in result
|
||||
220
tests/test_xlsx_parser.py
Normal file
220
tests/test_xlsx_parser.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""XlsxParser 单元测试"""
|
||||
|
||||
import pytest
|
||||
from openpyxl import Workbook
|
||||
|
||||
from exceptions import ParseError
|
||||
from parsers.xlsx_parser import XlsxParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser():
|
||||
return XlsxParser()
|
||||
|
||||
|
||||
def _create_xlsx(path, sheets=None):
|
||||
"""
|
||||
创建测试用 XLSX 文件。
|
||||
|
||||
Args:
|
||||
path: 输出文件路径
|
||||
sheets: dict,key 为 sheet 名称,value 为二维列表(行×列的数据)
|
||||
如果为 None,创建空工作簿
|
||||
"""
|
||||
wb = Workbook()
|
||||
# 删除默认 sheet
|
||||
wb.remove(wb.active)
|
||||
|
||||
if sheets:
|
||||
for sheet_name, rows in sheets.items():
|
||||
ws = wb.create_sheet(title=sheet_name)
|
||||
for row in rows:
|
||||
ws.append(row)
|
||||
|
||||
wb.save(str(path))
|
||||
|
||||
|
||||
def _create_xlsx_with_merge(path, sheet_name, rows, merges):
|
||||
"""
|
||||
创建带合并单元格的 XLSX 文件。
|
||||
|
||||
Args:
|
||||
path: 输出文件路径
|
||||
sheet_name: 工作表名称
|
||||
rows: 二维列表(行×列的数据)
|
||||
merges: 合并区域列表,如 ["A1:B1", "A2:A3"]
|
||||
"""
|
||||
wb = Workbook()
|
||||
wb.remove(wb.active)
|
||||
ws = wb.create_sheet(title=sheet_name)
|
||||
|
||||
for row in rows:
|
||||
ws.append(row)
|
||||
|
||||
for merge_range in merges:
|
||||
ws.merge_cells(merge_range)
|
||||
|
||||
wb.save(str(path))
|
||||
|
||||
|
||||
class TestSupportedExtensions:
|
||||
def test_supports_xlsx(self, parser):
|
||||
assert ".xlsx" in parser.supported_extensions()
|
||||
|
||||
def test_only_one_extension(self, parser):
|
||||
assert len(parser.supported_extensions()) == 1
|
||||
|
||||
|
||||
class TestParse:
|
||||
def test_simple_table(self, parser, tmp_path):
|
||||
"""基本表格转换为 Markdown"""
|
||||
xlsx_path = tmp_path / "simple.xlsx"
|
||||
_create_xlsx(xlsx_path, {
|
||||
"Sheet1": [
|
||||
["Name", "Age"],
|
||||
["Alice", 30],
|
||||
["Bob", 25],
|
||||
]
|
||||
})
|
||||
result = parser.parse(str(xlsx_path))
|
||||
assert "## Sheet1" in result
|
||||
assert "| Name | Age |" in result
|
||||
assert "| --- | --- |" in result
|
||||
assert "| Alice | 30 |" in result
|
||||
assert "| Bob | 25 |" in result
|
||||
|
||||
def test_multiple_sheets(self, parser, tmp_path):
|
||||
"""多个工作表各自生成标题和表格"""
|
||||
xlsx_path = tmp_path / "multi.xlsx"
|
||||
_create_xlsx(xlsx_path, {
|
||||
"Users": [["Name"], ["Alice"]],
|
||||
"Orders": [["ID"], ["001"]],
|
||||
})
|
||||
result = parser.parse(str(xlsx_path))
|
||||
assert "## Users" in result
|
||||
assert "## Orders" in result
|
||||
assert "| Name |" in result
|
||||
assert "| ID |" in result
|
||||
|
||||
def test_empty_sheet_skipped(self, parser, tmp_path):
|
||||
"""空工作表应被跳过"""
|
||||
xlsx_path = tmp_path / "empty_sheet.xlsx"
|
||||
_create_xlsx(xlsx_path, {
|
||||
"Empty": [],
|
||||
"Data": [["Col1"], ["Val1"]],
|
||||
})
|
||||
result = parser.parse(str(xlsx_path))
|
||||
assert "## Empty" not in result
|
||||
assert "## Data" in result
|
||||
|
||||
def test_all_empty_sheets(self, parser, tmp_path):
|
||||
"""所有工作表都为空时返回空字符串"""
|
||||
xlsx_path = tmp_path / "all_empty.xlsx"
|
||||
_create_xlsx(xlsx_path, {"Empty1": [], "Empty2": []})
|
||||
result = parser.parse(str(xlsx_path))
|
||||
assert result.strip() == ""
|
||||
|
||||
def test_pipe_escaped(self, parser, tmp_path):
|
||||
"""单元格中的 | 应被转义为 |"""
|
||||
xlsx_path = tmp_path / "pipe.xlsx"
|
||||
_create_xlsx(xlsx_path, {
|
||||
"Sheet1": [["Header"], ["value|with|pipes"]],
|
||||
})
|
||||
result = parser.parse(str(xlsx_path))
|
||||
assert "|" in result
|
||||
assert "value|with|pipes" in result
|
||||
|
||||
def test_newline_escaped(self, parser, tmp_path):
|
||||
"""单元格中的换行符应被转义为 <br>"""
|
||||
xlsx_path = tmp_path / "newline.xlsx"
|
||||
_create_xlsx(xlsx_path, {
|
||||
"Sheet1": [["Header"], ["line1\nline2"]],
|
||||
})
|
||||
result = parser.parse(str(xlsx_path))
|
||||
assert "line1<br>line2" in result
|
||||
|
||||
def test_backtick_escaped(self, parser, tmp_path):
|
||||
"""单元格中的反引号应被转义为 `"""
|
||||
xlsx_path = tmp_path / "backtick.xlsx"
|
||||
_create_xlsx(xlsx_path, {
|
||||
"Sheet1": [["Header"], ["code `snippet`"]],
|
||||
})
|
||||
result = parser.parse(str(xlsx_path))
|
||||
assert "`" in result
|
||||
|
||||
def test_none_cell_becomes_empty(self, parser, tmp_path):
|
||||
"""None 值的单元格应显示为空"""
|
||||
xlsx_path = tmp_path / "none.xlsx"
|
||||
_create_xlsx(xlsx_path, {
|
||||
"Sheet1": [["A", "B"], ["val", None]],
|
||||
})
|
||||
result = parser.parse(str(xlsx_path))
|
||||
assert "| val | |" in result
|
||||
|
||||
def test_merged_cells(self, parser, tmp_path):
|
||||
"""合并单元格应填充左上角的值"""
|
||||
xlsx_path = tmp_path / "merged.xlsx"
|
||||
_create_xlsx_with_merge(
|
||||
xlsx_path,
|
||||
sheet_name="Data",
|
||||
rows=[
|
||||
["Category", "Value"],
|
||||
["Fruit", 10],
|
||||
[None, 20], # A3 will be merged with A2
|
||||
],
|
||||
merges=["A2:A3"],
|
||||
)
|
||||
result = parser.parse(str(xlsx_path))
|
||||
assert "## Data" in result
|
||||
# The merged cell (A3) should have the value from A2 ("Fruit")
|
||||
lines = result.split("\n")
|
||||
data_lines = [l for l in lines if l.startswith("| ") and "---" not in l and "Category" not in l]
|
||||
assert len(data_lines) == 2
|
||||
# Both data rows should contain "Fruit"
|
||||
assert all("Fruit" in line for line in data_lines)
|
||||
|
||||
def test_sheet_name_as_heading(self, parser, tmp_path):
|
||||
"""工作表名称应作为 ## 标题"""
|
||||
xlsx_path = tmp_path / "named.xlsx"
|
||||
_create_xlsx(xlsx_path, {
|
||||
"Sales Report": [["Month", "Revenue"], ["Jan", "1000"]],
|
||||
})
|
||||
result = parser.parse(str(xlsx_path))
|
||||
assert "## Sales Report" in result
|
||||
|
||||
def test_nonexistent_file_raises(self, parser):
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/nonexistent/path/file.xlsx")
|
||||
assert "file.xlsx" in exc_info.value.file_name
|
||||
assert exc_info.value.reason != ""
|
||||
|
||||
def test_corrupted_file_raises(self, parser, tmp_path):
|
||||
xlsx_path = tmp_path / "corrupted.xlsx"
|
||||
xlsx_path.write_bytes(b"this is not an xlsx file")
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse(str(xlsx_path))
|
||||
assert "corrupted.xlsx" in exc_info.value.file_name
|
||||
|
||||
def test_parse_error_contains_filename(self, parser):
|
||||
with pytest.raises(ParseError) as exc_info:
|
||||
parser.parse("/no/such/report.xlsx")
|
||||
assert exc_info.value.file_name == "report.xlsx"
|
||||
|
||||
def test_numeric_values(self, parser, tmp_path):
|
||||
"""数值类型应正确转换为字符串"""
|
||||
xlsx_path = tmp_path / "numeric.xlsx"
|
||||
_create_xlsx(xlsx_path, {
|
||||
"Sheet1": [["Int", "Float"], [42, 3.14]],
|
||||
})
|
||||
result = parser.parse(str(xlsx_path))
|
||||
assert "42" in result
|
||||
assert "3.14" in result
|
||||
|
||||
def test_crlf_escaped(self, parser, tmp_path):
|
||||
"""\\r\\n 应被转义为 <br>"""
|
||||
xlsx_path = tmp_path / "crlf.xlsx"
|
||||
_create_xlsx(xlsx_path, {
|
||||
"Sheet1": [["Header"], ["line1\r\nline2"]],
|
||||
})
|
||||
result = parser.parse(str(xlsx_path))
|
||||
assert "line1<br>line2" in result
|
||||
Reference in New Issue
Block a user