agent-Specialization/scripts/qq_bot/web_api_client.py

152 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Agents Web API 客户端"""
import asyncio
import logging
from typing import Optional, Dict, Any, List
import aiohttp
from .config import AGENTS_HOST, AGENTS_EMAIL, AGENTS_PASSWORD
logger = logging.getLogger(__name__)
class WebAPIClient:
"""Agents Web API 客户端(基于 Session + CSRF"""
def __init__(
self,
host: str = AGENTS_HOST,
email: str = AGENTS_EMAIL,
password: str = AGENTS_PASSWORD,
):
self.host = host.rstrip("/")
self.email = email
self.password = password
self.session: Optional[aiohttp.ClientSession] = None
self.csrf_token: Optional[str] = None
async def __aenter__(self):
# 创建 session 并启用 cookie jar
self.session = aiohttp.ClientSession(
cookie_jar=aiohttp.CookieJar(unsafe=True)
)
await self.login()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.session:
await self.session.close()
async def _get_csrf_token(self) -> str:
"""获取 CSRF Token"""
url = f"{self.host}/api/csrf-token"
async with self.session.get(url) as resp:
data = await resp.json()
if not data.get("success"):
raise RuntimeError("获取 CSRF Token 失败")
return data["token"]
async def login(self):
"""登录到 Agents 系统"""
# 获取 CSRF Token
self.csrf_token = await self._get_csrf_token()
logger.info("已获取 CSRF Token")
# 登录
url = f"{self.host}/login"
headers = {"X-CSRF-Token": self.csrf_token}
data = {"email": self.email, "password": self.password}
async with self.session.post(url, json=data, headers=headers) as resp:
result = await resp.json()
if not result.get("success"):
raise RuntimeError(f"登录失败: {result.get('error')}")
logger.info(f"已登录为用户: {self.email}")
# 刷新 CSRF Token
self.csrf_token = await self._get_csrf_token()
async def _request(
self,
method: str,
endpoint: str,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""发送请求到 Web API"""
if not self.session:
raise RuntimeError("Client not initialized")
url = f"{self.host}{endpoint}"
headers = {}
if method.upper() not in {"GET", "HEAD", "OPTIONS", "TRACE"}:
headers["X-CSRF-Token"] = self.csrf_token
try:
async with self.session.request(
method, url, json=data, params=params, headers=headers
) as resp:
if resp.status == 401:
# Session 过期,重新登录
logger.warning("Session 过期,重新登录")
await self.login()
# 重试请求
headers["X-CSRF-Token"] = self.csrf_token
async with self.session.request(
method, url, json=data, params=params, headers=headers
) as retry_resp:
return await retry_resp.json()
return await resp.json()
except Exception as e:
logger.error(f"Web API 请求失败: {e}")
raise
async def create_conversation(
self, thinking_mode: bool = False, mode: str = "fast"
) -> Dict[str, Any]:
"""创建新对话"""
data = {
"preserve_mode": True,
"mode": mode,
"thinking_mode": thinking_mode,
}
return await self._request("POST", "/api/conversations", data)
async def list_conversations(self, limit: int = 20, offset: int = 0) -> Dict[str, Any]:
"""获取对话列表"""
params = {"limit": limit, "offset": offset}
return await self._request("GET", "/api/conversations", params=params)
async def load_conversation(self, conversation_id: str) -> Dict[str, Any]:
"""加载/切换对话"""
return await self._request("PUT", f"/api/conversations/{conversation_id}/load")
async def send_message(
self,
message: str,
conversation_id: str,
max_iterations: int = 100,
) -> Dict[str, Any]:
"""发送消息(创建任务)"""
data = {
"message": message,
"conversation_id": conversation_id,
"images": [],
"videos": [],
"model_key": None,
"thinking_mode": None,
"run_mode": None,
"max_iterations": max_iterations,
}
return await self._request("POST", "/api/tasks", data)
async def poll_task(self, task_id: str, from_offset: int = 0) -> Dict[str, Any]:
"""轮询任务事件"""
params = {"from": from_offset}
return await self._request("GET", f"/api/tasks/{task_id}", params=params)
async def cancel_task(self, task_id: str) -> Dict[str, Any]:
"""取消任务"""
return await self._request("POST", f"/api/tasks/{task_id}/cancel")