Files
bigwo/tests/test_api_client.py
2026-03-02 17:38:28 +08:00

252 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ApiClient 单元测试"""
import pytest
from unittest.mock import MagicMock
import openai
from api_client import ApiClient
from exceptions import ApiError
def _make_completion_response(content: str):
"""构造模拟的 ChatCompletion 响应"""
message = MagicMock()
message.content = content
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
return response
def _make_rate_limit_error():
"""构造 openai.RateLimitError"""
return openai.RateLimitError(
message="Rate limit exceeded",
response=MagicMock(status_code=429),
body=None,
)
def _make_api_error(status_code=500, message="Internal server error"):
"""构造非速率限制的 openai.APIStatusError"""
return openai.APIStatusError(
message=message,
response=MagicMock(status_code=status_code),
body=None,
)
def _make_client(**kwargs):
"""创建注入 mock OpenAI client 的 ApiClient"""
mock_openai = MagicMock()
sleep_fn = kwargs.get("sleep_fn", MagicMock())
return ApiClient(
api_key="test-key",
_client=mock_openai,
_sleep=sleep_fn,
), mock_openai, sleep_fn
class TestApiClientChat:
"""chat() 方法测试"""
def test_successful_chat(self):
"""成功调用 chat 返回内容"""
client, mock_openai, sleep_fn = _make_client()
expected = "这是 AI 的回复"
mock_openai.chat.completions.create.return_value = _make_completion_response(expected)
result = client.chat("你是助手", "你好")
assert result == expected
mock_openai.chat.completions.create.assert_called_once_with(
model="deepseek-chat",
messages=[
{"role": "system", "content": "你是助手"},
{"role": "user", "content": "你好"},
],
)
sleep_fn.assert_not_called()
def test_chat_custom_model(self):
"""chat 支持自定义模型"""
client, mock_openai, _ = _make_client()
mock_openai.chat.completions.create.return_value = _make_completion_response("ok")
client.chat("sys", "user", model="deepseek-reasoner")
mock_openai.chat.completions.create.assert_called_once_with(
model="deepseek-reasoner",
messages=[
{"role": "system", "content": "sys"},
{"role": "user", "content": "user"},
],
)
def test_chat_retry_on_429_then_success(self):
"""chat 遇到 429 后重试成功"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_completion_response("成功"),
]
result = client.chat("sys", "user")
assert result == "成功"
assert sleep_fn.call_count == 2
sleep_fn.assert_any_call(1)
sleep_fn.assert_any_call(2)
def test_chat_retry_exhausted_raises_api_error(self):
"""chat 重试耗尽抛出 ApiError"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
]
with pytest.raises(ApiError, match="速率限制重试耗尽") as exc_info:
client.chat("sys", "user")
assert exc_info.value.status_code == 429
assert sleep_fn.call_count == 3
sleep_fn.assert_any_call(1)
sleep_fn.assert_any_call(2)
sleep_fn.assert_any_call(4)
def test_chat_non_429_error_raises_immediately(self):
"""chat 遇到非 429 错误立即抛出 ApiError不重试"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = _make_api_error(500)
with pytest.raises(ApiError) as exc_info:
client.chat("sys", "user")
assert exc_info.value.status_code == 500
sleep_fn.assert_not_called()
class TestApiClientVision:
"""vision() 方法测试"""
def test_successful_vision(self):
"""成功调用 vision 返回内容"""
client, mock_openai, sleep_fn = _make_client()
expected = "图片中包含一段文字"
mock_openai.chat.completions.create.return_value = _make_completion_response(expected)
result = client.vision("识别图片", "aW1hZ2VfZGF0YQ==")
assert result == expected
mock_openai.chat.completions.create.assert_called_once_with(
model="deepseek-chat",
messages=[
{"role": "system", "content": "识别图片"},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,aW1hZ2VfZGF0YQ==",
},
},
],
},
],
)
sleep_fn.assert_not_called()
def test_vision_retry_on_429_then_success(self):
"""vision 遇到 429 后重试成功"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_completion_response("识别结果"),
]
result = client.vision("sys", "base64data")
assert result == "识别结果"
assert sleep_fn.call_count == 1
sleep_fn.assert_called_with(1)
def test_vision_retry_exhausted_raises_api_error(self):
"""vision 重试耗尽抛出 ApiError"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
]
with pytest.raises(ApiError, match="速率限制重试耗尽"):
client.vision("sys", "base64data")
assert sleep_fn.call_count == 3
def test_vision_non_429_error_raises_immediately(self):
"""vision 遇到非 429 错误立即抛出"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = _make_api_error(401, "Unauthorized")
with pytest.raises(ApiError) as exc_info:
client.vision("sys", "base64data")
assert exc_info.value.status_code == 401
sleep_fn.assert_not_called()
class TestRetryDelays:
"""重试延迟验证"""
def test_retry_delays_are_exponential(self):
"""验证重试延迟为 1, 2, 4 秒"""
assert ApiClient.RETRY_DELAYS == [1, 2, 4]
assert ApiClient.MAX_RETRIES == 3
def test_single_retry_uses_correct_delay(self):
"""单次 429 后使用 1 秒延迟"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_completion_response("ok"),
]
client.chat("sys", "user")
sleep_fn.assert_called_once_with(1)
def test_three_retries_use_correct_delays(self):
"""三次 429 后使用 1, 2, 4 秒延迟"""
client, mock_openai, sleep_fn = _make_client()
mock_openai.chat.completions.create.side_effect = [
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_rate_limit_error(),
_make_completion_response("ok"),
]
result = client.chat("sys", "user")
assert result == "ok"
assert sleep_fn.call_count == 3
calls = [c.args[0] for c in sleep_fn.call_args_list]
assert calls == [1, 2, 4]