Files
bigwo/api_client.py

123 lines
4.0 KiB
Python
Raw Normal View History

"""DeepSeek API 客户端封装,含指数退避重试逻辑"""
import time
from typing import Callable
from openai import OpenAI
import openai
from exceptions import ApiError
# 文件扩展名 → MIME 类型映射
EXTENSION_MIME_MAP = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".bmp": "image/bmp",
".webp": "image/webp",
}
class ApiClient:
"""封装 DeepSeek API 调用,含重试和错误处理"""
RETRY_DELAYS = [1, 2, 4] # 指数退避延迟(秒)
MAX_RETRIES = 3
def __init__(
self,
api_key: str,
base_url: str = "https://api.deepseek.com",
_sleep: Callable[[float], None] = time.sleep,
_client: "OpenAI | None" = None,
):
self._client = _client or OpenAI(api_key=api_key, base_url=base_url)
self._sleep = _sleep
def chat(self, system_prompt: str, user_content: str, model: str = "deepseek-chat") -> str:
"""
调用 Chat Completion API
速率限制和网络异常时自动指数退避重试最多 3
Raises:
ApiError: API 调用失败非可重试错误或重试耗尽
"""
def _call():
response = self._client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
],
)
return response.choices[0].message.content
return self._retry(_call)
def vision(
self,
system_prompt: str,
image_base64: str,
mime_type: str = "image/png",
model: str = "deepseek-chat",
) -> str:
"""
调用 Vision API 识别图片内容
速率限制和网络异常时自动指数退避重试最多 3
Args:
system_prompt: 系统提示词
image_base64: 图片的 base64 编码
mime_type: 图片 MIME 类型 image/jpegimage/png
model: 模型名称
Raises:
ApiError: API 调用失败
"""
def _call():
response = self._client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{image_base64}",
},
},
],
},
],
)
return response.choices[0].message.content
return self._retry(_call)
def _retry(self, call: Callable[[], str]) -> str:
"""执行带指数退避重试的 API 调用,对速率限制和网络异常重试"""
for attempt in range(self.MAX_RETRIES + 1):
try:
return call()
except openai.RateLimitError:
if attempt < self.MAX_RETRIES:
self._sleep(self.RETRY_DELAYS[attempt])
else:
raise ApiError("速率限制重试耗尽", status_code=429)
except openai.APIConnectionError:
if attempt < self.MAX_RETRIES:
self._sleep(self.RETRY_DELAYS[attempt])
else:
raise ApiError("网络连接失败,重试耗尽")
except openai.APITimeoutError:
if attempt < self.MAX_RETRIES:
self._sleep(self.RETRY_DELAYS[attempt])
else:
raise ApiError("API 请求超时,重试耗尽")
except openai.APIError as e:
raise ApiError(str(e), status_code=getattr(e, "status_code", None))