Files
bigwo/coze_api/coze_client.py
2026-03-12 12:47:56 +08:00

177 lines
5.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Coze 智能体 API 客户端封装
支持调用 Coze 平台发布的 Bot 进行对话
"""
import os
import json
import uuid
import requests
from dotenv import load_dotenv
load_dotenv()
class CozeClient:
"""Coze API 客户端"""
def __init__(self, api_token=None, bot_id=None, api_base=None):
self.api_token = api_token or os.getenv("COZE_API_TOKEN")
self.bot_id = bot_id or os.getenv("COZE_BOT_ID")
self.api_base = api_base or os.getenv("COZE_API_BASE", "https://api.coze.cn")
if not self.api_token:
raise ValueError("缺少 COZE_API_TOKEN请在 .env 文件中配置")
if not self.bot_id:
raise ValueError("缺少 COZE_BOT_ID请在 .env 文件中配置")
self.headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json",
}
def chat(self, message, user_id=None, conversation_id=None, stream=False, custom_variables=None):
"""
发送消息给 Bot 并获取回复
Args:
message: 用户消息内容
user_id: 用户标识可选默认生成随机ID
conversation_id: 会话ID可选用于多轮对话
stream: 是否使用流式返回
custom_variables: 用户自定义变量,如 {"zishu": "200字以内", "fengge": "轻松活泼"}
Returns:
dict: API 响应结果
"""
url = f"{self.api_base}/v3/chat"
if user_id is None:
user_id = str(uuid.uuid4().hex[:16])
payload = {
"bot_id": self.bot_id,
"user_id": user_id,
"stream": stream,
"auto_save_history": True,
"additional_messages": [
{
"role": "user",
"content": message,
"content_type": "text",
}
],
}
if custom_variables:
payload["custom_variables"] = custom_variables
if conversation_id:
payload["conversation_id"] = conversation_id
try:
response = requests.post(url, headers=self.headers, json=payload, timeout=60)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
return {"error": str(e), "status_code": getattr(e.response, "status_code", None)}
def retrieve_chat(self, conversation_id, chat_id):
"""
查询对话状态(非流式模式下需要轮询)
Args:
conversation_id: 会话ID
chat_id: 对话ID
Returns:
dict: 对话状态
"""
url = f"{self.api_base}/v3/chat/retrieve"
params = {
"conversation_id": conversation_id,
"chat_id": chat_id,
}
try:
response = requests.get(url, headers=self.headers, params=params, timeout=30)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
return {"error": str(e)}
def get_chat_messages(self, conversation_id, chat_id):
"""
获取对话的消息列表
Args:
conversation_id: 会话ID
chat_id: 对话ID
Returns:
dict: 消息列表
"""
url = f"{self.api_base}/v3/chat/message/list"
params = {
"conversation_id": conversation_id,
"chat_id": chat_id,
}
try:
response = requests.get(url, headers=self.headers, params=params, timeout=30)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
return {"error": str(e)}
def chat_and_poll(self, message, user_id=None, custom_variables=None, poll_interval=2, max_wait=120):
"""
发送消息并轮询等待结果(非流式模式的完整调用流程)
Args:
message: 用户消息
user_id: 用户标识
custom_variables: 用户自定义变量
poll_interval: 轮询间隔(秒)
max_wait: 最大等待时间(秒)
Returns:
str: Bot 的回复文本
"""
import time
result = self.chat(message, user_id=user_id, stream=False, custom_variables=custom_variables)
if "error" in result:
return f"请求失败: {result['error']}"
data = result.get("data", {})
conversation_id = data.get("conversation_id")
chat_id = data.get("id")
if not conversation_id or not chat_id:
return f"响应异常: {json.dumps(result, ensure_ascii=False, indent=2)}"
# 轮询等待完成
elapsed = 0
while elapsed < max_wait:
time.sleep(poll_interval)
elapsed += poll_interval
status_result = self.retrieve_chat(conversation_id, chat_id)
status = status_result.get("data", {}).get("status")
if status == "completed":
# 获取消息
messages_result = self.get_chat_messages(conversation_id, chat_id)
messages = messages_result.get("data", [])
# 找到 assistant 的回复
for msg in messages:
if msg.get("role") == "assistant" and msg.get("type") == "answer":
return msg.get("content", "")
return f"未找到回复消息: {json.dumps(messages, ensure_ascii=False, indent=2)}"
elif status == "failed":
return f"对话失败: {json.dumps(status_result, ensure_ascii=False, indent=2)}"
return "等待超时,请稍后重试"