Initial commit: AI 知识库文档智能分块工具

This commit is contained in:
AI Knowledge Splitter
2026-03-02 17:38:28 +08:00
commit 92e7fc5bda
160 changed files with 9577 additions and 0 deletions

251
tests/test_api_client.py Normal file
View File

@@ -0,0 +1,251 @@
"""ApiClient 单元测试"""
import pytest
from unittest.mock import MagicMock
import openai
from api_client import ApiClient
from exceptions import ApiError
def _make_completion_response(content: str):
"""构造模拟的 ChatCompletion 响应"""
message = MagicMock()
message.content = content
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
return response
def _make_rate_limit_error():
"""构造 openai.RateLimitError"""
return openai.RateLimitError(
message="Rate limit exceeded",
response=MagicMock(status_code=429),
body=None,
)
def _make_api_error(status_code=500, message="Internal server error"):
"""构造非速率限制的 openai.APIStatusError"""
return openai.APIStatusError(
message=message,
response=MagicMock(status_code=status_code),
body=None,
)
def _make_client(**kwargs):
"""创建注入 mock OpenAI client 的 ApiClient"""
mock_openai = MagicMock()
sleep_fn = kwargs.get("sleep_fn", MagicMock())
return ApiClient(
api_key="test-key",
_client=mock_openai,
_sleep=sleep_fn,
), mock_openai, sleep_fn
class TestApiClientChat:
"""chat() 方法测试"""
def test_successful_chat(self):
"""成功调用 chat 返回内容"""
client, mock_openai, sleep_fn = _make_client()
expected = "这是 AI 的回复"
mock_openai.chat.completions.create.return_value = _make_completion_response(expected)
result = client.chat("你是助手", "你好")
assert result == expected
mock_openai.chat.completions.create.assert_called_once_with(
model="deepseek-chat",
messages=[
{"role": "system", "content": "你是助手"},
{"role": "user", "content": "你好"},
],
)
sleep_fn.assert_not_called()
def test_chat_custom_model(self):
"""chat 支持自定义模型"""
client, mock_openai, _ = _make_client()
mock_openai.chat.completions.create.return_value = _make_completion_response("ok")
client.chat("sys", "user", model="deepseek-reasoner")
mock_openai.chat.completions.create.assert_called_once_with(
model="deepseek-reasoner",
messages=[
{"role": "system", "content": "sys"},
{"role": "user", "content": "user"},
],
)
def test_chat_retry_on_429_then_success(self):
"""chat 遇到 429 后重试成功"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_completion_response("成功"),
]
result = client.chat("sys", "user")
assert result == "成功"
assert sleep_fn.call_count == 2
sleep_fn.assert_any_call(1)
sleep_fn.assert_any_call(2)
def test_chat_retry_exhausted_raises_api_error(self):
"""chat 重试耗尽抛出 ApiError"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
]
with pytest.raises(ApiError, match="速率限制重试耗尽") as exc_info:
client.chat("sys", "user")
assert exc_info.value.status_code == 429
assert sleep_fn.call_count == 3
sleep_fn.assert_any_call(1)
sleep_fn.assert_any_call(2)
sleep_fn.assert_any_call(4)
def test_chat_non_429_error_raises_immediately(self):
"""chat 遇到非 429 错误立即抛出 ApiError不重试"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = _make_api_error(500)
with pytest.raises(ApiError) as exc_info:
client.chat("sys", "user")
assert exc_info.value.status_code == 500
sleep_fn.assert_not_called()
class TestApiClientVision:
"""vision() 方法测试"""
def test_successful_vision(self):
"""成功调用 vision 返回内容"""
client, mock_openai, sleep_fn = _make_client()
expected = "图片中包含一段文字"
mock_openai.chat.completions.create.return_value = _make_completion_response(expected)
result = client.vision("识别图片", "aW1hZ2VfZGF0YQ==")
assert result == expected
mock_openai.chat.completions.create.assert_called_once_with(
model="deepseek-chat",
messages=[
{"role": "system", "content": "识别图片"},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,aW1hZ2VfZGF0YQ==",
},
},
],
},
],
)
sleep_fn.assert_not_called()
def test_vision_retry_on_429_then_success(self):
"""vision 遇到 429 后重试成功"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_completion_response("识别结果"),
]
result = client.vision("sys", "base64data")
assert result == "识别结果"
assert sleep_fn.call_count == 1
sleep_fn.assert_called_with(1)
def test_vision_retry_exhausted_raises_api_error(self):
"""vision 重试耗尽抛出 ApiError"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
]
with pytest.raises(ApiError, match="速率限制重试耗尽"):
client.vision("sys", "base64data")
assert sleep_fn.call_count == 3
def test_vision_non_429_error_raises_immediately(self):
"""vision 遇到非 429 错误立即抛出"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = _make_api_error(401, "Unauthorized")
with pytest.raises(ApiError) as exc_info:
client.vision("sys", "base64data")
assert exc_info.value.status_code == 401
sleep_fn.assert_not_called()
class TestRetryDelays:
"""重试延迟验证"""
def test_retry_delays_are_exponential(self):
"""验证重试延迟为 1, 2, 4 秒"""
assert ApiClient.RETRY_DELAYS == [1, 2, 4]
assert ApiClient.MAX_RETRIES == 3
def test_single_retry_uses_correct_delay(self):
"""单次 429 后使用 1 秒延迟"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_completion_response("ok"),
]
client.chat("sys", "user")
sleep_fn.assert_called_once_with(1)
def test_three_retries_use_correct_delays(self):
"""三次 429 后使用 1, 2, 4 秒延迟"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_completion_response("ok"),
]
result = client.chat("sys", "user")
assert result == "ok"
assert sleep_fn.call_count == 3
calls = [c.args[0] for c in sleep_fn.call_args_list]
assert calls == [1, 2, 4]