"""ApiClient 单元测试""" import pytest from unittest.mock import MagicMock import openai from api_client import ApiClient from exceptions import ApiError def _make_completion_response(content: str): """构造模拟的 ChatCompletion 响应""" message = MagicMock() message.content = content choice = MagicMock() choice.message = message response = MagicMock() response.choices = [choice] return response def _make_rate_limit_error(): """构造 openai.RateLimitError""" return openai.RateLimitError( message="Rate limit exceeded", response=MagicMock(status_code=429), body=None, ) def _make_api_error(status_code=500, message="Internal server error"): """构造非速率限制的 openai.APIStatusError""" return openai.APIStatusError( message=message, response=MagicMock(status_code=status_code), body=None, ) def _make_client(**kwargs): """创建注入 mock OpenAI client 的 ApiClient""" mock_openai = MagicMock() sleep_fn = kwargs.get("sleep_fn", MagicMock()) return ApiClient( api_key="test-key", _client=mock_openai, _sleep=sleep_fn, ), mock_openai, sleep_fn class TestApiClientChat: """chat() 方法测试""" def test_successful_chat(self): """成功调用 chat 返回内容""" client, mock_openai, sleep_fn = _make_client() expected = "这是 AI 的回复" mock_openai.chat.completions.create.return_value = _make_completion_response(expected) result = client.chat("你是助手", "你好") assert result == expected mock_openai.chat.completions.create.assert_called_once_with( model="deepseek-chat", messages=[ {"role": "system", "content": "你是助手"}, {"role": "user", "content": "你好"}, ], ) sleep_fn.assert_not_called() def test_chat_custom_model(self): """chat 支持自定义模型""" client, mock_openai, _ = _make_client() mock_openai.chat.completions.create.return_value = _make_completion_response("ok") client.chat("sys", "user", model="deepseek-reasoner") mock_openai.chat.completions.create.assert_called_once_with( model="deepseek-reasoner", messages=[ {"role": "system", "content": "sys"}, {"role": "user", "content": "user"}, ], ) def test_chat_retry_on_429_then_success(self): """chat 遇到 429 后重试成功""" client, mock_openai, sleep_fn = _make_client() mock_openai.chat.completions.create.side_effect = [ _make_rate_limit_error(), _make_rate_limit_error(), _make_completion_response("成功"), ] result = client.chat("sys", "user") assert result == "成功" assert sleep_fn.call_count == 2 sleep_fn.assert_any_call(1) sleep_fn.assert_any_call(2) def test_chat_retry_exhausted_raises_api_error(self): """chat 重试耗尽抛出 ApiError""" client, mock_openai, sleep_fn = _make_client() mock_openai.chat.completions.create.side_effect = [ _make_rate_limit_error(), _make_rate_limit_error(), _make_rate_limit_error(), _make_rate_limit_error(), ] with pytest.raises(ApiError, match="速率限制重试耗尽") as exc_info: client.chat("sys", "user") assert exc_info.value.status_code == 429 assert sleep_fn.call_count == 3 sleep_fn.assert_any_call(1) sleep_fn.assert_any_call(2) sleep_fn.assert_any_call(4) def test_chat_non_429_error_raises_immediately(self): """chat 遇到非 429 错误立即抛出 ApiError,不重试""" client, mock_openai, sleep_fn = _make_client() mock_openai.chat.completions.create.side_effect = _make_api_error(500) with pytest.raises(ApiError) as exc_info: client.chat("sys", "user") assert exc_info.value.status_code == 500 sleep_fn.assert_not_called() class TestApiClientVision: """vision() 方法测试""" def test_successful_vision(self): """成功调用 vision 返回内容""" client, mock_openai, sleep_fn = _make_client() expected = "图片中包含一段文字" mock_openai.chat.completions.create.return_value = _make_completion_response(expected) result = client.vision("识别图片", "aW1hZ2VfZGF0YQ==") assert result == expected mock_openai.chat.completions.create.assert_called_once_with( model="deepseek-chat", messages=[ {"role": "system", "content": "识别图片"}, { "role": "user", "content": [ { "type": "image_url", "image_url": { "url": "data:image/png;base64,aW1hZ2VfZGF0YQ==", }, }, ], }, ], ) sleep_fn.assert_not_called() def test_vision_retry_on_429_then_success(self): """vision 遇到 429 后重试成功""" client, mock_openai, sleep_fn = _make_client() mock_openai.chat.completions.create.side_effect = [ _make_rate_limit_error(), _make_completion_response("识别结果"), ] result = client.vision("sys", "base64data") assert result == "识别结果" assert sleep_fn.call_count == 1 sleep_fn.assert_called_with(1) def test_vision_retry_exhausted_raises_api_error(self): """vision 重试耗尽抛出 ApiError""" client, mock_openai, sleep_fn = _make_client() mock_openai.chat.completions.create.side_effect = [ _make_rate_limit_error(), _make_rate_limit_error(), _make_rate_limit_error(), _make_rate_limit_error(), ] with pytest.raises(ApiError, match="速率限制重试耗尽"): client.vision("sys", "base64data") assert sleep_fn.call_count == 3 def test_vision_non_429_error_raises_immediately(self): """vision 遇到非 429 错误立即抛出""" client, mock_openai, sleep_fn = _make_client() mock_openai.chat.completions.create.side_effect = _make_api_error(401, "Unauthorized") with pytest.raises(ApiError) as exc_info: client.vision("sys", "base64data") assert exc_info.value.status_code == 401 sleep_fn.assert_not_called() class TestRetryDelays: """重试延迟验证""" def test_retry_delays_are_exponential(self): """验证重试延迟为 1, 2, 4 秒""" assert ApiClient.RETRY_DELAYS == [1, 2, 4] assert ApiClient.MAX_RETRIES == 3 def test_single_retry_uses_correct_delay(self): """单次 429 后使用 1 秒延迟""" client, mock_openai, sleep_fn = _make_client() mock_openai.chat.completions.create.side_effect = [ _make_rate_limit_error(), _make_completion_response("ok"), ] client.chat("sys", "user") sleep_fn.assert_called_once_with(1) def test_three_retries_use_correct_delays(self): """三次 429 后使用 1, 2, 4 秒延迟""" client, mock_openai, sleep_fn = _make_client() mock_openai.chat.completions.create.side_effect = [ _make_rate_limit_error(), _make_rate_limit_error(), _make_rate_limit_error(), _make_completion_response("ok"), ] result = client.chat("sys", "user") assert result == "ok" assert sleep_fn.call_count == 3 calls = [c.args[0] for c in sleep_fn.call_args_list] assert calls == [1, 2, 4]