Initial commit: AI 知识库文档智能分块工具

This commit is contained in:
AI Knowledge Splitter
2026-03-02 17:38:28 +08:00
commit 92e7fc5bda
160 changed files with 9577 additions and 0 deletions

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
# tests package

251
tests/test_api_client.py Normal file
View File

@@ -0,0 +1,251 @@
"""ApiClient 单元测试"""
import pytest
from unittest.mock import MagicMock
import openai
from api_client import ApiClient
from exceptions import ApiError
def _make_completion_response(content: str):
"""构造模拟的 ChatCompletion 响应"""
message = MagicMock()
message.content = content
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
return response
def _make_rate_limit_error():
"""构造 openai.RateLimitError"""
return openai.RateLimitError(
message="Rate limit exceeded",
response=MagicMock(status_code=429),
body=None,
)
def _make_api_error(status_code=500, message="Internal server error"):
"""构造非速率限制的 openai.APIStatusError"""
return openai.APIStatusError(
message=message,
response=MagicMock(status_code=status_code),
body=None,
)
def _make_client(**kwargs):
"""创建注入 mock OpenAI client 的 ApiClient"""
mock_openai = MagicMock()
sleep_fn = kwargs.get("sleep_fn", MagicMock())
return ApiClient(
api_key="test-key",
_client=mock_openai,
_sleep=sleep_fn,
), mock_openai, sleep_fn
class TestApiClientChat:
"""chat() 方法测试"""
def test_successful_chat(self):
"""成功调用 chat 返回内容"""
client, mock_openai, sleep_fn = _make_client()
expected = "这是 AI 的回复"
mock_openai.chat.completions.create.return_value = _make_completion_response(expected)
result = client.chat("你是助手", "你好")
assert result == expected
mock_openai.chat.completions.create.assert_called_once_with(
model="deepseek-chat",
messages=[
{"role": "system", "content": "你是助手"},
{"role": "user", "content": "你好"},
],
)
sleep_fn.assert_not_called()
def test_chat_custom_model(self):
"""chat 支持自定义模型"""
client, mock_openai, _ = _make_client()
mock_openai.chat.completions.create.return_value = _make_completion_response("ok")
client.chat("sys", "user", model="deepseek-reasoner")
mock_openai.chat.completions.create.assert_called_once_with(
model="deepseek-reasoner",
messages=[
{"role": "system", "content": "sys"},
{"role": "user", "content": "user"},
],
)
def test_chat_retry_on_429_then_success(self):
"""chat 遇到 429 后重试成功"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_completion_response("成功"),
]
result = client.chat("sys", "user")
assert result == "成功"
assert sleep_fn.call_count == 2
sleep_fn.assert_any_call(1)
sleep_fn.assert_any_call(2)
def test_chat_retry_exhausted_raises_api_error(self):
"""chat 重试耗尽抛出 ApiError"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
]
with pytest.raises(ApiError, match="速率限制重试耗尽") as exc_info:
client.chat("sys", "user")
assert exc_info.value.status_code == 429
assert sleep_fn.call_count == 3
sleep_fn.assert_any_call(1)
sleep_fn.assert_any_call(2)
sleep_fn.assert_any_call(4)
def test_chat_non_429_error_raises_immediately(self):
"""chat 遇到非 429 错误立即抛出 ApiError不重试"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = _make_api_error(500)
with pytest.raises(ApiError) as exc_info:
client.chat("sys", "user")
assert exc_info.value.status_code == 500
sleep_fn.assert_not_called()
class TestApiClientVision:
"""vision() 方法测试"""
def test_successful_vision(self):
"""成功调用 vision 返回内容"""
client, mock_openai, sleep_fn = _make_client()
expected = "图片中包含一段文字"
mock_openai.chat.completions.create.return_value = _make_completion_response(expected)
result = client.vision("识别图片", "aW1hZ2VfZGF0YQ==")
assert result == expected
mock_openai.chat.completions.create.assert_called_once_with(
model="deepseek-chat",
messages=[
{"role": "system", "content": "识别图片"},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,aW1hZ2VfZGF0YQ==",
},
},
],
},
],
)
sleep_fn.assert_not_called()
def test_vision_retry_on_429_then_success(self):
"""vision 遇到 429 后重试成功"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_completion_response("识别结果"),
]
result = client.vision("sys", "base64data")
assert result == "识别结果"
assert sleep_fn.call_count == 1
sleep_fn.assert_called_with(1)
def test_vision_retry_exhausted_raises_api_error(self):
"""vision 重试耗尽抛出 ApiError"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
]
with pytest.raises(ApiError, match="速率限制重试耗尽"):
client.vision("sys", "base64data")
assert sleep_fn.call_count == 3
def test_vision_non_429_error_raises_immediately(self):
"""vision 遇到非 429 错误立即抛出"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = _make_api_error(401, "Unauthorized")
with pytest.raises(ApiError) as exc_info:
client.vision("sys", "base64data")
assert exc_info.value.status_code == 401
sleep_fn.assert_not_called()
class TestRetryDelays:
"""重试延迟验证"""
def test_retry_delays_are_exponential(self):
"""验证重试延迟为 1, 2, 4 秒"""
assert ApiClient.RETRY_DELAYS == [1, 2, 4]
assert ApiClient.MAX_RETRIES == 3
def test_single_retry_uses_correct_delay(self):
"""单次 429 后使用 1 秒延迟"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_completion_response("ok"),
]
client.chat("sys", "user")
sleep_fn.assert_called_once_with(1)
def test_three_retries_use_correct_delays(self):
"""三次 429 后使用 1, 2, 4 秒延迟"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_completion_response("ok"),
]
result = client.chat("sys", "user")
assert result == "ok"
assert sleep_fn.call_count == 3
calls = [c.args[0] for c in sleep_fn.call_args_list]
assert calls == [1, 2, 4]

317
tests/test_chunker.py Normal file
View File

@@ -0,0 +1,317 @@
"""AIChunker 单元测试"""
import pytest
from unittest.mock import MagicMock, call
from chunker import AIChunker
from exceptions import ApiError
from models import Chunk
def _make_chunker(api_response="标题1\n\n内容1", delimiter="---"):
"""创建注入 mock ApiClient 的 AIChunker"""
mock_api = MagicMock()
mock_api.chat.return_value = api_response
chunker = AIChunker(api_client=mock_api, delimiter=delimiter)
return chunker, mock_api
class TestParseResponse:
"""_parse_response() 方法测试"""
def test_single_chunk(self):
"""解析单个分块"""
chunker, _ = _make_chunker()
result = chunker._parse_response("摘要标题\n\n这是分块内容")
assert len(result) == 1
assert result[0].title == "摘要标题"
assert result[0].content == "这是分块内容"
def test_multiple_chunks(self):
"""解析多个分块(用 delimiter 分隔)"""
chunker, _ = _make_chunker()
response = "标题一\n\n内容一\n\n---\n标题二\n\n内容二"
result = chunker._parse_response(response)
assert len(result) == 2
assert result[0].title == "标题一"
assert result[0].content == "内容一"
assert result[1].title == "标题二"
assert result[1].content == "内容二"
def test_skip_empty_parts(self):
"""跳过空片段"""
chunker, _ = _make_chunker()
response = "标题\n\n内容\n\n---\n\n---\n"
result = chunker._parse_response(response)
assert len(result) == 1
assert result[0].title == "标题"
def test_title_only_no_content(self):
"""只有标题没有内容的分块"""
chunker, _ = _make_chunker()
result = chunker._parse_response("仅标题")
assert len(result) == 1
assert result[0].title == "仅标题"
assert result[0].content == ""
def test_empty_response_raises_error(self):
"""空响应抛出 ApiError"""
chunker, _ = _make_chunker()
with pytest.raises(ApiError, match="API 返回空响应"):
chunker._parse_response("")
def test_whitespace_only_response_raises_error(self):
"""纯空白响应抛出 ApiError"""
chunker, _ = _make_chunker()
with pytest.raises(ApiError, match="API 返回空响应"):
chunker._parse_response(" \n\n ")
def test_custom_delimiter(self):
"""使用自定义分隔符解析"""
chunker, _ = _make_chunker(delimiter="===")
response = "标题A\n\n内容A\n\n===\n标题B\n\n内容B"
result = chunker._parse_response(response)
assert len(result) == 2
assert result[0].title == "标题A"
assert result[1].title == "标题B"
def test_strips_whitespace_from_parts(self):
"""去除片段首尾空白"""
chunker, _ = _make_chunker()
response = " \n标题\n\n内容\n "
result = chunker._parse_response(response)
assert len(result) == 1
assert result[0].title == "标题"
assert result[0].content == "内容"
class TestPreSplit:
"""_pre_split() 方法测试"""
def test_short_text_single_segment(self):
"""短文本不需要切分,返回单段"""
chunker, _ = _make_chunker()
text = "短文本内容"
result = chunker._pre_split(text)
assert len(result) == 1
assert result[0] == text
def test_split_on_paragraph_boundary(self):
"""在段落边界(双换行)处切分"""
chunker, _ = _make_chunker()
para1 = "a" * 7000
para2 = "b" * 7000
text = f"{para1}\n\n{para2}"
result = chunker._pre_split(text)
assert len(result) == 2
assert result[0] == para1
assert result[1] == para2
def test_greedy_merge_paragraphs(self):
"""贪心合并段落,不超过 PRE_SPLIT_SIZE"""
chunker, _ = _make_chunker()
para1 = "a" * 4000
para2 = "b" * 4000
para3 = "c" * 5000
text = f"{para1}\n\n{para2}\n\n{para3}"
result = chunker._pre_split(text)
# para1 + \n\n + para2 = 8002 <= 12000, so they merge
# adding para3 would be 8002 + 2 + 5000 = 13004 > 12000
assert len(result) == 2
assert result[0] == f"{para1}\n\n{para2}"
assert result[1] == para3
def test_single_paragraph_exceeds_limit_split_by_newline(self):
"""单段落超限时按单换行符切分"""
chunker, _ = _make_chunker()
line1 = "x" * 7000
line2 = "y" * 7000
# 单个段落(无双换行),但有单换行
text = f"{line1}\n{line2}"
result = chunker._pre_split(text)
assert len(result) == 2
assert result[0] == line1
assert result[1] == line2
def test_hard_split_very_long_line(self):
"""超长单行硬切分"""
chunker, _ = _make_chunker()
# 一行超过 PRE_SPLIT_SIZE 且无段落/换行分隔
text = "a" * 30000
result = chunker._pre_split(text)
assert len(result) >= 2
for seg in result:
assert len(seg) <= chunker.PRE_SPLIT_SIZE
# 拼接后内容不丢失
assert "".join(result) == text
def test_no_content_loss(self):
"""预切分后拼接不丢失内容"""
chunker, _ = _make_chunker()
para1 = "a" * 5000
para2 = "b" * 5000
para3 = "c" * 5000
text = f"{para1}\n\n{para2}\n\n{para3}"
result = chunker._pre_split(text)
joined = "\n\n".join(result)
assert para1 in joined
assert para2 in joined
assert para3 in joined
def test_each_segment_within_limit(self):
"""每段不超过 PRE_SPLIT_SIZE"""
chunker, _ = _make_chunker()
paragraphs = ["p" * 5000 for _ in range(10)]
text = "\n\n".join(paragraphs)
result = chunker._pre_split(text)
for seg in result:
assert len(seg) <= chunker.PRE_SPLIT_SIZE
class TestCallApi:
"""_call_api() 方法测试"""
def test_calls_api_with_correct_prompts(self):
"""调用 API 时使用正确的提示词"""
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
chunker._call_api("测试文本")
mock_api.chat.assert_called_once()
args = mock_api.chat.call_args
system_prompt = args[0][0] if args[0] else args[1].get("system_prompt")
user_content = args[0][1] if len(args[0]) > 1 else args[1].get("user_content")
# 系统提示词应包含 delimiter
assert "---" in system_prompt
# 用户提示词应包含文本内容
assert "测试文本" in user_content
def test_returns_parsed_chunks(self):
"""返回解析后的 Chunk 列表"""
chunker, _ = _make_chunker(api_response="标题\n\n内容")
result = chunker._call_api("文本")
assert len(result) == 1
assert isinstance(result[0], Chunk)
assert result[0].title == "标题"
def test_api_error_propagates(self):
"""API 错误向上传播"""
chunker, mock_api = _make_chunker()
mock_api.chat.side_effect = ApiError("调用失败")
with pytest.raises(ApiError, match="调用失败"):
chunker._call_api("文本")
class TestChunk:
"""chunk() 方法测试"""
def test_short_text_single_api_call(self):
"""短文本(≤ PRE_SPLIT_SIZE只调用一次 API"""
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
text = "短文本" * 100 # well under 12000
result = chunker.chunk(text)
assert len(result) == 1
assert mock_api.chat.call_count == 1
def test_long_text_multiple_api_calls(self):
"""长文本(> PRE_SPLIT_SIZE预切分后多次调用 API"""
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
# 创建超过 PRE_SPLIT_SIZE 的文本
text = ("段落内容" * 2000 + "\n\n") * 5
result = chunker.chunk(text)
assert mock_api.chat.call_count > 1
assert len(result) >= 1
def test_progress_callback_called(self):
"""长文本时 on_progress 回调被正确调用"""
chunker, mock_api = _make_chunker(api_response="标题\n\n内容")
# 构造需要 pre_split 的长文本
para1 = "a" * 7000
para2 = "b" * 7000
para3 = "c" * 7000
text = f"{para1}\n\n{para2}\n\n{para3}"
progress_calls = []
def on_progress(current, total):
progress_calls.append((current, total))
chunker.chunk(text, on_progress=on_progress)
# 验证 progress 回调
assert len(progress_calls) > 0
total = progress_calls[0][1]
for i, (current, t) in enumerate(progress_calls):
assert current == i + 1
assert t == total
def test_no_progress_callback_no_error(self):
"""不传 on_progress 不报错"""
chunker, _ = _make_chunker(api_response="标题\n\n内容")
text = "短文本"
result = chunker.chunk(text)
assert len(result) == 1
def test_short_text_no_progress_callback(self):
"""短文本不触发 on_progress 回调"""
chunker, _ = _make_chunker(api_response="标题\n\n内容")
progress_calls = []
chunker.chunk("短文本", on_progress=lambda c, t: progress_calls.append((c, t)))
assert len(progress_calls) == 0
def test_chunks_aggregated_from_segments(self):
"""多段的 chunks 被正确聚合"""
mock_api = MagicMock()
# 每次 API 调用返回不同的响应
mock_api.chat.side_effect = [
"标题A\n\n内容A",
"标题B\n\n内容B",
]
chunker = AIChunker(api_client=mock_api, delimiter="---")
# 构造需要 2 段的文本(每段 > 12000/2 使得合并后超限)
text = "a" * 7000 + "\n\n" + "b" * 7000
result = chunker.chunk(text)
assert len(result) == 2
assert result[0].title == "标题A"
assert result[1].title == "标题B"
class TestHardSplit:
"""_hard_split() 和 _find_sentence_boundary() 测试"""
def test_hard_split_preserves_content(self):
"""硬切分不丢失内容"""
chunker, _ = _make_chunker()
text = "x" * 30000
result = chunker._hard_split(text)
assert "".join(result) == text
def test_hard_split_respects_limit(self):
"""硬切分每段不超过 PRE_SPLIT_SIZE"""
chunker, _ = _make_chunker()
text = "x" * 30000
result = chunker._hard_split(text)
for seg in result:
assert len(seg) <= chunker.PRE_SPLIT_SIZE
def test_sentence_boundary_split(self):
"""硬切分优先在句子边界切分"""
chunker, _ = _make_chunker()
# 构造在中间有句号的超长文本,总长超过 12000
text = "a" * 9000 + "" + "b" * 5000
result = chunker._hard_split(text)
assert len(result) == 2
# 第一段应在句号后切分
assert result[0].endswith("")
def test_find_sentence_boundary_chinese(self):
"""中文句号作为句子边界"""
boundary = AIChunker._find_sentence_boundary("你好世界。再见")
assert boundary == 5 # "你好世界。" 的长度
def test_find_sentence_boundary_english(self):
"""英文句号作为句子边界"""
boundary = AIChunker._find_sentence_boundary("Hello world. Bye")
assert boundary == 12 # index of '.' is 11, return 11+1=12

105
tests/test_csv_parser.py Normal file
View File

@@ -0,0 +1,105 @@
"""CsvParser 单元测试"""
import pytest
from exceptions import ParseError
from parsers.csv_parser import CsvParser
@pytest.fixture
def parser():
return CsvParser()
class TestSupportedExtensions:
def test_supports_csv(self, parser):
assert ".csv" in parser.supported_extensions()
def test_only_one_extension(self, parser):
assert len(parser.supported_extensions()) == 1
class TestParse:
def test_basic_csv(self, parser, tmp_path):
f = tmp_path / "basic.csv"
f.write_text("name,age,city\nAlice,30,Beijing\nBob,25,Shanghai\n", encoding="utf-8")
result = parser.parse(str(f))
assert "| name | age | city |" in result
assert "| --- | --- | --- |" in result
assert "| Alice | 30 | Beijing |" in result
assert "| Bob | 25 | Shanghai |" in result
def test_empty_file(self, parser, tmp_path):
f = tmp_path / "empty.csv"
f.write_bytes(b"")
assert parser.parse(str(f)) == ""
def test_header_only(self, parser, tmp_path):
f = tmp_path / "header.csv"
f.write_text("col1,col2,col3\n", encoding="utf-8")
result = parser.parse(str(f))
assert "| col1 | col2 | col3 |" in result
assert "| --- | --- | --- |" in result
lines = result.strip().split("\n")
assert len(lines) == 2
def test_pipe_char_escaped(self, parser, tmp_path):
f = tmp_path / "pipe.csv"
f.write_text('header\n"a|b"\n', encoding="utf-8")
result = parser.parse(str(f))
assert "&#124;" in result
assert "a&#124;b" in result
def test_newline_in_cell(self, parser, tmp_path):
f = tmp_path / "newline.csv"
f.write_text('header\n"line1\nline2"\n', encoding="utf-8")
result = parser.parse(str(f))
assert "<br>" in result
assert "line1<br>line2" in result
def test_gbk_encoded_csv(self, parser, tmp_path):
f = tmp_path / "gbk.csv"
content = "姓名,年龄,城市\n张三,28,北京\n李四,32,上海\n"
f.write_bytes(content.encode("gbk"))
result = parser.parse(str(f))
assert "张三" in result
assert "北京" in result
def test_nonexistent_file_raises(self, parser):
with pytest.raises(ParseError) as exc_info:
parser.parse("/nonexistent/path/data.csv")
assert "data.csv" in exc_info.value.file_name
assert exc_info.value.reason != ""
def test_short_row_padded(self, parser, tmp_path):
"""Rows shorter than header should be padded with empty cells."""
f = tmp_path / "short.csv"
f.write_text("a,b,c\n1\n", encoding="utf-8")
result = parser.parse(str(f))
assert "| 1 | | |" in result
def test_result_ends_with_newline(self, parser, tmp_path):
f = tmp_path / "trail.csv"
f.write_text("h1,h2\nv1,v2\n", encoding="utf-8")
result = parser.parse(str(f))
assert result.endswith("\n")
class TestEscapeCell:
def test_no_special_chars(self):
assert CsvParser._escape_cell("hello") == "hello"
def test_pipe_escaped(self):
assert CsvParser._escape_cell("a|b") == "a&#124;b"
def test_newline_escaped(self):
assert CsvParser._escape_cell("a\nb") == "a<br>b"
def test_crlf_escaped(self):
assert CsvParser._escape_cell("a\r\nb") == "a<br>b"
def test_cr_escaped(self):
assert CsvParser._escape_cell("a\rb") == "a<br>b"
def test_combined_escapes(self):
assert CsvParser._escape_cell("a|b\nc") == "a&#124;b<br>c"

260
tests/test_doc_parser.py Normal file
View File

@@ -0,0 +1,260 @@
"""DocParser 单元测试"""
import pytest
from docx import Document
from docx.shared import Pt
from docx.enum.text import WD_ALIGN_PARAGRAPH
from exceptions import ParseError
from parsers.doc_parser import DocParser
@pytest.fixture
def parser():
return DocParser()
def _create_docx(path, paragraphs=None, tables=None):
"""
创建测试用 Word 文档。
Args:
path: 输出文件路径
paragraphs: 列表,每个元素是 dict:
- text: 段落文本
- style: 可选,样式名(如 'Heading 1'
- font_size: 可选,字体大小 (Pt)
- bold: 可选,是否加粗
tables: 列表,每个元素是二维列表(行×列的文本)
"""
doc = Document()
# 清除默认的空段落
for p in doc.paragraphs:
p._element.getparent().remove(p._element)
if paragraphs:
for para_info in paragraphs:
if isinstance(para_info, str):
doc.add_paragraph(para_info)
else:
text = para_info.get("text", "")
style = para_info.get("style", None)
font_size = para_info.get("font_size", None)
bold = para_info.get("bold", None)
if style:
p = doc.add_paragraph(text, style=style)
else:
p = doc.add_paragraph(text)
if font_size is not None or bold is not None:
# 需要通过 run 设置字体属性
# 清除默认 run重新添加
for run in p.runs:
if font_size is not None:
run.font.size = Pt(font_size)
if bold is not None:
run.bold = bold
if tables:
for table_data in tables:
if not table_data:
continue
rows = len(table_data)
cols = len(table_data[0]) if table_data else 0
table = doc.add_table(rows=rows, cols=cols)
for i, row_data in enumerate(table_data):
for j, cell_text in enumerate(row_data):
table.rows[i].cells[j].text = cell_text
doc.save(str(path))
class TestSupportedExtensions:
def test_supports_docx(self, parser):
assert ".docx" in parser.supported_extensions()
def test_only_one_extension(self, parser):
assert len(parser.supported_extensions()) == 1
class TestParse:
def test_parse_simple_text(self, parser, tmp_path):
docx_path = tmp_path / "simple.docx"
_create_docx(docx_path, paragraphs=["Hello, world!"])
result = parser.parse(str(docx_path))
assert "Hello, world!" in result
def test_parse_multiple_paragraphs(self, parser, tmp_path):
docx_path = tmp_path / "multi.docx"
_create_docx(docx_path, paragraphs=["First paragraph", "Second paragraph"])
result = parser.parse(str(docx_path))
assert "First paragraph" in result
assert "Second paragraph" in result
def test_heading_by_style_name(self, parser, tmp_path):
"""Heading style should produce Markdown heading"""
docx_path = tmp_path / "heading.docx"
_create_docx(docx_path, paragraphs=[
{"text": "Main Title", "style": "Heading 1"},
{"text": "Body text"},
])
result = parser.parse(str(docx_path))
assert "# Main Title" in result
# Should be exactly H1, not H2
assert "## Main Title" not in result
def test_heading2_by_style_name(self, parser, tmp_path):
docx_path = tmp_path / "h2.docx"
_create_docx(docx_path, paragraphs=[
{"text": "Section Title", "style": "Heading 2"},
{"text": "Some content"},
])
result = parser.parse(str(docx_path))
assert "## Section Title" in result
assert "### Section Title" not in result
def test_heading3_by_style_name(self, parser, tmp_path):
docx_path = tmp_path / "h3.docx"
_create_docx(docx_path, paragraphs=[
{"text": "Subsection", "style": "Heading 3"},
])
result = parser.parse(str(docx_path))
assert "### Subsection" in result
def test_heading_by_font_size_bold(self, parser, tmp_path):
"""Bold text with large font size should be detected as heading"""
docx_path = tmp_path / "font_heading.docx"
_create_docx(docx_path, paragraphs=[
{"text": "Big Bold Title", "font_size": 36, "bold": True},
{"text": "Normal text"},
])
result = parser.parse(str(docx_path))
assert "# Big Bold Title" in result
def test_heading_h2_by_font_size(self, parser, tmp_path):
docx_path = tmp_path / "font_h2.docx"
_create_docx(docx_path, paragraphs=[
{"text": "H2 Title", "font_size": 28, "bold": True},
{"text": "Normal text"},
])
result = parser.parse(str(docx_path))
assert "## H2 Title" in result
def test_heading_h5_by_font_size(self, parser, tmp_path):
docx_path = tmp_path / "font_h5.docx"
_create_docx(docx_path, paragraphs=[
{"text": "H5 Title", "font_size": 20, "bold": True},
{"text": "Normal text"},
])
result = parser.parse(str(docx_path))
assert "##### H5 Title" in result
def test_no_heading_without_bold(self, parser, tmp_path):
"""Large font without bold should NOT be detected as heading via font size"""
docx_path = tmp_path / "no_bold.docx"
_create_docx(docx_path, paragraphs=[
{"text": "Large Not Bold", "font_size": 36, "bold": False},
])
result = parser.parse(str(docx_path))
assert "# Large Not Bold" not in result
assert "Large Not Bold" in result
def test_simple_table(self, parser, tmp_path):
docx_path = tmp_path / "table.docx"
_create_docx(docx_path, tables=[
[["Name", "Age"], ["Alice", "30"], ["Bob", "25"]],
])
result = parser.parse(str(docx_path))
assert "| Name | Age |" in result
assert "| --- | --- |" in result
assert "| Alice | 30 |" in result
assert "| Bob | 25 |" in result
def test_table_with_pipe_in_cell(self, parser, tmp_path):
"""Pipe characters in cells should be escaped"""
docx_path = tmp_path / "pipe.docx"
_create_docx(docx_path, tables=[
[["Header"], ["value|with|pipes"]],
])
result = parser.parse(str(docx_path))
assert "&#124;" in result
assert "value&#124;with&#124;pipes" in result
def test_mixed_paragraphs_and_tables(self, parser, tmp_path):
"""Document with both paragraphs and tables"""
docx_path = tmp_path / "mixed.docx"
doc = Document()
# Clear default paragraph
for p in doc.paragraphs:
p._element.getparent().remove(p._element)
doc.add_paragraph("Introduction", style="Heading 1")
doc.add_paragraph("Some intro text.")
table = doc.add_table(rows=2, cols=2)
table.rows[0].cells[0].text = "Col1"
table.rows[0].cells[1].text = "Col2"
table.rows[1].cells[0].text = "A"
table.rows[1].cells[1].text = "B"
doc.add_paragraph("Conclusion")
doc.save(str(docx_path))
result = parser.parse(str(docx_path))
assert "# Introduction" in result
assert "Some intro text." in result
assert "| Col1 | Col2 |" in result
assert "| A | B |" in result
assert "Conclusion" in result
def test_empty_document(self, parser, tmp_path):
docx_path = tmp_path / "empty.docx"
doc = Document()
# Clear default paragraph
for p in doc.paragraphs:
p._element.getparent().remove(p._element)
doc.save(str(docx_path))
result = parser.parse(str(docx_path))
assert result.strip() == ""
def test_empty_paragraphs_skipped(self, parser, tmp_path):
docx_path = tmp_path / "empty_para.docx"
_create_docx(docx_path, paragraphs=["", "Actual content", ""])
result = parser.parse(str(docx_path))
assert "Actual content" in result
# Empty paragraphs should not produce extra lines
assert result.strip() == "Actual content"
def test_nonexistent_file_raises(self, parser):
with pytest.raises(ParseError) as exc_info:
parser.parse("/nonexistent/path/file.docx")
assert "file.docx" in exc_info.value.file_name
assert exc_info.value.reason != ""
def test_corrupted_file_raises(self, parser, tmp_path):
docx_path = tmp_path / "corrupted.docx"
docx_path.write_bytes(b"this is not a docx file at all")
with pytest.raises(ParseError) as exc_info:
parser.parse(str(docx_path))
assert "corrupted.docx" in exc_info.value.file_name
def test_parse_error_contains_filename(self, parser):
with pytest.raises(ParseError) as exc_info:
parser.parse("/no/such/report.docx")
assert exc_info.value.file_name == "report.docx"
def test_multiple_heading_levels(self, parser, tmp_path):
"""Test document with multiple heading levels via styles"""
docx_path = tmp_path / "levels.docx"
_create_docx(docx_path, paragraphs=[
{"text": "Title", "style": "Heading 1"},
{"text": "Chapter", "style": "Heading 2"},
{"text": "Section", "style": "Heading 3"},
{"text": "Body text"},
])
result = parser.parse(str(docx_path))
assert "# Title" in result
assert "## Chapter" in result
assert "### Section" in result
assert "Body text" in result
# Body text should not have heading prefix
assert "# Body text" not in result

64
tests/test_exceptions.py Normal file
View File

@@ -0,0 +1,64 @@
"""异常类型单元测试"""
import pytest
from exceptions import ApiError, ParseError, RateLimitError, UnsupportedFormatError
class TestParseError:
def test_attributes(self):
err = ParseError("test.pdf", "文件损坏")
assert err.file_name == "test.pdf"
assert err.reason == "文件损坏"
def test_message_format(self):
err = ParseError("data.csv", "编码无法识别")
assert str(err) == "解析失败 [data.csv]: 编码无法识别"
def test_is_exception(self):
err = ParseError("f.txt", "reason")
assert isinstance(err, Exception)
class TestUnsupportedFormatError:
def test_inherits_parse_error(self):
err = UnsupportedFormatError("file.xyz", ".xyz")
assert isinstance(err, ParseError)
def test_extension_attribute(self):
err = UnsupportedFormatError("file.abc", ".abc")
assert err.extension == ".abc"
def test_message_format(self):
err = UnsupportedFormatError("doc.bin", ".bin")
assert str(err) == "解析失败 [doc.bin]: 不支持的文件格式: .bin"
def test_file_name_propagated(self):
err = UnsupportedFormatError("my_file.xyz", ".xyz")
assert err.file_name == "my_file.xyz"
assert err.reason == "不支持的文件格式: .xyz"
class TestApiError:
def test_with_status_code(self):
err = ApiError("服务端错误", status_code=500)
assert err.status_code == 500
assert str(err) == "服务端错误"
def test_without_status_code(self):
err = ApiError("网络错误")
assert err.status_code is None
assert str(err) == "网络错误"
def test_is_exception(self):
assert isinstance(ApiError("msg"), Exception)
class TestRateLimitError:
def test_inherits_api_error(self):
err = RateLimitError("速率限制", status_code=429)
assert isinstance(err, ApiError)
def test_status_code(self):
err = RateLimitError("速率限制", status_code=429)
assert err.status_code == 429
assert str(err) == "速率限制"

145
tests/test_html_parser.py Normal file
View File

@@ -0,0 +1,145 @@
"""HtmlParser 单元测试"""
import pytest
from exceptions import ParseError
from parsers.html_parser import HtmlParser
@pytest.fixture
def parser():
return HtmlParser()
class TestSupportedExtensions:
def test_supports_html(self, parser):
assert ".html" in parser.supported_extensions()
def test_supports_htm(self, parser):
assert ".htm" in parser.supported_extensions()
def test_only_two_extensions(self, parser):
assert len(parser.supported_extensions()) == 2
class TestParse:
def test_parse_simple_html(self, parser, tmp_path):
f = tmp_path / "test.html"
f.write_text("<html><body><p>Hello, world!</p></body></html>", encoding="utf-8")
result = parser.parse(str(f))
assert "Hello, world!" in result
def test_parse_htm_extension(self, parser, tmp_path):
f = tmp_path / "test.htm"
f.write_text("<html><body><p>HTM file</p></body></html>", encoding="utf-8")
result = parser.parse(str(f))
assert "HTM file" in result
def test_parse_empty_file(self, parser, tmp_path):
f = tmp_path / "empty.html"
f.write_bytes(b"")
assert parser.parse(str(f)) == ""
def test_removes_script_tags(self, parser, tmp_path):
f = tmp_path / "script.html"
html = "<html><body><script>alert('xss');</script><p>Content</p></body></html>"
f.write_text(html, encoding="utf-8")
result = parser.parse(str(f))
assert "alert" not in result
assert "script" not in result.lower() or "Content" in result
assert "Content" in result
def test_removes_style_tags(self, parser, tmp_path):
f = tmp_path / "style.html"
html = "<html><head><style>body { color: red; }</style></head><body><p>Styled</p></body></html>"
f.write_text(html, encoding="utf-8")
result = parser.parse(str(f))
assert "color: red" not in result
assert "Styled" in result
def test_converts_headings_to_markdown(self, parser, tmp_path):
f = tmp_path / "headings.html"
html = "<html><body><h1>Title</h1><h2>Subtitle</h2><p>Text</p></body></html>"
f.write_text(html, encoding="utf-8")
result = parser.parse(str(f))
assert "# Title" in result
assert "## Subtitle" in result
def test_converts_links_to_markdown(self, parser, tmp_path):
f = tmp_path / "links.html"
html = '<html><body><a href="https://example.com">Example</a></body></html>'
f.write_text(html, encoding="utf-8")
result = parser.parse(str(f))
assert "Example" in result
assert "https://example.com" in result
def test_converts_lists_to_markdown(self, parser, tmp_path):
f = tmp_path / "lists.html"
html = "<html><body><ul><li>Item 1</li><li>Item 2</li></ul></body></html>"
f.write_text(html, encoding="utf-8")
result = parser.parse(str(f))
assert "Item 1" in result
assert "Item 2" in result
def test_meta_charset_detection(self, parser, tmp_path):
f = tmp_path / "charset.html"
html = '<html><head><meta charset="utf-8"></head><body><p>UTF-8 content</p></body></html>'
f.write_text(html, encoding="utf-8")
result = parser.parse(str(f))
assert "UTF-8 content" in result
def test_gbk_encoded_html_with_meta_charset(self, parser, tmp_path):
f = tmp_path / "gbk.html"
html = '<html><head><meta charset="gbk"></head><body><p>你好世界,这是中文内容测试</p></body></html>'
f.write_bytes(html.encode("gbk"))
result = parser.parse(str(f))
assert "你好世界" in result
def test_encoding_fallback_to_charset_normalizer(self, parser, tmp_path):
f = tmp_path / "no_meta.html"
html = "<html><body><p>Hello, this is a test with enough text for encoding detection to work properly.</p></body></html>"
f.write_bytes(html.encode("utf-8"))
result = parser.parse(str(f))
assert "Hello" in result
def test_nonexistent_file_raises(self, parser):
with pytest.raises(ParseError) as exc_info:
parser.parse("/nonexistent/path/file.html")
assert "file.html" in exc_info.value.file_name
assert exc_info.value.reason != ""
def test_parse_error_contains_filename(self, parser):
with pytest.raises(ParseError) as exc_info:
parser.parse("/no/such/mypage.html")
assert exc_info.value.file_name == "mypage.html"
def test_complex_html_removes_all_tags(self, parser, tmp_path):
f = tmp_path / "complex.html"
html = """<!DOCTYPE html>
<html>
<head>
<title>Test Page</title>
<style>.hidden { display: none; }</style>
<script>var x = 1;</script>
</head>
<body>
<div class="container">
<h1>Main Title</h1>
<p>Paragraph with <strong>bold</strong> and <em>italic</em> text.</p>
<script>console.log('inline script');</script>
<table>
<tr><th>Name</th><th>Value</th></tr>
<tr><td>A</td><td>1</td></tr>
</table>
</div>
</body>
</html>"""
f.write_text(html, encoding="utf-8")
result = parser.parse(str(f))
assert "Main Title" in result
assert "bold" in result.lower() or "**bold**" in result
assert "<script>" not in result
assert "<style>" not in result
assert "<div" not in result
assert "console.log" not in result
assert "var x" not in result

135
tests/test_image_parser.py Normal file
View File

@@ -0,0 +1,135 @@
"""ImageParser 单元测试"""
import base64
from unittest.mock import MagicMock
import pytest
from exceptions import ApiError, ParseError
from parsers.image_parser import ImageParser, DEFAULT_VISION_PROMPT
@pytest.fixture
def mock_api_client():
return MagicMock()
@pytest.fixture
def parser(mock_api_client):
return ImageParser(mock_api_client)
class TestSupportedExtensions:
def test_supports_png(self, parser):
assert ".png" in parser.supported_extensions()
def test_supports_jpg(self, parser):
assert ".jpg" in parser.supported_extensions()
def test_supports_jpeg(self, parser):
assert ".jpeg" in parser.supported_extensions()
def test_supports_bmp(self, parser):
assert ".bmp" in parser.supported_extensions()
def test_supports_gif(self, parser):
assert ".gif" in parser.supported_extensions()
def test_supports_webp(self, parser):
assert ".webp" in parser.supported_extensions()
def test_has_six_extensions(self, parser):
assert len(parser.supported_extensions()) == 6
class TestParse:
def test_successful_parse(self, mock_api_client, tmp_path):
"""成功解析图片文件,返回 Vision API 的文本描述"""
img = tmp_path / "photo.png"
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 20)
mock_api_client.vision.return_value = "图片中包含一段中文文字"
parser = ImageParser(mock_api_client)
result = parser.parse(str(img))
assert result == "图片中包含一段中文文字"
mock_api_client.vision.assert_called_once()
def test_base64_encoding_correctness(self, mock_api_client, tmp_path):
"""验证传递给 API 的 base64 编码与文件内容一致"""
raw_bytes = b"\x89PNG\r\n\x1a\nSOME_IMAGE_DATA"
img = tmp_path / "check.png"
img.write_bytes(raw_bytes)
mock_api_client.vision.return_value = "ok"
parser = ImageParser(mock_api_client)
parser.parse(str(img))
call_args = mock_api_client.vision.call_args
sent_base64 = call_args.kwargs.get("image_base64") or call_args[1].get("image_base64") or call_args[0][1]
assert base64.b64decode(sent_base64) == raw_bytes
def test_system_prompt_passed_to_api(self, mock_api_client, tmp_path):
"""验证使用了正确的系统提示词,且包含文件名上下文"""
img = tmp_path / "prompt.png"
img.write_bytes(b"\x00")
mock_api_client.vision.return_value = "text"
parser = ImageParser(mock_api_client)
parser.parse(str(img))
call_args = mock_api_client.vision.call_args
sent_prompt = call_args.kwargs.get("system_prompt") or call_args[0][0]
assert DEFAULT_VISION_PROMPT in sent_prompt
assert "prompt" in sent_prompt
def test_file_not_found_raises_parse_error(self, parser):
"""文件不存在时抛出 ParseError"""
with pytest.raises(ParseError) as exc_info:
parser.parse("/nonexistent/path/missing.png")
assert exc_info.value.file_name == "missing.png"
assert "文件读取失败" in exc_info.value.reason
def test_unreadable_file_raises_parse_error(self, mock_api_client, tmp_path):
"""文件无法读取时抛出 ParseError使用目录路径模拟不可读文件"""
dir_path = tmp_path / "fakefile.jpg"
dir_path.mkdir()
parser = ImageParser(mock_api_client)
with pytest.raises(ParseError) as exc_info:
parser.parse(str(dir_path))
assert exc_info.value.file_name == "fakefile.jpg"
assert "文件读取失败" in exc_info.value.reason
def test_api_error_raises_parse_error(self, mock_api_client, tmp_path):
"""API 调用失败时抛出 ParseError"""
img = tmp_path / "api_fail.png"
img.write_bytes(b"\x89PNG")
mock_api_client.vision.side_effect = ApiError("服务不可用", status_code=503)
parser = ImageParser(mock_api_client)
with pytest.raises(ParseError) as exc_info:
parser.parse(str(img))
assert exc_info.value.file_name == "api_fail.png"
assert "Vision API 调用失败" in exc_info.value.reason
def test_api_rate_limit_error_raises_parse_error(self, mock_api_client, tmp_path):
"""API 速率限制错误(经重试耗尽后)也被包装为 ParseError"""
img = tmp_path / "rate.png"
img.write_bytes(b"\x89PNG")
mock_api_client.vision.side_effect = ApiError("速率限制重试耗尽", status_code=429)
parser = ImageParser(mock_api_client)
with pytest.raises(ParseError) as exc_info:
parser.parse(str(img))
assert "Vision API 调用失败" in exc_info.value.reason
def test_parse_error_contains_filename_for_missing_file(self, parser):
"""ParseError 包含正确的文件名"""
with pytest.raises(ParseError) as exc_info:
parser.parse("/tmp/does_not_exist/myimage.jpeg")
assert exc_info.value.file_name == "myimage.jpeg"
assert exc_info.value.reason != ""

182
tests/test_main.py Normal file
View File

@@ -0,0 +1,182 @@
"""CLI 入口 main.py 单元测试"""
import os
import pytest
from unittest.mock import patch, MagicMock
from main import derive_output_path, build_parser, main
from exceptions import ParseError, UnsupportedFormatError, ApiError
class TestDeriveOutputPath:
"""默认输出路径推导测试"""
def test_pdf_to_md(self):
assert derive_output_path("report.pdf") == "report.md"
def test_xlsx_to_md(self):
assert derive_output_path("data.xlsx") == "data.md"
def test_with_directory(self):
assert derive_output_path("/home/user/docs/file.docx") == "/home/user/docs/file.md"
def test_txt_to_md(self):
assert derive_output_path("notes.txt") == "notes.md"
def test_no_extension(self):
assert derive_output_path("README") == "README.md"
def test_multiple_dots(self):
assert derive_output_path("my.report.v2.pdf") == "my.report.v2.md"
class TestBuildParser:
"""argparse 参数解析测试"""
def test_all_args(self):
parser = build_parser()
args = parser.parse_args(["input.pdf", "-k", "sk-abc", "-o", "out.md", "-d", "==="])
assert args.input_file == "input.pdf"
assert args.api_key == "sk-abc"
assert args.output == "out.md"
assert args.delimiter == "==="
def test_required_args_only(self):
parser = build_parser()
args = parser.parse_args(["input.pdf", "-k", "sk-abc"])
assert args.input_file == "input.pdf"
assert args.api_key == "sk-abc"
assert args.output is None
assert args.delimiter == "---"
def test_long_option_names(self):
parser = build_parser()
args = parser.parse_args(["input.pdf", "--api-key", "sk-abc", "--output", "out.md", "--delimiter", "***"])
assert args.api_key == "sk-abc"
assert args.output == "out.md"
assert args.delimiter == "***"
def test_missing_input_file(self):
parser = build_parser()
with pytest.raises(SystemExit) as exc_info:
parser.parse_args(["-k", "sk-abc"])
assert exc_info.value.code != 0
@patch.dict(os.environ, {}, clear=True)
def test_missing_api_key(self):
"""无 -k 且无环境变量时api_key 应为 None"""
parser = build_parser()
args = parser.parse_args(["input.pdf"])
assert args.api_key is None
class TestMainFunction:
"""main() 函数集成测试"""
@patch("main.Splitter")
def test_success(self, mock_splitter_cls):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc"]):
main()
mock_splitter_cls.assert_called_once_with(
api_key="sk-abc", delimiter="---",
pre_split_size=None, vision_prompt=None, output_format="markdown",
)
mock_splitter.process.assert_called_once_with("input.pdf", "input.md")
@patch("main.Splitter")
def test_custom_output(self, mock_splitter_cls):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc", "-o", "custom.md"]):
main()
mock_splitter.process.assert_called_once_with("input.pdf", "custom.md")
@patch("main.Splitter")
def test_custom_delimiter(self, mock_splitter_cls):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc", "-d", "==="]):
main()
mock_splitter_cls.assert_called_once_with(
api_key="sk-abc", delimiter="===",
pre_split_size=None, vision_prompt=None, output_format="markdown",
)
@patch("main.Splitter")
def test_file_not_found_error(self, mock_splitter_cls, capsys):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
mock_splitter.process.side_effect = FileNotFoundError("输入文件不存在: missing.pdf")
with patch("sys.argv", ["main.py", "missing.pdf", "-k", "sk-abc"]):
with pytest.raises(SystemExit) as exc_info:
main()
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "missing.pdf" in captured.err
@patch("main.Splitter")
def test_unsupported_format_error(self, mock_splitter_cls, capsys):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
mock_splitter.process.side_effect = UnsupportedFormatError("file.xyz", ".xyz")
with patch("sys.argv", ["main.py", "file.xyz", "-k", "sk-abc"]):
with pytest.raises(SystemExit) as exc_info:
main()
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert ".xyz" in captured.err
@patch("main.Splitter")
def test_parse_error(self, mock_splitter_cls, capsys):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
mock_splitter.process.side_effect = ParseError("bad.pdf", "文件损坏")
with patch("sys.argv", ["main.py", "bad.pdf", "-k", "sk-abc"]):
with pytest.raises(SystemExit) as exc_info:
main()
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "bad.pdf" in captured.err
@patch("main.Splitter")
def test_api_error(self, mock_splitter_cls, capsys):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
mock_splitter.process.side_effect = ApiError("认证失败", status_code=401)
with patch("sys.argv", ["main.py", "input.pdf", "-k", "bad-key"]):
with pytest.raises(SystemExit) as exc_info:
main()
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "API" in captured.err
@patch("main.Splitter")
def test_generic_exception(self, mock_splitter_cls, capsys):
mock_splitter = MagicMock()
mock_splitter_cls.return_value = mock_splitter
mock_splitter.process.side_effect = RuntimeError("意外错误")
with patch("sys.argv", ["main.py", "input.pdf", "-k", "sk-abc"]):
with pytest.raises(SystemExit) as exc_info:
main()
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "意外错误" in captured.err

57
tests/test_models.py Normal file
View File

@@ -0,0 +1,57 @@
"""核心数据结构单元测试"""
from datetime import datetime
from models import Chunk, CLIArgs, ProcessResult
class TestChunk:
def test_creation(self):
chunk = Chunk(title="概述", content="这是内容")
assert chunk.title == "概述"
assert chunk.content == "这是内容"
def test_equality(self):
a = Chunk(title="t", content="c")
b = Chunk(title="t", content="c")
assert a == b
class TestProcessResult:
def test_creation(self):
now = datetime.now()
chunks = [Chunk("t1", "c1"), Chunk("t2", "c2")]
result = ProcessResult(
source_file="input.pdf",
output_file="output.md",
chunks=chunks,
process_time=now,
total_chunks=2,
)
assert result.source_file == "input.pdf"
assert result.output_file == "output.md"
assert len(result.chunks) == 2
assert result.process_time == now
assert result.total_chunks == 2
class TestCLIArgs:
def test_required_fields(self):
args = CLIArgs(input_file="doc.pdf", api_key="sk-123")
assert args.input_file == "doc.pdf"
assert args.api_key == "sk-123"
def test_defaults(self):
args = CLIArgs(input_file="doc.pdf", api_key="sk-123")
assert args.output_file is None
assert args.delimiter == "---"
def test_custom_values(self):
args = CLIArgs(
input_file="doc.pdf",
api_key="sk-123",
output_file="out.md",
delimiter="***",
)
assert args.output_file == "out.md"
assert args.delimiter == "***"

View File

@@ -0,0 +1,82 @@
"""BaseParser 和 ParserRegistry 单元测试"""
import pytest
from typing import List
from exceptions import UnsupportedFormatError
from parsers.base import BaseParser, ParserRegistry
class StubParser(BaseParser):
"""用于测试的具体解析器实现"""
def __init__(self, extensions: List[str]):
self._extensions = extensions
def supported_extensions(self) -> List[str]:
return self._extensions
def parse(self, file_path: str) -> str:
return f"parsed: {file_path}"
class TestBaseParser:
def test_cannot_instantiate_directly(self):
with pytest.raises(TypeError):
BaseParser()
def test_concrete_subclass_works(self):
parser = StubParser([".txt"])
assert parser.supported_extensions() == [".txt"]
assert parser.parse("test.txt") == "parsed: test.txt"
class TestParserRegistry:
def test_empty_registry_raises(self):
registry = ParserRegistry()
with pytest.raises(UnsupportedFormatError):
registry.get_parser("file.pdf")
def test_register_and_get_parser(self):
registry = ParserRegistry()
pdf_parser = StubParser([".pdf"])
registry.register(pdf_parser)
assert registry.get_parser("document.pdf") is pdf_parser
def test_multiple_parsers(self):
registry = ParserRegistry()
pdf_parser = StubParser([".pdf"])
txt_parser = StubParser([".txt", ".md"])
registry.register(pdf_parser)
registry.register(txt_parser)
assert registry.get_parser("doc.pdf") is pdf_parser
assert registry.get_parser("readme.txt") is txt_parser
assert registry.get_parser("notes.md") is txt_parser
def test_unsupported_format_error_details(self):
registry = ParserRegistry()
registry.register(StubParser([".pdf"]))
with pytest.raises(UnsupportedFormatError) as exc_info:
registry.get_parser("file.xyz")
assert exc_info.value.extension == ".xyz"
assert exc_info.value.file_name == "file.xyz"
def test_case_insensitive_extension(self):
registry = ParserRegistry()
registry.register(StubParser([".pdf"]))
assert registry.get_parser("DOC.PDF") is not None
def test_file_path_with_directory(self):
registry = ParserRegistry()
parser = StubParser([".csv"])
registry.register(parser)
assert registry.get_parser("/home/user/data/report.csv") is parser
def test_first_matching_parser_wins(self):
registry = ParserRegistry()
first = StubParser([".txt"])
second = StubParser([".txt"])
registry.register(first)
registry.register(second)
assert registry.get_parser("file.txt") is first

159
tests/test_pdf_parser.py Normal file
View File

@@ -0,0 +1,159 @@
"""PdfParser 单元测试"""
import pytest
import fitz
from exceptions import ParseError
from parsers.pdf_parser import PdfParser
@pytest.fixture
def parser():
return PdfParser()
def _create_pdf(path, pages):
"""
创建测试用 PDF 文件。
Args:
path: 输出文件路径
pages: 列表,每个元素是 (text, fontsize) 元组的列表,代表一页中的文本行
"""
doc = fitz.open()
for page_items in pages:
page = doc.new_page()
y = 72
for text, fontsize in page_items:
page.insert_text((72, y), text, fontsize=fontsize)
y += fontsize + 10
doc.save(str(path))
doc.close()
class TestSupportedExtensions:
def test_supports_pdf(self, parser):
assert ".pdf" in parser.supported_extensions()
def test_only_one_extension(self, parser):
assert len(parser.supported_extensions()) == 1
class TestParse:
def test_parse_simple_text(self, parser, tmp_path):
pdf_path = tmp_path / "simple.pdf"
_create_pdf(pdf_path, [
[("Hello, world!", 12)],
])
result = parser.parse(str(pdf_path))
assert "Hello, world!" in result
def test_parse_multiline_text(self, parser, tmp_path):
pdf_path = tmp_path / "multi.pdf"
_create_pdf(pdf_path, [
[("Line one", 12), ("Line two", 12)],
])
result = parser.parse(str(pdf_path))
assert "Line one" in result
assert "Line two" in result
def test_parse_multiple_pages(self, parser, tmp_path):
pdf_path = tmp_path / "pages.pdf"
_create_pdf(pdf_path, [
[("Page one content", 12)],
[("Page two content", 12)],
])
result = parser.parse(str(pdf_path))
assert "Page one content" in result
assert "Page two content" in result
def test_heading_level2_detection(self, parser, tmp_path):
"""Font size > body_mode + 2 should produce ## heading"""
pdf_path = tmp_path / "h2.pdf"
# Body text at size 12 (will be the mode), heading at size 18 (diff=6 > 2)
_create_pdf(pdf_path, [
[
("Body text line one", 12),
("Body text line two", 12),
("Body text line three", 12),
("Big Heading", 18),
],
])
result = parser.parse(str(pdf_path))
assert "## Big Heading" in result
def test_heading_level3_detection(self, parser, tmp_path):
"""Font size > body_mode + 0.5 but <= body_mode + 2 should produce ### heading"""
pdf_path = tmp_path / "h3.pdf"
# Body text at size 12 (mode), heading at size 13.5 (diff=1.5, >0.5 and <=2)
_create_pdf(pdf_path, [
[
("Body text one", 12),
("Body text two", 12),
("Body text three", 12),
("Sub Heading", 13.5),
],
])
result = parser.parse(str(pdf_path))
assert "### Sub Heading" in result
def test_body_text_no_heading_prefix(self, parser, tmp_path):
"""Text at body font size should not have heading prefix"""
pdf_path = tmp_path / "body.pdf"
_create_pdf(pdf_path, [
[("Normal text", 12), ("More normal text", 12)],
])
result = parser.parse(str(pdf_path))
assert "## Normal text" not in result
assert "### Normal text" not in result
assert "Normal text" in result
def test_empty_pdf(self, parser, tmp_path):
"""Empty PDF (no text) should return empty string"""
pdf_path = tmp_path / "empty.pdf"
doc = fitz.open()
doc.new_page()
doc.save(str(pdf_path))
doc.close()
result = parser.parse(str(pdf_path))
assert result.strip() == ""
def test_nonexistent_file_raises(self, parser):
with pytest.raises(ParseError) as exc_info:
parser.parse("/nonexistent/path/file.pdf")
assert "file.pdf" in exc_info.value.file_name
assert exc_info.value.reason != ""
def test_corrupted_file_raises(self, parser, tmp_path):
pdf_path = tmp_path / "corrupted.pdf"
pdf_path.write_bytes(b"this is not a pdf file at all")
with pytest.raises(ParseError) as exc_info:
parser.parse(str(pdf_path))
assert "corrupted.pdf" in exc_info.value.file_name
def test_parse_error_contains_filename(self, parser):
with pytest.raises(ParseError) as exc_info:
parser.parse("/no/such/report.pdf")
assert exc_info.value.file_name == "report.pdf"
def test_mixed_headings_and_body(self, parser, tmp_path):
"""Test a document with mixed heading levels and body text"""
pdf_path = tmp_path / "mixed.pdf"
_create_pdf(pdf_path, [
[
("Body one", 12),
("Body two", 12),
("Body three", 12),
("Body four", 12),
("Body five", 12),
("Main Title", 20),
("Section Title", 14),
("Paragraph text", 12),
],
])
result = parser.parse(str(pdf_path))
assert "## Main Title" in result
assert "### Section Title" in result
# Body text should not have heading markers
assert "## Body one" not in result
assert "## Paragraph text" not in result

84
tests/test_prompts.py Normal file
View File

@@ -0,0 +1,84 @@
"""提示词模块单元测试。"""
from prompts import (
SYSTEM_PROMPT_TEMPLATE,
USER_PROMPT_TEMPLATE,
get_system_prompt,
get_user_prompt,
)
class TestSystemPromptTemplate:
"""系统提示词模板测试。"""
def test_contains_delimiter_placeholder(self):
assert "{delimiter}" in SYSTEM_PROMPT_TEMPLATE
def test_contains_semantic_completeness_rule(self):
assert "语义完整性" in SYSTEM_PROMPT_TEMPLATE
def test_contains_self_contained_rule(self):
assert "自包含性" in SYSTEM_PROMPT_TEMPLATE
def test_contains_heading_preservation_rule(self):
assert "标题层级保留" in SYSTEM_PROMPT_TEMPLATE
def test_contains_table_integrity_rule(self):
assert "表格完整性" in SYSTEM_PROMPT_TEMPLATE
def test_contains_granularity_rule(self):
assert "合理粒度" in SYSTEM_PROMPT_TEMPLATE
class TestUserPromptTemplate:
"""用户提示词模板测试。"""
def test_contains_text_content_placeholder(self):
assert "{text_content}" in USER_PROMPT_TEMPLATE
class TestGetSystemPrompt:
"""get_system_prompt 函数测试。"""
def test_default_delimiter(self):
result = get_system_prompt()
assert "---" in result
assert "{delimiter}" not in result
def test_custom_delimiter(self):
result = get_system_prompt("===SPLIT===")
assert "===SPLIT===" in result
assert "{delimiter}" not in result
def test_delimiter_appears_in_format_example(self):
result = get_system_prompt("***")
# 分隔符应出现在格式说明和示例中
assert "`***`" in result
def test_empty_delimiter(self):
result = get_system_prompt("")
assert "{delimiter}" not in result
class TestGetUserPrompt:
"""get_user_prompt 函数测试。"""
def test_text_content_substitution(self):
result = get_user_prompt("这是一段测试文本。")
assert "这是一段测试文本。" in result
assert "{text_content}" not in result
def test_preserves_surrounding_markers(self):
result = get_user_prompt("内容")
assert "---开始---" in result
assert "---结束---" in result
def test_multiline_content(self):
content = "第一行\n第二行\n第三行"
result = get_user_prompt(content)
assert content in result
def test_empty_content(self):
result = get_user_prompt("")
assert "{text_content}" not in result
assert "---开始---" in result

359
tests/test_splitter.py Normal file
View File

@@ -0,0 +1,359 @@
"""Splitter 协调器单元测试"""
import os
import pytest
from unittest.mock import MagicMock, patch, call
from exceptions import ApiError, ParseError, UnsupportedFormatError
from models import Chunk
from splitter import Splitter
@pytest.fixture
def mock_deps():
"""Patch all external dependencies and return their mocks."""
with (
patch("splitter.ApiClient") as mock_api_cls,
patch("splitter.AIChunker") as mock_chunker_cls,
patch("splitter.MarkdownWriter") as mock_writer_cls,
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry") as mock_registry_cls,
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
api_client = mock_api_cls.return_value
chunker = mock_chunker_cls.return_value
writer = mock_writer_cls.return_value
registry = mock_registry_cls.return_value
splitter = Splitter(api_key="test-key", delimiter="---")
yield {
"splitter": splitter,
"api_client": api_client,
"chunker": chunker,
"writer": writer,
"registry": registry,
}
class TestInit:
"""初始化测试"""
def test_registers_all_parsers(self):
"""验证所有解析器都被注册"""
with (
patch("splitter.ApiClient"),
patch("splitter.AIChunker"),
patch("splitter.MarkdownWriter"),
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry") as mock_registry_cls,
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
registry = mock_registry_cls.return_value
Splitter(api_key="test-key")
# 9 parsers: Text, Csv, Html, Pdf, Doc, LegacyDoc, Xlsx, Xls, Image
assert registry.register.call_count == 9
def test_creates_api_client_with_key(self):
"""验证 ApiClient 使用正确的 api_key 创建"""
with (
patch("splitter.ApiClient") as mock_api_cls,
patch("splitter.AIChunker"),
patch("splitter.MarkdownWriter"),
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry"),
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
Splitter(api_key="my-secret-key")
mock_api_cls.assert_called_once_with(api_key="my-secret-key")
def test_creates_chunker_with_delimiter(self):
"""验证 AIChunker 使用正确的 delimiter 创建"""
with (
patch("splitter.ApiClient") as mock_api_cls,
patch("splitter.AIChunker") as mock_chunker_cls,
patch("splitter.MarkdownWriter"),
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry"),
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
Splitter(api_key="key", delimiter="===")
mock_chunker_cls.assert_called_once_with(
mock_api_cls.return_value, "===", pre_split_size=None
)
class TestProcessSuccess:
"""成功处理流程测试"""
def test_full_flow(self, mock_deps, tmp_path, capsys):
"""验证完整的成功处理流程"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
writer = mock_deps["writer"]
# Setup
input_file = tmp_path / "test.txt"
input_file.write_text("hello")
output_file = str(tmp_path / "output.md")
mock_parser = MagicMock()
mock_parser.parse.return_value = "parsed text"
registry.get_parser.return_value = mock_parser
chunks = [Chunk(title="标题", content="内容")]
chunker.chunk.return_value = chunks
# Execute
splitter.process(str(input_file), output_file)
# Verify call chain
registry.get_parser.assert_called_once_with(str(input_file))
mock_parser.parse.assert_called_once_with(str(input_file))
chunker.chunk.assert_called_once()
assert chunker.chunk.call_args[0][0] == "parsed text"
writer.write.assert_called_once_with(
chunks, output_file, "test.txt", "---"
)
def test_logs_parsing_stage(self, mock_deps, tmp_path, capsys):
"""验证输出文件解析日志"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.pdf"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.return_value = [Chunk(title="t", content="c")]
splitter.process(str(input_file), str(tmp_path / "out.md"))
output = capsys.readouterr().out
assert "解析文件: doc.pdf" in output
def test_logs_chunking_stage(self, mock_deps, tmp_path, capsys):
"""验证输出 AI 分块日志"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.return_value = [Chunk(title="t", content="c")]
splitter.process(str(input_file), str(tmp_path / "out.md"))
output = capsys.readouterr().out
assert "AI 语义分块" in output
def test_logs_writing_stage(self, mock_deps, tmp_path, capsys):
"""验证输出写入日志"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
output_path = str(tmp_path / "out.md")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.return_value = [Chunk(title="t", content="c")]
splitter.process(str(input_file), output_path)
output = capsys.readouterr().out
assert "写入输出" in output
def test_logs_summary(self, mock_deps, tmp_path, capsys):
"""验证输出处理摘要"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.return_value = [
Chunk(title="t1", content="c1"),
Chunk(title="t2", content="c2"),
Chunk(title="t3", content="c3"),
]
splitter.process(str(input_file), str(tmp_path / "out.md"))
output = capsys.readouterr().out
assert "3 个分块" in output
def test_progress_callback_passed_to_chunker(self, mock_deps, tmp_path, capsys):
"""验证进度回调被传递给 chunker 并正确输出"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
# Simulate chunker calling the progress callback
def fake_chunk(text, content_type=None, source_file="", on_progress=None):
if on_progress:
on_progress(1, 3)
on_progress(2, 3)
on_progress(3, 3)
return [Chunk(title="t", content="c")]
chunker.chunk.side_effect = fake_chunk
splitter.process(str(input_file), str(tmp_path / "out.md"))
output = capsys.readouterr().out
assert "分块进度: 1/3" in output
assert "分块进度: 2/3" in output
assert "分块进度: 3/3" in output
class TestProcessErrors:
"""错误处理测试"""
def test_file_not_found(self, mock_deps):
"""验证文件不存在时抛出 FileNotFoundError"""
splitter = mock_deps["splitter"]
with pytest.raises(FileNotFoundError, match="输入文件不存在"):
splitter.process("/nonexistent/path/file.txt", "output.md")
def test_unsupported_format(self, mock_deps, tmp_path):
"""验证不支持的格式时抛出 UnsupportedFormatError"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
input_file = tmp_path / "file.xyz"
input_file.write_text("data")
registry.get_parser.side_effect = UnsupportedFormatError("file.xyz", ".xyz")
with pytest.raises(UnsupportedFormatError):
splitter.process(str(input_file), str(tmp_path / "out.md"))
def test_parse_error(self, mock_deps, tmp_path):
"""验证解析错误时抛出 ParseError"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
input_file = tmp_path / "bad.pdf"
input_file.write_bytes(b"\x00\x01\x02")
mock_parser = MagicMock()
mock_parser.parse.side_effect = ParseError("bad.pdf", "文件损坏")
registry.get_parser.return_value = mock_parser
with pytest.raises(ParseError, match="bad.pdf"):
splitter.process(str(input_file), str(tmp_path / "out.md"))
def test_api_error(self, mock_deps, tmp_path):
"""验证 API 错误时抛出 ApiError"""
splitter = mock_deps["splitter"]
registry = mock_deps["registry"]
chunker = mock_deps["chunker"]
input_file = tmp_path / "doc.txt"
input_file.write_text("data")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
registry.get_parser.return_value = mock_parser
chunker.chunk.side_effect = ApiError("API 调用失败")
with pytest.raises(ApiError, match="API 调用失败"):
splitter.process(str(input_file), str(tmp_path / "out.md"))
class TestCustomDelimiter:
"""自定义分隔符测试"""
def test_delimiter_passed_to_writer(self, tmp_path):
"""验证自定义分隔符传递给 writer"""
with (
patch("splitter.ApiClient"),
patch("splitter.AIChunker") as mock_chunker_cls,
patch("splitter.MarkdownWriter") as mock_writer_cls,
patch("splitter.JsonWriter"),
patch("splitter.ParserRegistry") as mock_registry_cls,
patch("splitter.TextParser"),
patch("splitter.CsvParser"),
patch("splitter.HtmlParser"),
patch("splitter.PdfParser"),
patch("splitter.DocParser"),
patch("splitter.LegacyDocParser"),
patch("splitter.XlsxParser"),
patch("splitter.XlsParser"),
patch("splitter.ImageParser"),
):
splitter = Splitter(api_key="key", delimiter="===")
input_file = tmp_path / "test.txt"
input_file.write_text("hello")
mock_parser = MagicMock()
mock_parser.parse.return_value = "text"
mock_registry_cls.return_value.get_parser.return_value = mock_parser
chunks = [Chunk(title="t", content="c")]
mock_chunker_cls.return_value.chunk.return_value = chunks
writer = mock_writer_cls.return_value
splitter.process(str(input_file), str(tmp_path / "out.md"))
writer.write.assert_called_once_with(
chunks, str(tmp_path / "out.md"), "test.txt", "==="
)

83
tests/test_text_parser.py Normal file
View File

@@ -0,0 +1,83 @@
"""TextParser 单元测试"""
import os
import tempfile
import pytest
from exceptions import ParseError
from parsers.text_parser import TextParser
@pytest.fixture
def parser():
return TextParser()
class TestSupportedExtensions:
def test_supports_txt(self, parser):
assert ".txt" in parser.supported_extensions()
def test_supports_md(self, parser):
assert ".md" in parser.supported_extensions()
def test_only_two_extensions(self, parser):
assert len(parser.supported_extensions()) == 2
class TestParse:
def test_parse_utf8_txt(self, parser, tmp_path):
f = tmp_path / "test.txt"
f.write_text("Hello, world!", encoding="utf-8")
assert parser.parse(str(f)) == "Hello, world!"
def test_parse_utf8_md(self, parser, tmp_path):
f = tmp_path / "readme.md"
content = "# Title\n\nSome **bold** text."
f.write_bytes(content.encode("utf-8"))
assert parser.parse(str(f)) == content
def test_parse_gbk_encoded_file(self, parser, tmp_path):
f = tmp_path / "chinese.txt"
# Use longer text so charset_normalizer can reliably detect GBK
content = "你好,世界!这是一段中文文本。我们正在测试文件编码的自动检测功能,需要足够长的文本才能让检测器准确识别编码格式。"
f.write_bytes(content.encode("gbk"))
result = parser.parse(str(f))
assert result == content
def test_parse_utf8_bom(self, parser, tmp_path):
f = tmp_path / "bom.txt"
content = "UTF-8 with BOM"
f.write_bytes(b"\xef\xbb\xbf" + content.encode("utf-8"))
result = parser.parse(str(f))
assert "UTF-8 with BOM" in result
def test_parse_empty_file(self, parser, tmp_path):
f = tmp_path / "empty.txt"
f.write_bytes(b"")
assert parser.parse(str(f)) == ""
def test_parse_multiline(self, parser, tmp_path):
f = tmp_path / "multi.md"
content = "Line 1\nLine 2\nLine 3\n"
f.write_bytes(content.encode("utf-8"))
assert parser.parse(str(f)) == content
def test_parse_nonexistent_file_raises(self, parser):
with pytest.raises(ParseError) as exc_info:
parser.parse("/nonexistent/path/file.txt")
assert "file.txt" in exc_info.value.file_name
assert exc_info.value.reason != ""
def test_parse_error_contains_filename(self, parser):
with pytest.raises(ParseError) as exc_info:
parser.parse("/no/such/myfile.txt")
assert exc_info.value.file_name == "myfile.txt"
def test_parse_latin1_encoded_file(self, parser, tmp_path):
f = tmp_path / "latin.txt"
content = "café résumé naïve"
f.write_bytes(content.encode("latin-1"))
result = parser.parse(str(f))
assert "caf" in result
assert "sum" in result

225
tests/test_writer.py Normal file
View File

@@ -0,0 +1,225 @@
"""MarkdownWriter 单元测试"""
import pytest
from models import Chunk
from writer import MarkdownWriter
@pytest.fixture
def writer():
return MarkdownWriter()
@pytest.fixture
def tmp_output(tmp_path):
return str(tmp_path / "output.md")
class TestSingleChunk:
"""单个 Chunk 输出测试"""
def test_single_chunk_no_delimiter(self, writer, tmp_output):
chunks = [Chunk(title="摘要标题", content="这是内容")]
writer.write(chunks, tmp_output, "test.pdf")
content = open(tmp_output, encoding="utf-8").read()
assert "---" not in content.split("-->", 1)[1]
def test_single_chunk_has_title(self, writer, tmp_output):
chunks = [Chunk(title="摘要标题", content="这是内容")]
writer.write(chunks, tmp_output, "test.pdf")
content = open(tmp_output, encoding="utf-8").read()
assert "## 摘要标题" in content
def test_single_chunk_has_content(self, writer, tmp_output):
chunks = [Chunk(title="摘要标题", content="这是内容")]
writer.write(chunks, tmp_output, "test.pdf")
content = open(tmp_output, encoding="utf-8").read()
assert "这是内容" in content
class TestMultipleChunks:
"""多个 Chunk 输出测试"""
def test_delimiter_between_chunks(self, writer, tmp_output):
chunks = [
Chunk(title="标题1", content="内容1"),
Chunk(title="标题2", content="内容2"),
Chunk(title="标题3", content="内容3"),
]
writer.write(chunks, tmp_output, "test.pdf")
content = open(tmp_output, encoding="utf-8").read()
after_meta = content.split("-->", 1)[1]
assert after_meta.count("\n---\n") == 2
def test_all_titles_present(self, writer, tmp_output):
chunks = [
Chunk(title="标题A", content="内容A"),
Chunk(title="标题B", content="内容B"),
]
writer.write(chunks, tmp_output, "test.pdf")
content = open(tmp_output, encoding="utf-8").read()
assert "## 标题A" in content
assert "## 标题B" in content
def test_all_contents_present(self, writer, tmp_output):
chunks = [
Chunk(title="标题A", content="内容A"),
Chunk(title="标题B", content="内容B"),
]
writer.write(chunks, tmp_output, "test.pdf")
content = open(tmp_output, encoding="utf-8").read()
assert "内容A" in content
assert "内容B" in content
def test_no_trailing_delimiter(self, writer, tmp_output):
chunks = [
Chunk(title="标题1", content="内容1"),
Chunk(title="标题2", content="内容2"),
]
writer.write(chunks, tmp_output, "test.pdf")
content = open(tmp_output, encoding="utf-8").read()
after_meta = content.split("-->", 1)[1]
# The last chunk content should appear after the last delimiter
# and there should be no delimiter after the last content
last_delimiter_pos = after_meta.rfind("\n---\n")
last_content_pos = after_meta.rfind("内容2")
assert last_content_pos > last_delimiter_pos
class TestMetaInfo:
"""元信息注释测试"""
def test_contains_source_file(self, writer, tmp_output):
chunks = [Chunk(title="标题", content="内容")]
writer.write(chunks, tmp_output, "example.pdf")
content = open(tmp_output, encoding="utf-8").read()
assert "源文件: example.pdf" in content
def test_contains_process_time(self, writer, tmp_output):
chunks = [Chunk(title="标题", content="内容")]
writer.write(chunks, tmp_output, "test.pdf")
content = open(tmp_output, encoding="utf-8").read()
assert "处理时间:" in content
def test_contains_chunk_count(self, writer, tmp_output):
chunks = [
Chunk(title="标题1", content="内容1"),
Chunk(title="标题2", content="内容2"),
Chunk(title="标题3", content="内容3"),
]
writer.write(chunks, tmp_output, "test.pdf")
content = open(tmp_output, encoding="utf-8").read()
assert "分块总数: 3" in content
def test_meta_is_html_comment(self, writer, tmp_output):
chunks = [Chunk(title="标题", content="内容")]
writer.write(chunks, tmp_output, "test.pdf")
content = open(tmp_output, encoding="utf-8").read()
assert content.startswith("<!-- ")
assert "-->" in content
def test_meta_at_file_start(self, writer, tmp_output):
chunks = [Chunk(title="标题", content="内容")]
writer.write(chunks, tmp_output, "test.pdf")
content = open(tmp_output, encoding="utf-8").read()
comment_end = content.index("-->")
title_pos = content.index("## 标题")
assert comment_end < title_pos
class TestFileOverwrite:
"""文件覆盖测试"""
def test_overwrites_existing_file(self, writer, tmp_output):
with open(tmp_output, "w", encoding="utf-8") as f:
f.write("旧内容")
chunks = [Chunk(title="新标题", content="新内容")]
writer.write(chunks, tmp_output, "test.pdf")
content = open(tmp_output, encoding="utf-8").read()
assert "旧内容" not in content
assert "新内容" in content
def test_prints_warning_on_overwrite(self, writer, tmp_output, capsys):
with open(tmp_output, "w", encoding="utf-8") as f:
f.write("旧内容")
chunks = [Chunk(title="标题", content="内容")]
writer.write(chunks, tmp_output, "test.pdf")
captured = capsys.readouterr()
assert "警告" in captured.out
assert tmp_output in captured.out
def test_no_warning_for_new_file(self, writer, tmp_output, capsys):
chunks = [Chunk(title="标题", content="内容")]
writer.write(chunks, tmp_output, "test.pdf")
captured = capsys.readouterr()
assert "警告" not in captured.out
class TestCustomDelimiter:
"""自定义分隔符测试"""
def test_custom_delimiter(self, writer, tmp_output):
chunks = [
Chunk(title="标题1", content="内容1"),
Chunk(title="标题2", content="内容2"),
]
writer.write(chunks, tmp_output, "test.pdf", delimiter="===")
content = open(tmp_output, encoding="utf-8").read()
after_meta = content.split("-->", 1)[1]
assert "\n===\n" in after_meta
assert "\n---\n" not in after_meta
class TestEmptyContent:
"""空内容 Chunk 测试"""
def test_empty_content_chunk(self, writer, tmp_output):
chunks = [Chunk(title="空内容标题", content="")]
writer.write(chunks, tmp_output, "test.pdf")
content = open(tmp_output, encoding="utf-8").read()
assert "## 空内容标题" in content
def test_empty_content_with_multiple_chunks(self, writer, tmp_output):
chunks = [
Chunk(title="标题1", content=""),
Chunk(title="标题2", content="有内容"),
]
writer.write(chunks, tmp_output, "test.pdf")
content = open(tmp_output, encoding="utf-8").read()
assert "## 标题1" in content
assert "## 标题2" in content
assert "有内容" in content
class TestUTF8Encoding:
"""UTF-8 编码测试"""
def test_utf8_encoding(self, writer, tmp_output):
chunks = [Chunk(title="中文标题", content="中文内容,包含特殊字符:①②③")]
writer.write(chunks, tmp_output, "测试文件.pdf")
content = open(tmp_output, encoding="utf-8").read()
assert "中文标题" in content
assert "①②③" in content
assert "测试文件.pdf" in content

178
tests/test_xls_parser.py Normal file
View File

@@ -0,0 +1,178 @@
"""XlsParser 单元测试"""
import pytest
import xlwt
from exceptions import ParseError
from parsers.xls_parser import XlsParser
@pytest.fixture
def parser():
return XlsParser()
def _create_xls(path, sheets=None):
"""
创建测试用 XLS 文件。
Args:
path: 输出文件路径
sheets: dictkey 为 sheet 名称value 为二维列表(行×列的数据)
如果为 None创建空工作簿
"""
wb = xlwt.Workbook()
if sheets:
for sheet_name, rows in sheets.items():
ws = wb.add_sheet(sheet_name)
for row_idx, row in enumerate(rows):
for col_idx, value in enumerate(row):
ws.write(row_idx, col_idx, value)
else:
# xlwt 需要至少一个 sheet
wb.add_sheet("Sheet1")
wb.save(str(path))
class TestSupportedExtensions:
def test_supports_xls(self, parser):
assert ".xls" in parser.supported_extensions()
def test_only_one_extension(self, parser):
assert len(parser.supported_extensions()) == 1
class TestParse:
def test_simple_table(self, parser, tmp_path):
"""基本表格转换为 Markdown"""
xls_path = tmp_path / "simple.xls"
_create_xls(xls_path, {
"Sheet1": [
["Name", "Age"],
["Alice", 30],
["Bob", 25],
]
})
result = parser.parse(str(xls_path))
assert "## Sheet1" in result
assert "| Name | Age |" in result
assert "| --- | --- |" in result
assert "Alice" in result
assert "Bob" in result
def test_multiple_sheets(self, parser, tmp_path):
"""多个工作表各自生成标题和表格"""
xls_path = tmp_path / "multi.xls"
_create_xls(xls_path, {
"Users": [["Name"], ["Alice"]],
"Orders": [["ID"], ["001"]],
})
result = parser.parse(str(xls_path))
assert "## Users" in result
assert "## Orders" in result
assert "| Name |" in result
assert "| ID |" in result
def test_empty_sheet_skipped(self, parser, tmp_path):
"""空工作表应被跳过"""
xls_path = tmp_path / "empty_sheet.xls"
wb = xlwt.Workbook()
wb.add_sheet("Empty") # no data written
ws = wb.add_sheet("Data")
ws.write(0, 0, "Col1")
ws.write(1, 0, "Val1")
wb.save(str(xls_path))
result = parser.parse(str(xls_path))
assert "## Empty" not in result
assert "## Data" in result
def test_pipe_escaped(self, parser, tmp_path):
"""单元格中的 | 应被转义为 &#124;"""
xls_path = tmp_path / "pipe.xls"
_create_xls(xls_path, {
"Sheet1": [["Header"], ["value|with|pipes"]],
})
result = parser.parse(str(xls_path))
assert "&#124;" in result
assert "value&#124;with&#124;pipes" in result
def test_newline_escaped(self, parser, tmp_path):
"""单元格中的换行符应被转义为 <br>"""
xls_path = tmp_path / "newline.xls"
_create_xls(xls_path, {
"Sheet1": [["Header"], ["line1\nline2"]],
})
result = parser.parse(str(xls_path))
assert "line1<br>line2" in result
def test_backtick_escaped(self, parser, tmp_path):
"""单元格中的反引号应被转义为 &#96;"""
xls_path = tmp_path / "backtick.xls"
_create_xls(xls_path, {
"Sheet1": [["Header"], ["code `snippet`"]],
})
result = parser.parse(str(xls_path))
assert "&#96;" in result
def test_empty_cell_becomes_empty(self, parser, tmp_path):
"""空单元格应显示为空字符串"""
xls_path = tmp_path / "empty_cell.xls"
wb = xlwt.Workbook()
ws = wb.add_sheet("Sheet1")
ws.write(0, 0, "A")
ws.write(0, 1, "B")
ws.write(1, 0, "val")
# cell (1,1) is not written — will be empty
wb.save(str(xls_path))
result = parser.parse(str(xls_path))
assert "| val | |" in result
def test_sheet_name_as_heading(self, parser, tmp_path):
"""工作表名称应作为 ## 标题"""
xls_path = tmp_path / "named.xls"
_create_xls(xls_path, {
"Sales Report": [["Month", "Revenue"], ["Jan", "1000"]],
})
result = parser.parse(str(xls_path))
assert "## Sales Report" in result
def test_nonexistent_file_raises(self, parser):
with pytest.raises(ParseError) as exc_info:
parser.parse("/nonexistent/path/file.xls")
assert "file.xls" in exc_info.value.file_name
assert exc_info.value.reason != ""
def test_corrupted_file_raises(self, parser, tmp_path):
xls_path = tmp_path / "corrupted.xls"
xls_path.write_bytes(b"this is not an xls file")
with pytest.raises(ParseError) as exc_info:
parser.parse(str(xls_path))
assert "corrupted.xls" in exc_info.value.file_name
def test_parse_error_contains_filename(self, parser):
with pytest.raises(ParseError) as exc_info:
parser.parse("/no/such/report.xls")
assert exc_info.value.file_name == "report.xls"
def test_numeric_values(self, parser, tmp_path):
"""数值类型应正确转换为字符串"""
xls_path = tmp_path / "numeric.xls"
_create_xls(xls_path, {
"Sheet1": [["Int", "Float"], [42, 3.14]],
})
result = parser.parse(str(xls_path))
assert "42" in result
assert "3.14" in result
def test_crlf_escaped(self, parser, tmp_path):
"""\\r\\n 应被转义为 <br>"""
xls_path = tmp_path / "crlf.xls"
_create_xls(xls_path, {
"Sheet1": [["Header"], ["line1\r\nline2"]],
})
result = parser.parse(str(xls_path))
assert "line1<br>line2" in result

220
tests/test_xlsx_parser.py Normal file
View File

@@ -0,0 +1,220 @@
"""XlsxParser 单元测试"""
import pytest
from openpyxl import Workbook
from exceptions import ParseError
from parsers.xlsx_parser import XlsxParser
@pytest.fixture
def parser():
return XlsxParser()
def _create_xlsx(path, sheets=None):
"""
创建测试用 XLSX 文件。
Args:
path: 输出文件路径
sheets: dictkey 为 sheet 名称value 为二维列表(行×列的数据)
如果为 None创建空工作簿
"""
wb = Workbook()
# 删除默认 sheet
wb.remove(wb.active)
if sheets:
for sheet_name, rows in sheets.items():
ws = wb.create_sheet(title=sheet_name)
for row in rows:
ws.append(row)
wb.save(str(path))
def _create_xlsx_with_merge(path, sheet_name, rows, merges):
"""
创建带合并单元格的 XLSX 文件。
Args:
path: 输出文件路径
sheet_name: 工作表名称
rows: 二维列表(行×列的数据)
merges: 合并区域列表,如 ["A1:B1", "A2:A3"]
"""
wb = Workbook()
wb.remove(wb.active)
ws = wb.create_sheet(title=sheet_name)
for row in rows:
ws.append(row)
for merge_range in merges:
ws.merge_cells(merge_range)
wb.save(str(path))
class TestSupportedExtensions:
def test_supports_xlsx(self, parser):
assert ".xlsx" in parser.supported_extensions()
def test_only_one_extension(self, parser):
assert len(parser.supported_extensions()) == 1
class TestParse:
def test_simple_table(self, parser, tmp_path):
"""基本表格转换为 Markdown"""
xlsx_path = tmp_path / "simple.xlsx"
_create_xlsx(xlsx_path, {
"Sheet1": [
["Name", "Age"],
["Alice", 30],
["Bob", 25],
]
})
result = parser.parse(str(xlsx_path))
assert "## Sheet1" in result
assert "| Name | Age |" in result
assert "| --- | --- |" in result
assert "| Alice | 30 |" in result
assert "| Bob | 25 |" in result
def test_multiple_sheets(self, parser, tmp_path):
"""多个工作表各自生成标题和表格"""
xlsx_path = tmp_path / "multi.xlsx"
_create_xlsx(xlsx_path, {
"Users": [["Name"], ["Alice"]],
"Orders": [["ID"], ["001"]],
})
result = parser.parse(str(xlsx_path))
assert "## Users" in result
assert "## Orders" in result
assert "| Name |" in result
assert "| ID |" in result
def test_empty_sheet_skipped(self, parser, tmp_path):
"""空工作表应被跳过"""
xlsx_path = tmp_path / "empty_sheet.xlsx"
_create_xlsx(xlsx_path, {
"Empty": [],
"Data": [["Col1"], ["Val1"]],
})
result = parser.parse(str(xlsx_path))
assert "## Empty" not in result
assert "## Data" in result
def test_all_empty_sheets(self, parser, tmp_path):
"""所有工作表都为空时返回空字符串"""
xlsx_path = tmp_path / "all_empty.xlsx"
_create_xlsx(xlsx_path, {"Empty1": [], "Empty2": []})
result = parser.parse(str(xlsx_path))
assert result.strip() == ""
def test_pipe_escaped(self, parser, tmp_path):
"""单元格中的 | 应被转义为 &#124;"""
xlsx_path = tmp_path / "pipe.xlsx"
_create_xlsx(xlsx_path, {
"Sheet1": [["Header"], ["value|with|pipes"]],
})
result = parser.parse(str(xlsx_path))
assert "&#124;" in result
assert "value&#124;with&#124;pipes" in result
def test_newline_escaped(self, parser, tmp_path):
"""单元格中的换行符应被转义为 <br>"""
xlsx_path = tmp_path / "newline.xlsx"
_create_xlsx(xlsx_path, {
"Sheet1": [["Header"], ["line1\nline2"]],
})
result = parser.parse(str(xlsx_path))
assert "line1<br>line2" in result
def test_backtick_escaped(self, parser, tmp_path):
"""单元格中的反引号应被转义为 &#96;"""
xlsx_path = tmp_path / "backtick.xlsx"
_create_xlsx(xlsx_path, {
"Sheet1": [["Header"], ["code `snippet`"]],
})
result = parser.parse(str(xlsx_path))
assert "&#96;" in result
def test_none_cell_becomes_empty(self, parser, tmp_path):
"""None 值的单元格应显示为空"""
xlsx_path = tmp_path / "none.xlsx"
_create_xlsx(xlsx_path, {
"Sheet1": [["A", "B"], ["val", None]],
})
result = parser.parse(str(xlsx_path))
assert "| val | |" in result
def test_merged_cells(self, parser, tmp_path):
"""合并单元格应填充左上角的值"""
xlsx_path = tmp_path / "merged.xlsx"
_create_xlsx_with_merge(
xlsx_path,
sheet_name="Data",
rows=[
["Category", "Value"],
["Fruit", 10],
[None, 20], # A3 will be merged with A2
],
merges=["A2:A3"],
)
result = parser.parse(str(xlsx_path))
assert "## Data" in result
# The merged cell (A3) should have the value from A2 ("Fruit")
lines = result.split("\n")
data_lines = [l for l in lines if l.startswith("| ") and "---" not in l and "Category" not in l]
assert len(data_lines) == 2
# Both data rows should contain "Fruit"
assert all("Fruit" in line for line in data_lines)
def test_sheet_name_as_heading(self, parser, tmp_path):
"""工作表名称应作为 ## 标题"""
xlsx_path = tmp_path / "named.xlsx"
_create_xlsx(xlsx_path, {
"Sales Report": [["Month", "Revenue"], ["Jan", "1000"]],
})
result = parser.parse(str(xlsx_path))
assert "## Sales Report" in result
def test_nonexistent_file_raises(self, parser):
with pytest.raises(ParseError) as exc_info:
parser.parse("/nonexistent/path/file.xlsx")
assert "file.xlsx" in exc_info.value.file_name
assert exc_info.value.reason != ""
def test_corrupted_file_raises(self, parser, tmp_path):
xlsx_path = tmp_path / "corrupted.xlsx"
xlsx_path.write_bytes(b"this is not an xlsx file")
with pytest.raises(ParseError) as exc_info:
parser.parse(str(xlsx_path))
assert "corrupted.xlsx" in exc_info.value.file_name
def test_parse_error_contains_filename(self, parser):
with pytest.raises(ParseError) as exc_info:
parser.parse("/no/such/report.xlsx")
assert exc_info.value.file_name == "report.xlsx"
def test_numeric_values(self, parser, tmp_path):
"""数值类型应正确转换为字符串"""
xlsx_path = tmp_path / "numeric.xlsx"
_create_xlsx(xlsx_path, {
"Sheet1": [["Int", "Float"], [42, 3.14]],
})
result = parser.parse(str(xlsx_path))
assert "42" in result
assert "3.14" in result
def test_crlf_escaped(self, parser, tmp_path):
"""\\r\\n 应被转义为 <br>"""
xlsx_path = tmp_path / "crlf.xlsx"
_create_xlsx(xlsx_path, {
"Sheet1": [["Header"], ["line1\r\nline2"]],
})
result = parser.parse(str(xlsx_path))
assert "line1<br>line2" in result