152 lines
5.1 KiB
Python
152 lines
5.1 KiB
Python
"""Agents Web API 客户端"""
|
||
|
||
import asyncio
|
||
import logging
|
||
from typing import Optional, Dict, Any, List
|
||
|
||
import aiohttp
|
||
|
||
from .config import AGENTS_HOST, AGENTS_EMAIL, AGENTS_PASSWORD
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class WebAPIClient:
|
||
"""Agents Web API 客户端(基于 Session + CSRF)"""
|
||
|
||
def __init__(
|
||
self,
|
||
host: str = AGENTS_HOST,
|
||
email: str = AGENTS_EMAIL,
|
||
password: str = AGENTS_PASSWORD,
|
||
):
|
||
self.host = host.rstrip("/")
|
||
self.email = email
|
||
self.password = password
|
||
self.session: Optional[aiohttp.ClientSession] = None
|
||
self.csrf_token: Optional[str] = None
|
||
|
||
async def __aenter__(self):
|
||
# 创建 session 并启用 cookie jar
|
||
self.session = aiohttp.ClientSession(
|
||
cookie_jar=aiohttp.CookieJar(unsafe=True)
|
||
)
|
||
await self.login()
|
||
return self
|
||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
if self.session:
|
||
await self.session.close()
|
||
|
||
async def _get_csrf_token(self) -> str:
|
||
"""获取 CSRF Token"""
|
||
url = f"{self.host}/api/csrf-token"
|
||
async with self.session.get(url) as resp:
|
||
data = await resp.json()
|
||
if not data.get("success"):
|
||
raise RuntimeError("获取 CSRF Token 失败")
|
||
return data["token"]
|
||
|
||
async def login(self):
|
||
"""登录到 Agents 系统"""
|
||
# 获取 CSRF Token
|
||
self.csrf_token = await self._get_csrf_token()
|
||
logger.info("已获取 CSRF Token")
|
||
|
||
# 登录
|
||
url = f"{self.host}/login"
|
||
headers = {"X-CSRF-Token": self.csrf_token}
|
||
data = {"email": self.email, "password": self.password}
|
||
|
||
async with self.session.post(url, json=data, headers=headers) as resp:
|
||
result = await resp.json()
|
||
if not result.get("success"):
|
||
raise RuntimeError(f"登录失败: {result.get('error')}")
|
||
logger.info(f"已登录为用户: {self.email}")
|
||
|
||
# 刷新 CSRF Token
|
||
self.csrf_token = await self._get_csrf_token()
|
||
|
||
async def _request(
|
||
self,
|
||
method: str,
|
||
endpoint: str,
|
||
data: Optional[Dict[str, Any]] = None,
|
||
params: Optional[Dict[str, Any]] = None,
|
||
) -> Dict[str, Any]:
|
||
"""发送请求到 Web API"""
|
||
if not self.session:
|
||
raise RuntimeError("Client not initialized")
|
||
|
||
url = f"{self.host}{endpoint}"
|
||
headers = {}
|
||
if method.upper() not in {"GET", "HEAD", "OPTIONS", "TRACE"}:
|
||
headers["X-CSRF-Token"] = self.csrf_token
|
||
|
||
try:
|
||
async with self.session.request(
|
||
method, url, json=data, params=params, headers=headers
|
||
) as resp:
|
||
if resp.status == 401:
|
||
# Session 过期,重新登录
|
||
logger.warning("Session 过期,重新登录")
|
||
await self.login()
|
||
# 重试请求
|
||
headers["X-CSRF-Token"] = self.csrf_token
|
||
async with self.session.request(
|
||
method, url, json=data, params=params, headers=headers
|
||
) as retry_resp:
|
||
return await retry_resp.json()
|
||
return await resp.json()
|
||
except Exception as e:
|
||
logger.error(f"Web API 请求失败: {e}")
|
||
raise
|
||
|
||
async def create_conversation(
|
||
self, thinking_mode: bool = False, mode: str = "fast"
|
||
) -> Dict[str, Any]:
|
||
"""创建新对话"""
|
||
data = {
|
||
"preserve_mode": True,
|
||
"mode": mode,
|
||
"thinking_mode": thinking_mode,
|
||
}
|
||
return await self._request("POST", "/api/conversations", data)
|
||
|
||
async def list_conversations(self, limit: int = 20, offset: int = 0) -> Dict[str, Any]:
|
||
"""获取对话列表"""
|
||
params = {"limit": limit, "offset": offset}
|
||
return await self._request("GET", "/api/conversations", params=params)
|
||
|
||
async def load_conversation(self, conversation_id: str) -> Dict[str, Any]:
|
||
"""加载/切换对话"""
|
||
return await self._request("PUT", f"/api/conversations/{conversation_id}/load")
|
||
|
||
async def send_message(
|
||
self,
|
||
message: str,
|
||
conversation_id: str,
|
||
max_iterations: int = 100,
|
||
) -> Dict[str, Any]:
|
||
"""发送消息(创建任务)"""
|
||
data = {
|
||
"message": message,
|
||
"conversation_id": conversation_id,
|
||
"images": [],
|
||
"videos": [],
|
||
"model_key": None,
|
||
"thinking_mode": None,
|
||
"run_mode": None,
|
||
"max_iterations": max_iterations,
|
||
}
|
||
return await self._request("POST", "/api/tasks", data)
|
||
|
||
async def poll_task(self, task_id: str, from_offset: int = 0) -> Dict[str, Any]:
|
||
"""轮询任务事件"""
|
||
params = {"from": from_offset}
|
||
return await self._request("GET", f"/api/tasks/{task_id}", params=params)
|
||
|
||
async def cancel_task(self, task_id: str) -> Dict[str, Any]:
|
||
"""取消任务"""
|
||
return await self._request("POST", f"/api/tasks/{task_id}/cancel")
|