agent-Specialization/server/tasks.py

367 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""简单任务 API将聊天任务与 WebSocket 解耦,支持后台运行与轮询。"""
from __future__ import annotations
import time
import threading
import uuid
from collections import deque
from typing import Dict, Any, Optional, List
from flask import Blueprint, request, jsonify
from flask import current_app, session
from .auth_helpers import api_login_required, get_current_username
from .context import get_user_resources, ensure_conversation_loaded
from .chat_flow import run_chat_task_sync
from .state import stop_flags
from .utils_common import debug_log
class TaskRecord:
__slots__ = (
"task_id",
"username",
"workspace_id",
"status",
"created_at",
"updated_at",
"message",
"conversation_id",
"events",
"thread",
"error",
"model_key",
"thinking_mode",
"run_mode",
"max_iterations",
"session_data",
"stop_requested",
)
def __init__(
self,
task_id: str,
username: str,
workspace_id: str,
message: str,
conversation_id: Optional[str],
model_key: Optional[str],
thinking_mode: Optional[bool],
run_mode: Optional[str],
max_iterations: Optional[int],
):
self.task_id = task_id
self.username = username
self.workspace_id = workspace_id
self.status = "pending"
self.created_at = time.time()
self.updated_at = self.created_at
self.message = message
self.conversation_id = conversation_id
self.events: deque[Dict[str, Any]] = deque(maxlen=1000)
self.thread: Optional[threading.Thread] = None
self.error: Optional[str] = None
self.model_key = model_key
self.thinking_mode = thinking_mode
self.run_mode = run_mode
self.max_iterations = max_iterations
self.session_data: Dict[str, Any] = {}
self.stop_requested: bool = False
class TaskManager:
"""线程内存版任务管理器,后续可替换为 Redis/DB。"""
def __init__(self):
self._tasks: Dict[str, TaskRecord] = {}
self._lock = threading.Lock()
# ---- public APIs ----
def create_chat_task(
self,
username: str,
workspace_id: str,
message: str,
images: List[Any],
conversation_id: Optional[str],
model_key: Optional[str] = None,
thinking_mode: Optional[bool] = None,
run_mode: Optional[str] = None,
max_iterations: Optional[int] = None,
) -> TaskRecord:
if run_mode:
normalized = str(run_mode).lower()
if normalized not in {"fast", "thinking", "deep"}:
raise ValueError("run_mode 只支持 fast/thinking/deep")
run_mode = normalized
# 单工作区互斥:禁止同一用户同一工作区并发任务
existing = [t for t in self.list_tasks(username, workspace_id) if t.status in {"pending", "running"}]
if existing:
raise RuntimeError("已有运行中的任务,请稍后再试。")
task_id = str(uuid.uuid4())
record = TaskRecord(task_id, username, workspace_id, message, conversation_id, model_key, thinking_mode, run_mode, max_iterations)
# 记录当前 session 快照,便于后台线程内使用
try:
record.session_data = {
"username": session.get("username"),
"role": session.get("role"),
"is_api_user": session.get("is_api_user"),
"workspace_id": workspace_id,
"run_mode": session.get("run_mode"),
"thinking_mode": session.get("thinking_mode"),
"model_key": session.get("model_key"),
}
except Exception:
record.session_data = {}
with self._lock:
self._tasks[task_id] = record
thread = threading.Thread(target=self._run_chat_task, args=(record, images), daemon=True)
record.thread = thread
record.status = "running"
record.updated_at = time.time()
thread.start()
return record
def get_task(self, username: str, task_id: str) -> Optional[TaskRecord]:
with self._lock:
rec = self._tasks.get(task_id)
if not rec or rec.username != username:
return None
return rec
def list_tasks(self, username: str, workspace_id: Optional[str] = None) -> List[TaskRecord]:
with self._lock:
return [
rec
for rec in self._tasks.values()
if rec.username == username and (workspace_id is None or rec.workspace_id == workspace_id)
]
def cancel_task(self, username: str, task_id: str) -> bool:
rec = self.get_task(username, task_id)
if not rec:
return False
rec.stop_requested = True
# 标记停止标志chat_flow 会检测 stop_flags
entry = stop_flags.get(task_id)
if not isinstance(entry, dict):
entry = {'stop': False, 'task': None, 'terminal': None}
stop_flags[task_id] = entry
entry['stop'] = True
try:
if entry.get('task') and hasattr(entry['task'], "cancel"):
entry['task'].cancel()
except Exception:
pass
with self._lock:
rec.status = "cancel_requested"
rec.updated_at = time.time()
return True
# ---- internal helpers ----
def _append_event(self, rec: TaskRecord, event_type: str, data: Dict[str, Any]):
with self._lock:
idx = rec.events[-1]["idx"] + 1 if rec.events else 0
rec.events.append({
"idx": idx,
"type": event_type,
"data": data,
"ts": time.time(),
})
rec.updated_at = time.time()
def _run_chat_task(self, rec: TaskRecord, images: List[Any]):
username = rec.username
workspace_id = rec.workspace_id
terminal = None
workspace = None
stop_hint = False
try:
# 为后台线程构造最小请求上下文,填充 session
from server.app import app as flask_app
with flask_app.test_request_context():
try:
for k, v in (rec.session_data or {}).items():
if v is not None:
session[k] = v
except Exception:
pass
terminal, workspace = get_user_resources(username, workspace_id=workspace_id)
if not terminal or not workspace:
raise RuntimeError("系统未初始化")
stop_hint = bool(stop_flags.get(rec.task_id, {}).get("stop"))
# API 传入的模型/模式配置
if rec.model_key:
try:
terminal.set_model(rec.model_key)
except Exception as exc:
debug_log(f"[Task] 设置模型失败 {rec.model_key}: {exc}")
if rec.run_mode:
try:
terminal.set_run_mode(rec.run_mode)
except Exception as exc:
debug_log(f"[Task] 设置运行模式失败 {rec.run_mode}: {exc}")
elif rec.thinking_mode is not None:
try:
terminal.set_run_mode("thinking" if rec.thinking_mode else "fast")
except Exception as exc:
debug_log(f"[Task] 设置思考模式失败: {exc}")
if rec.max_iterations:
try:
terminal.max_iterations_override = int(rec.max_iterations)
except Exception:
terminal.max_iterations_override = None
# 确保会话加载
conversation_id = rec.conversation_id
try:
conversation_id, _ = ensure_conversation_loaded(terminal, conversation_id, workspace=workspace)
rec.conversation_id = conversation_id
except Exception as exc:
raise RuntimeError(f"对话加载失败: {exc}") from exc
def sender(event_type, data):
# 记录事件
self._append_event(rec, event_type, data)
# 在线用户仍然收到实时推送(房间 user_{username}
try:
from .extensions import socketio
socketio.emit(event_type, data, room=f"user_{username}")
except Exception:
pass
# 将 task_id 作为 client_sid供 stop_flags 检测
run_chat_task_sync(
terminal=terminal,
message=rec.message,
images=images,
sender=sender,
client_sid=rec.task_id,
workspace=workspace,
username=username,
)
# 结束状态
canceled_flag = rec.stop_requested or stop_hint or bool(stop_flags.get(rec.task_id, {}).get("stop"))
with self._lock:
rec.status = "canceled" if canceled_flag else "succeeded"
rec.updated_at = time.time()
except Exception as exc:
debug_log(f"[Task] 后台任务失败: {exc}")
self._append_event(rec, "error", {"message": str(exc)})
with self._lock:
rec.status = "failed"
rec.error = str(exc)
rec.updated_at = time.time()
finally:
# 清理 stop_flags
stop_flags.pop(rec.task_id, None)
# 清理一次性配置
if terminal and hasattr(terminal, "max_iterations_override"):
try:
delattr(terminal, "max_iterations_override")
except Exception:
terminal.max_iterations_override = None
task_manager = TaskManager()
tasks_bp = Blueprint("tasks", __name__)
@tasks_bp.route("/api/tasks", methods=["GET"])
@api_login_required
def list_tasks_api():
username = get_current_username()
recs = task_manager.list_tasks(username)
return jsonify({
"success": True,
"data": [
{
"task_id": r.task_id,
"status": r.status,
"created_at": r.created_at,
"updated_at": r.updated_at,
"message": r.message,
"conversation_id": r.conversation_id,
"error": r.error,
} for r in sorted(recs, key=lambda x: x.created_at, reverse=True)
]
})
@tasks_bp.route("/api/tasks", methods=["POST"])
@api_login_required
def create_task_api():
username = get_current_username()
workspace_id = session.get("workspace_id") or "default"
payload = request.get_json() or {}
message = (payload.get("message") or "").strip()
images = payload.get("images") or []
conversation_id = payload.get("conversation_id")
if not message and not images:
return jsonify({"success": False, "error": "消息不能为空"}), 400
model_key = payload.get("model_key")
thinking_mode = payload.get("thinking_mode")
max_iterations = payload.get("max_iterations")
try:
rec = task_manager.create_chat_task(
username,
workspace_id,
message,
images,
conversation_id,
model_key=model_key,
thinking_mode=thinking_mode,
max_iterations=max_iterations,
)
except RuntimeError as exc:
return jsonify({"success": False, "error": str(exc)}), 409
return jsonify({
"success": True,
"data": {
"task_id": rec.task_id,
"status": rec.status,
"created_at": rec.created_at,
"conversation_id": rec.conversation_id,
}
}), 202
@tasks_bp.route("/api/tasks/<task_id>", methods=["GET"])
@api_login_required
def get_task_api(task_id: str):
username = get_current_username()
rec = task_manager.get_task(username, task_id)
if not rec:
return jsonify({"success": False, "error": "任务不存在"}), 404
try:
offset = int(request.args.get("from", 0))
except Exception:
offset = 0
events = [e for e in rec.events if e["idx"] >= offset]
next_offset = events[-1]["idx"] + 1 if events else offset
return jsonify({
"success": True,
"data": {
"task_id": rec.task_id,
"status": rec.status,
"created_at": rec.created_at,
"updated_at": rec.updated_at,
"message": rec.message,
"conversation_id": rec.conversation_id,
"error": rec.error,
"events": events,
"next_offset": next_offset,
}
})
@tasks_bp.route("/api/tasks/<task_id>/cancel", methods=["POST"])
@api_login_required
def cancel_task_api(task_id: str):
username = get_current_username()
ok = task_manager.cancel_task(username, task_id)
if not ok:
return jsonify({"success": False, "error": "任务不存在"}), 404
return jsonify({"success": True})