"""简单任务 API:将聊天任务与 WebSocket 解耦,支持后台运行与轮询。""" from __future__ import annotations import time import threading import uuid from collections import deque from typing import Dict, Any, Optional, List from flask import Blueprint, request, jsonify from flask import current_app, session from .auth_helpers import api_login_required, get_current_username from .context import get_user_resources, ensure_conversation_loaded from .chat_flow import run_chat_task_sync from .state import stop_flags from .utils_common import debug_log class TaskRecord: __slots__ = ( "task_id", "username", "workspace_id", "status", "created_at", "updated_at", "message", "conversation_id", "events", "thread", "error", "model_key", "thinking_mode", "run_mode", "max_iterations", "session_data", "stop_requested", ) def __init__( self, task_id: str, username: str, workspace_id: str, message: str, conversation_id: Optional[str], model_key: Optional[str], thinking_mode: Optional[bool], run_mode: Optional[str], max_iterations: Optional[int], ): self.task_id = task_id self.username = username self.workspace_id = workspace_id self.status = "pending" self.created_at = time.time() self.updated_at = self.created_at self.message = message self.conversation_id = conversation_id self.events: deque[Dict[str, Any]] = deque(maxlen=1000) self.thread: Optional[threading.Thread] = None self.error: Optional[str] = None self.model_key = model_key self.thinking_mode = thinking_mode self.run_mode = run_mode self.max_iterations = max_iterations self.session_data: Dict[str, Any] = {} self.stop_requested: bool = False class TaskManager: """线程内存版任务管理器,后续可替换为 Redis/DB。""" def __init__(self): self._tasks: Dict[str, TaskRecord] = {} self._lock = threading.Lock() # ---- public APIs ---- def create_chat_task( self, username: str, workspace_id: str, message: str, images: List[Any], conversation_id: Optional[str], model_key: Optional[str] = None, thinking_mode: Optional[bool] = None, run_mode: Optional[str] = None, max_iterations: Optional[int] = None, ) -> TaskRecord: if run_mode: normalized = str(run_mode).lower() if normalized not in {"fast", "thinking", "deep"}: raise ValueError("run_mode 只支持 fast/thinking/deep") run_mode = normalized # 单工作区互斥:禁止同一用户同一工作区并发任务 existing = [t for t in self.list_tasks(username, workspace_id) if t.status in {"pending", "running"}] if existing: raise RuntimeError("已有运行中的任务,请稍后再试。") task_id = str(uuid.uuid4()) record = TaskRecord(task_id, username, workspace_id, message, conversation_id, model_key, thinking_mode, run_mode, max_iterations) # 记录当前 session 快照,便于后台线程内使用 try: record.session_data = { "username": session.get("username"), "role": session.get("role"), "is_api_user": session.get("is_api_user"), "workspace_id": workspace_id, "run_mode": session.get("run_mode"), "thinking_mode": session.get("thinking_mode"), "model_key": session.get("model_key"), } except Exception: record.session_data = {} with self._lock: self._tasks[task_id] = record thread = threading.Thread(target=self._run_chat_task, args=(record, images), daemon=True) record.thread = thread record.status = "running" record.updated_at = time.time() thread.start() return record def get_task(self, username: str, task_id: str) -> Optional[TaskRecord]: with self._lock: rec = self._tasks.get(task_id) if not rec or rec.username != username: return None return rec def list_tasks(self, username: str, workspace_id: Optional[str] = None) -> List[TaskRecord]: with self._lock: return [ rec for rec in self._tasks.values() if rec.username == username and (workspace_id is None or rec.workspace_id == workspace_id) ] def cancel_task(self, username: str, task_id: str) -> bool: rec = self.get_task(username, task_id) if not rec: return False rec.stop_requested = True # 标记停止标志;chat_flow 会检测 stop_flags entry = stop_flags.get(task_id) if not isinstance(entry, dict): entry = {'stop': False, 'task': None, 'terminal': None} stop_flags[task_id] = entry entry['stop'] = True try: if entry.get('task') and hasattr(entry['task'], "cancel"): entry['task'].cancel() except Exception: pass with self._lock: rec.status = "cancel_requested" rec.updated_at = time.time() return True # ---- internal helpers ---- def _append_event(self, rec: TaskRecord, event_type: str, data: Dict[str, Any]): with self._lock: idx = rec.events[-1]["idx"] + 1 if rec.events else 0 rec.events.append({ "idx": idx, "type": event_type, "data": data, "ts": time.time(), }) rec.updated_at = time.time() def _run_chat_task(self, rec: TaskRecord, images: List[Any]): username = rec.username workspace_id = rec.workspace_id terminal = None workspace = None stop_hint = False try: # 为后台线程构造最小请求上下文,填充 session from server.app import app as flask_app with flask_app.test_request_context(): try: for k, v in (rec.session_data or {}).items(): if v is not None: session[k] = v except Exception: pass terminal, workspace = get_user_resources(username, workspace_id=workspace_id) if not terminal or not workspace: raise RuntimeError("系统未初始化") stop_hint = bool(stop_flags.get(rec.task_id, {}).get("stop")) # API 传入的模型/模式配置 if rec.model_key: try: terminal.set_model(rec.model_key) except Exception as exc: debug_log(f"[Task] 设置模型失败 {rec.model_key}: {exc}") if rec.run_mode: try: terminal.set_run_mode(rec.run_mode) except Exception as exc: debug_log(f"[Task] 设置运行模式失败 {rec.run_mode}: {exc}") elif rec.thinking_mode is not None: try: terminal.set_run_mode("thinking" if rec.thinking_mode else "fast") except Exception as exc: debug_log(f"[Task] 设置思考模式失败: {exc}") if rec.max_iterations: try: terminal.max_iterations_override = int(rec.max_iterations) except Exception: terminal.max_iterations_override = None # 确保会话加载 conversation_id = rec.conversation_id try: conversation_id, _ = ensure_conversation_loaded(terminal, conversation_id, workspace=workspace) rec.conversation_id = conversation_id except Exception as exc: raise RuntimeError(f"对话加载失败: {exc}") from exc def sender(event_type, data): # 记录事件 self._append_event(rec, event_type, data) # 在线用户仍然收到实时推送(房间 user_{username}) try: from .extensions import socketio socketio.emit(event_type, data, room=f"user_{username}") except Exception: pass # 将 task_id 作为 client_sid,供 stop_flags 检测 run_chat_task_sync( terminal=terminal, message=rec.message, images=images, sender=sender, client_sid=rec.task_id, workspace=workspace, username=username, ) # 结束状态 canceled_flag = rec.stop_requested or stop_hint or bool(stop_flags.get(rec.task_id, {}).get("stop")) with self._lock: rec.status = "canceled" if canceled_flag else "succeeded" rec.updated_at = time.time() except Exception as exc: debug_log(f"[Task] 后台任务失败: {exc}") self._append_event(rec, "error", {"message": str(exc)}) with self._lock: rec.status = "failed" rec.error = str(exc) rec.updated_at = time.time() finally: # 清理 stop_flags stop_flags.pop(rec.task_id, None) # 清理一次性配置 if terminal and hasattr(terminal, "max_iterations_override"): try: delattr(terminal, "max_iterations_override") except Exception: terminal.max_iterations_override = None task_manager = TaskManager() tasks_bp = Blueprint("tasks", __name__) @tasks_bp.route("/api/tasks", methods=["GET"]) @api_login_required def list_tasks_api(): username = get_current_username() recs = task_manager.list_tasks(username) return jsonify({ "success": True, "data": [ { "task_id": r.task_id, "status": r.status, "created_at": r.created_at, "updated_at": r.updated_at, "message": r.message, "conversation_id": r.conversation_id, "error": r.error, } for r in sorted(recs, key=lambda x: x.created_at, reverse=True) ] }) @tasks_bp.route("/api/tasks", methods=["POST"]) @api_login_required def create_task_api(): username = get_current_username() workspace_id = session.get("workspace_id") or "default" payload = request.get_json() or {} message = (payload.get("message") or "").strip() images = payload.get("images") or [] videos = payload.get("videos") or [] conversation_id = payload.get("conversation_id") if not message and not images and not videos: return jsonify({"success": False, "error": "消息不能为空"}), 400 model_key = payload.get("model_key") thinking_mode = payload.get("thinking_mode") run_mode = payload.get("run_mode") max_iterations = payload.get("max_iterations") # 合并 images 和 videos 到 images 参数(后端统一处理) all_media = images + videos try: rec = task_manager.create_chat_task( username, workspace_id, message, all_media, conversation_id, model_key=model_key, thinking_mode=thinking_mode, run_mode=run_mode, max_iterations=max_iterations, ) except RuntimeError as exc: return jsonify({"success": False, "error": str(exc)}), 409 return jsonify({ "success": True, "data": { "task_id": rec.task_id, "status": rec.status, "created_at": rec.created_at, "conversation_id": rec.conversation_id, } }), 202 @tasks_bp.route("/api/tasks/", methods=["GET"]) @api_login_required def get_task_api(task_id: str): username = get_current_username() rec = task_manager.get_task(username, task_id) if not rec: return jsonify({"success": False, "error": "任务不存在"}), 404 try: offset = int(request.args.get("from", 0)) except Exception: offset = 0 events = [e for e in rec.events if e["idx"] >= offset] next_offset = events[-1]["idx"] + 1 if events else offset return jsonify({ "success": True, "data": { "task_id": rec.task_id, "status": rec.status, "created_at": rec.created_at, "updated_at": rec.updated_at, "message": rec.message, "conversation_id": rec.conversation_id, "error": rec.error, "events": events, "next_offset": next_offset, } }) @tasks_bp.route("/api/tasks//cancel", methods=["POST"]) @api_login_required def cancel_task_api(task_id: str): username = get_current_username() ok = task_manager.cancel_task(username, task_id) if not ok: return jsonify({"success": False, "error": "任务不存在"}), 404 return jsonify({"success": True})