123 lines
4.0 KiB
Python
123 lines
4.0 KiB
Python
"""DeepSeek API 客户端封装,含指数退避重试逻辑"""
|
||
|
||
import time
|
||
from typing import Callable
|
||
|
||
from openai import OpenAI
|
||
import openai
|
||
|
||
from exceptions import ApiError
|
||
|
||
# 文件扩展名 → MIME 类型映射
|
||
EXTENSION_MIME_MAP = {
|
||
".png": "image/png",
|
||
".jpg": "image/jpeg",
|
||
".jpeg": "image/jpeg",
|
||
".gif": "image/gif",
|
||
".bmp": "image/bmp",
|
||
".webp": "image/webp",
|
||
}
|
||
|
||
|
||
class ApiClient:
|
||
"""封装 DeepSeek API 调用,含重试和错误处理"""
|
||
|
||
RETRY_DELAYS = [1, 2, 4] # 指数退避延迟(秒)
|
||
MAX_RETRIES = 3
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str,
|
||
base_url: str = "https://api.deepseek.com",
|
||
_sleep: Callable[[float], None] = time.sleep,
|
||
_client: "OpenAI | None" = None,
|
||
):
|
||
self._client = _client or OpenAI(api_key=api_key, base_url=base_url)
|
||
self._sleep = _sleep
|
||
|
||
|
||
def chat(self, system_prompt: str, user_content: str, model: str = "deepseek-chat") -> str:
|
||
"""
|
||
调用 Chat Completion API。
|
||
速率限制和网络异常时自动指数退避重试,最多 3 次。
|
||
|
||
Raises:
|
||
ApiError: API 调用失败(非可重试错误或重试耗尽)
|
||
"""
|
||
def _call():
|
||
response = self._client.chat.completions.create(
|
||
model=model,
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_content},
|
||
],
|
||
)
|
||
return response.choices[0].message.content
|
||
|
||
return self._retry(_call)
|
||
|
||
def vision(
|
||
self,
|
||
system_prompt: str,
|
||
image_base64: str,
|
||
mime_type: str = "image/png",
|
||
model: str = "deepseek-chat",
|
||
) -> str:
|
||
"""
|
||
调用 Vision API 识别图片内容。
|
||
速率限制和网络异常时自动指数退避重试,最多 3 次。
|
||
|
||
Args:
|
||
system_prompt: 系统提示词
|
||
image_base64: 图片的 base64 编码
|
||
mime_type: 图片 MIME 类型(如 image/jpeg、image/png)
|
||
model: 模型名称
|
||
|
||
Raises:
|
||
ApiError: API 调用失败
|
||
"""
|
||
def _call():
|
||
response = self._client.chat.completions.create(
|
||
model=model,
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:{mime_type};base64,{image_base64}",
|
||
},
|
||
},
|
||
],
|
||
},
|
||
],
|
||
)
|
||
return response.choices[0].message.content
|
||
|
||
return self._retry(_call)
|
||
|
||
def _retry(self, call: Callable[[], str]) -> str:
|
||
"""执行带指数退避重试的 API 调用,对速率限制和网络异常重试"""
|
||
for attempt in range(self.MAX_RETRIES + 1):
|
||
try:
|
||
return call()
|
||
except openai.RateLimitError:
|
||
if attempt < self.MAX_RETRIES:
|
||
self._sleep(self.RETRY_DELAYS[attempt])
|
||
else:
|
||
raise ApiError("速率限制重试耗尽", status_code=429)
|
||
except openai.APIConnectionError:
|
||
if attempt < self.MAX_RETRIES:
|
||
self._sleep(self.RETRY_DELAYS[attempt])
|
||
else:
|
||
raise ApiError("网络连接失败,重试耗尽")
|
||
except openai.APITimeoutError:
|
||
if attempt < self.MAX_RETRIES:
|
||
self._sleep(self.RETRY_DELAYS[attempt])
|
||
else:
|
||
raise ApiError("API 请求超时,重试耗尽")
|
||
except openai.APIError as e:
|
||
raise ApiError(str(e), status_code=getattr(e, "status_code", None))
|