agent-Specialization/server/tasks.py
JOJO 801d20591c feat: 实现 REST API + 轮询模式,支持页面刷新后任务继续执行
主要改进:
- 新增 REST API 任务管理接口 (/api/tasks)
- 实现 150ms 轮询机制,提供流畅的流式输出体验
- 支持页面刷新后自动恢复任务状态
- WebSocket 断开时检测 REST API 任务,避免误停止
- 修复堆叠块融合问题,刷新后内容正确合并
- 修复思考块展开/折叠逻辑,只展开正在流式输出的块
- 修复工具块重复显示问题,通过注册机制实现状态更新
- 修复历史不完整导致内容丢失的问题
- 新增 tool_intent 事件处理,支持打字机效果显示
- 修复对话列表排序时 None 值比较错误

技术细节:
- 前端:新增 taskPolling.ts 和 task store 处理轮询逻辑
- 后端:TaskManager 管理任务生命周期和事件存储
- 状态恢复:智能判断是否需要从头重建,避免内容重复
- 工具块注册:恢复时注册到 toolActionIndex,支持状态更新
- Intent 显示:0.5-1秒打字机效果,历史加载直接显示

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-03-08 03:12:46 +08:00

374 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""简单任务 API将聊天任务与 WebSocket 解耦,支持后台运行与轮询。"""
from __future__ import annotations
import time
import threading
import uuid
from collections import deque
from typing import Dict, Any, Optional, List
from flask import Blueprint, request, jsonify
from flask import current_app, session
from .auth_helpers import api_login_required, get_current_username
from .context import get_user_resources, ensure_conversation_loaded
from .chat_flow import run_chat_task_sync
from .state import stop_flags
from .utils_common import debug_log
class TaskRecord:
__slots__ = (
"task_id",
"username",
"workspace_id",
"status",
"created_at",
"updated_at",
"message",
"conversation_id",
"events",
"thread",
"error",
"model_key",
"thinking_mode",
"run_mode",
"max_iterations",
"session_data",
"stop_requested",
)
def __init__(
self,
task_id: str,
username: str,
workspace_id: str,
message: str,
conversation_id: Optional[str],
model_key: Optional[str],
thinking_mode: Optional[bool],
run_mode: Optional[str],
max_iterations: Optional[int],
):
self.task_id = task_id
self.username = username
self.workspace_id = workspace_id
self.status = "pending"
self.created_at = time.time()
self.updated_at = self.created_at
self.message = message
self.conversation_id = conversation_id
self.events: deque[Dict[str, Any]] = deque(maxlen=1000)
self.thread: Optional[threading.Thread] = None
self.error: Optional[str] = None
self.model_key = model_key
self.thinking_mode = thinking_mode
self.run_mode = run_mode
self.max_iterations = max_iterations
self.session_data: Dict[str, Any] = {}
self.stop_requested: bool = False
class TaskManager:
"""线程内存版任务管理器,后续可替换为 Redis/DB。"""
def __init__(self):
self._tasks: Dict[str, TaskRecord] = {}
self._lock = threading.Lock()
# ---- public APIs ----
def create_chat_task(
self,
username: str,
workspace_id: str,
message: str,
images: List[Any],
conversation_id: Optional[str],
model_key: Optional[str] = None,
thinking_mode: Optional[bool] = None,
run_mode: Optional[str] = None,
max_iterations: Optional[int] = None,
) -> TaskRecord:
if run_mode:
normalized = str(run_mode).lower()
if normalized not in {"fast", "thinking", "deep"}:
raise ValueError("run_mode 只支持 fast/thinking/deep")
run_mode = normalized
# 单工作区互斥:禁止同一用户同一工作区并发任务
existing = [t for t in self.list_tasks(username, workspace_id) if t.status in {"pending", "running"}]
if existing:
raise RuntimeError("已有运行中的任务,请稍后再试。")
task_id = str(uuid.uuid4())
record = TaskRecord(task_id, username, workspace_id, message, conversation_id, model_key, thinking_mode, run_mode, max_iterations)
# 记录当前 session 快照,便于后台线程内使用
try:
record.session_data = {
"username": session.get("username"),
"role": session.get("role"),
"is_api_user": session.get("is_api_user"),
"workspace_id": workspace_id,
"run_mode": session.get("run_mode"),
"thinking_mode": session.get("thinking_mode"),
"model_key": session.get("model_key"),
}
except Exception:
record.session_data = {}
with self._lock:
self._tasks[task_id] = record
thread = threading.Thread(target=self._run_chat_task, args=(record, images), daemon=True)
record.thread = thread
record.status = "running"
record.updated_at = time.time()
thread.start()
return record
def get_task(self, username: str, task_id: str) -> Optional[TaskRecord]:
with self._lock:
rec = self._tasks.get(task_id)
if not rec or rec.username != username:
return None
return rec
def list_tasks(self, username: str, workspace_id: Optional[str] = None) -> List[TaskRecord]:
with self._lock:
return [
rec
for rec in self._tasks.values()
if rec.username == username and (workspace_id is None or rec.workspace_id == workspace_id)
]
def cancel_task(self, username: str, task_id: str) -> bool:
rec = self.get_task(username, task_id)
if not rec:
return False
rec.stop_requested = True
# 标记停止标志chat_flow 会检测 stop_flags
entry = stop_flags.get(task_id)
if not isinstance(entry, dict):
entry = {'stop': False, 'task': None, 'terminal': None}
stop_flags[task_id] = entry
entry['stop'] = True
try:
if entry.get('task') and hasattr(entry['task'], "cancel"):
entry['task'].cancel()
except Exception:
pass
with self._lock:
rec.status = "cancel_requested"
rec.updated_at = time.time()
return True
# ---- internal helpers ----
def _append_event(self, rec: TaskRecord, event_type: str, data: Dict[str, Any]):
with self._lock:
idx = rec.events[-1]["idx"] + 1 if rec.events else 0
rec.events.append({
"idx": idx,
"type": event_type,
"data": data,
"ts": time.time(),
})
rec.updated_at = time.time()
def _run_chat_task(self, rec: TaskRecord, images: List[Any]):
username = rec.username
workspace_id = rec.workspace_id
terminal = None
workspace = None
stop_hint = False
try:
# 为后台线程构造最小请求上下文,填充 session
from server.app import app as flask_app
with flask_app.test_request_context():
try:
for k, v in (rec.session_data or {}).items():
if v is not None:
session[k] = v
except Exception:
pass
terminal, workspace = get_user_resources(username, workspace_id=workspace_id)
if not terminal or not workspace:
raise RuntimeError("系统未初始化")
stop_hint = bool(stop_flags.get(rec.task_id, {}).get("stop"))
# API 传入的模型/模式配置
if rec.model_key:
try:
terminal.set_model(rec.model_key)
except Exception as exc:
debug_log(f"[Task] 设置模型失败 {rec.model_key}: {exc}")
if rec.run_mode:
try:
terminal.set_run_mode(rec.run_mode)
except Exception as exc:
debug_log(f"[Task] 设置运行模式失败 {rec.run_mode}: {exc}")
elif rec.thinking_mode is not None:
try:
terminal.set_run_mode("thinking" if rec.thinking_mode else "fast")
except Exception as exc:
debug_log(f"[Task] 设置思考模式失败: {exc}")
if rec.max_iterations:
try:
terminal.max_iterations_override = int(rec.max_iterations)
except Exception:
terminal.max_iterations_override = None
# 确保会话加载
conversation_id = rec.conversation_id
try:
conversation_id, _ = ensure_conversation_loaded(terminal, conversation_id, workspace=workspace)
rec.conversation_id = conversation_id
except Exception as exc:
raise RuntimeError(f"对话加载失败: {exc}") from exc
def sender(event_type, data):
# 记录事件
self._append_event(rec, event_type, data)
# 在线用户仍然收到实时推送(房间 user_{username}
try:
from .extensions import socketio
socketio.emit(event_type, data, room=f"user_{username}")
except Exception:
pass
# 将 task_id 作为 client_sid供 stop_flags 检测
run_chat_task_sync(
terminal=terminal,
message=rec.message,
images=images,
sender=sender,
client_sid=rec.task_id,
workspace=workspace,
username=username,
)
# 结束状态
canceled_flag = rec.stop_requested or stop_hint or bool(stop_flags.get(rec.task_id, {}).get("stop"))
with self._lock:
rec.status = "canceled" if canceled_flag else "succeeded"
rec.updated_at = time.time()
except Exception as exc:
debug_log(f"[Task] 后台任务失败: {exc}")
self._append_event(rec, "error", {"message": str(exc)})
with self._lock:
rec.status = "failed"
rec.error = str(exc)
rec.updated_at = time.time()
finally:
# 清理 stop_flags
stop_flags.pop(rec.task_id, None)
# 清理一次性配置
if terminal and hasattr(terminal, "max_iterations_override"):
try:
delattr(terminal, "max_iterations_override")
except Exception:
terminal.max_iterations_override = None
task_manager = TaskManager()
tasks_bp = Blueprint("tasks", __name__)
@tasks_bp.route("/api/tasks", methods=["GET"])
@api_login_required
def list_tasks_api():
username = get_current_username()
recs = task_manager.list_tasks(username)
return jsonify({
"success": True,
"data": [
{
"task_id": r.task_id,
"status": r.status,
"created_at": r.created_at,
"updated_at": r.updated_at,
"message": r.message,
"conversation_id": r.conversation_id,
"error": r.error,
} for r in sorted(recs, key=lambda x: x.created_at, reverse=True)
]
})
@tasks_bp.route("/api/tasks", methods=["POST"])
@api_login_required
def create_task_api():
username = get_current_username()
workspace_id = session.get("workspace_id") or "default"
payload = request.get_json() or {}
message = (payload.get("message") or "").strip()
images = payload.get("images") or []
videos = payload.get("videos") or []
conversation_id = payload.get("conversation_id")
if not message and not images and not videos:
return jsonify({"success": False, "error": "消息不能为空"}), 400
model_key = payload.get("model_key")
thinking_mode = payload.get("thinking_mode")
run_mode = payload.get("run_mode")
max_iterations = payload.get("max_iterations")
# 合并 images 和 videos 到 images 参数(后端统一处理)
all_media = images + videos
try:
rec = task_manager.create_chat_task(
username,
workspace_id,
message,
all_media,
conversation_id,
model_key=model_key,
thinking_mode=thinking_mode,
run_mode=run_mode,
max_iterations=max_iterations,
)
except RuntimeError as exc:
return jsonify({"success": False, "error": str(exc)}), 409
return jsonify({
"success": True,
"data": {
"task_id": rec.task_id,
"status": rec.status,
"created_at": rec.created_at,
"conversation_id": rec.conversation_id,
}
}), 202
@tasks_bp.route("/api/tasks/<task_id>", methods=["GET"])
@api_login_required
def get_task_api(task_id: str):
username = get_current_username()
rec = task_manager.get_task(username, task_id)
if not rec:
return jsonify({"success": False, "error": "任务不存在"}), 404
try:
offset = int(request.args.get("from", 0))
except Exception:
offset = 0
events = [e for e in rec.events if e["idx"] >= offset]
next_offset = events[-1]["idx"] + 1 if events else offset
return jsonify({
"success": True,
"data": {
"task_id": rec.task_id,
"status": rec.status,
"created_at": rec.created_at,
"updated_at": rec.updated_at,
"message": rec.message,
"conversation_id": rec.conversation_id,
"error": rec.error,
"events": events,
"next_offset": next_offset,
}
})
@tasks_bp.route("/api/tasks/<task_id>/cancel", methods=["POST"])
@api_login_required
def cancel_task_api(task_id: str):
username = get_current_username()
ok = task_manager.cancel_task(username, task_id)
if not ok:
return jsonify({"success": False, "error": "任务不存在"}), 404
return jsonify({"success": True})