355 lines
12 KiB
Python
355 lines
12 KiB
Python
"""简单任务 API:将聊天任务与 WebSocket 解耦,支持后台运行与轮询。"""
|
||
from __future__ import annotations
|
||
import time
|
||
import threading
|
||
import uuid
|
||
from collections import deque
|
||
from typing import Dict, Any, Optional, List
|
||
|
||
from flask import Blueprint, request, jsonify
|
||
from flask import current_app, session
|
||
|
||
from .auth_helpers import api_login_required, get_current_username
|
||
from .context import get_user_resources, ensure_conversation_loaded
|
||
from .chat_flow import run_chat_task_sync
|
||
from .state import stop_flags
|
||
from .utils_common import debug_log
|
||
|
||
|
||
class TaskRecord:
|
||
__slots__ = (
|
||
"task_id",
|
||
"username",
|
||
"status",
|
||
"created_at",
|
||
"updated_at",
|
||
"message",
|
||
"conversation_id",
|
||
"events",
|
||
"thread",
|
||
"error",
|
||
"model_key",
|
||
"thinking_mode",
|
||
"run_mode",
|
||
"max_iterations",
|
||
"session_data",
|
||
"stop_requested",
|
||
)
|
||
|
||
def __init__(
|
||
self,
|
||
task_id: str,
|
||
username: str,
|
||
message: str,
|
||
conversation_id: Optional[str],
|
||
model_key: Optional[str],
|
||
thinking_mode: Optional[bool],
|
||
run_mode: Optional[str],
|
||
max_iterations: Optional[int],
|
||
):
|
||
self.task_id = task_id
|
||
self.username = username
|
||
self.status = "pending"
|
||
self.created_at = time.time()
|
||
self.updated_at = self.created_at
|
||
self.message = message
|
||
self.conversation_id = conversation_id
|
||
self.events: deque[Dict[str, Any]] = deque(maxlen=1000)
|
||
self.thread: Optional[threading.Thread] = None
|
||
self.error: Optional[str] = None
|
||
self.model_key = model_key
|
||
self.thinking_mode = thinking_mode
|
||
self.run_mode = run_mode
|
||
self.max_iterations = max_iterations
|
||
self.session_data: Dict[str, Any] = {}
|
||
self.stop_requested: bool = False
|
||
|
||
|
||
class TaskManager:
|
||
"""线程内存版任务管理器,后续可替换为 Redis/DB。"""
|
||
|
||
def __init__(self):
|
||
self._tasks: Dict[str, TaskRecord] = {}
|
||
self._lock = threading.Lock()
|
||
|
||
# ---- public APIs ----
|
||
def create_chat_task(
|
||
self,
|
||
username: str,
|
||
message: str,
|
||
images: List[Any],
|
||
conversation_id: Optional[str],
|
||
model_key: Optional[str] = None,
|
||
thinking_mode: Optional[bool] = None,
|
||
run_mode: Optional[str] = None,
|
||
max_iterations: Optional[int] = None,
|
||
) -> TaskRecord:
|
||
if run_mode:
|
||
normalized = str(run_mode).lower()
|
||
if normalized not in {"fast", "thinking", "deep"}:
|
||
raise ValueError("run_mode 只支持 fast/thinking/deep")
|
||
run_mode = normalized
|
||
# 单用户互斥:禁止并发任务
|
||
existing = [t for t in self.list_tasks(username) if t.status in {"pending", "running"}]
|
||
if existing:
|
||
raise RuntimeError("已有运行中的任务,请稍后再试。")
|
||
task_id = str(uuid.uuid4())
|
||
record = TaskRecord(task_id, username, message, conversation_id, model_key, thinking_mode, run_mode, max_iterations)
|
||
# 记录当前 session 快照,便于后台线程内使用
|
||
try:
|
||
record.session_data = {
|
||
"username": session.get("username"),
|
||
"role": session.get("role"),
|
||
"is_api_user": session.get("is_api_user"),
|
||
"run_mode": session.get("run_mode"),
|
||
"thinking_mode": session.get("thinking_mode"),
|
||
"model_key": session.get("model_key"),
|
||
}
|
||
except Exception:
|
||
record.session_data = {}
|
||
with self._lock:
|
||
self._tasks[task_id] = record
|
||
thread = threading.Thread(target=self._run_chat_task, args=(record, images), daemon=True)
|
||
record.thread = thread
|
||
record.status = "running"
|
||
record.updated_at = time.time()
|
||
thread.start()
|
||
return record
|
||
|
||
def get_task(self, username: str, task_id: str) -> Optional[TaskRecord]:
|
||
with self._lock:
|
||
rec = self._tasks.get(task_id)
|
||
if not rec or rec.username != username:
|
||
return None
|
||
return rec
|
||
|
||
def list_tasks(self, username: str) -> List[TaskRecord]:
|
||
with self._lock:
|
||
return [rec for rec in self._tasks.values() if rec.username == username]
|
||
|
||
def cancel_task(self, username: str, task_id: str) -> bool:
|
||
rec = self.get_task(username, task_id)
|
||
if not rec:
|
||
return False
|
||
rec.stop_requested = True
|
||
# 标记停止标志;chat_flow 会检测 stop_flags
|
||
entry = stop_flags.get(task_id)
|
||
if not isinstance(entry, dict):
|
||
entry = {'stop': False, 'task': None, 'terminal': None}
|
||
stop_flags[task_id] = entry
|
||
entry['stop'] = True
|
||
try:
|
||
if entry.get('task') and hasattr(entry['task'], "cancel"):
|
||
entry['task'].cancel()
|
||
except Exception:
|
||
pass
|
||
with self._lock:
|
||
rec.status = "cancel_requested"
|
||
rec.updated_at = time.time()
|
||
return True
|
||
|
||
# ---- internal helpers ----
|
||
def _append_event(self, rec: TaskRecord, event_type: str, data: Dict[str, Any]):
|
||
with self._lock:
|
||
idx = rec.events[-1]["idx"] + 1 if rec.events else 0
|
||
rec.events.append({
|
||
"idx": idx,
|
||
"type": event_type,
|
||
"data": data,
|
||
"ts": time.time(),
|
||
})
|
||
rec.updated_at = time.time()
|
||
|
||
def _run_chat_task(self, rec: TaskRecord, images: List[Any]):
|
||
username = rec.username
|
||
terminal = None
|
||
workspace = None
|
||
stop_hint = False
|
||
try:
|
||
# 为后台线程构造最小请求上下文,填充 session
|
||
from server.app import app as flask_app
|
||
with flask_app.test_request_context():
|
||
try:
|
||
for k, v in (rec.session_data or {}).items():
|
||
if v is not None:
|
||
session[k] = v
|
||
except Exception:
|
||
pass
|
||
terminal, workspace = get_user_resources(username)
|
||
if not terminal or not workspace:
|
||
raise RuntimeError("系统未初始化")
|
||
stop_hint = bool(stop_flags.get(rec.task_id, {}).get("stop"))
|
||
|
||
# API 传入的模型/模式配置
|
||
if rec.model_key:
|
||
try:
|
||
terminal.set_model(rec.model_key)
|
||
except Exception as exc:
|
||
debug_log(f"[Task] 设置模型失败 {rec.model_key}: {exc}")
|
||
if rec.run_mode:
|
||
try:
|
||
terminal.set_run_mode(rec.run_mode)
|
||
except Exception as exc:
|
||
debug_log(f"[Task] 设置运行模式失败 {rec.run_mode}: {exc}")
|
||
elif rec.thinking_mode is not None:
|
||
try:
|
||
terminal.set_run_mode("thinking" if rec.thinking_mode else "fast")
|
||
except Exception as exc:
|
||
debug_log(f"[Task] 设置思考模式失败: {exc}")
|
||
if rec.max_iterations:
|
||
try:
|
||
terminal.max_iterations_override = int(rec.max_iterations)
|
||
except Exception:
|
||
terminal.max_iterations_override = None
|
||
|
||
# 确保会话加载
|
||
conversation_id = rec.conversation_id
|
||
try:
|
||
conversation_id, _ = ensure_conversation_loaded(terminal, conversation_id)
|
||
rec.conversation_id = conversation_id
|
||
except Exception as exc:
|
||
raise RuntimeError(f"对话加载失败: {exc}") from exc
|
||
|
||
def sender(event_type, data):
|
||
# 记录事件
|
||
self._append_event(rec, event_type, data)
|
||
# 在线用户仍然收到实时推送(房间 user_{username})
|
||
try:
|
||
from .extensions import socketio
|
||
socketio.emit(event_type, data, room=f"user_{username}")
|
||
except Exception:
|
||
pass
|
||
|
||
# 将 task_id 作为 client_sid,供 stop_flags 检测
|
||
run_chat_task_sync(
|
||
terminal=terminal,
|
||
message=rec.message,
|
||
images=images,
|
||
sender=sender,
|
||
client_sid=rec.task_id,
|
||
workspace=workspace,
|
||
username=username,
|
||
)
|
||
|
||
# 结束状态
|
||
canceled_flag = rec.stop_requested or stop_hint or bool(stop_flags.get(rec.task_id, {}).get("stop"))
|
||
with self._lock:
|
||
rec.status = "canceled" if canceled_flag else "succeeded"
|
||
rec.updated_at = time.time()
|
||
except Exception as exc:
|
||
debug_log(f"[Task] 后台任务失败: {exc}")
|
||
self._append_event(rec, "error", {"message": str(exc)})
|
||
with self._lock:
|
||
rec.status = "failed"
|
||
rec.error = str(exc)
|
||
rec.updated_at = time.time()
|
||
finally:
|
||
# 清理 stop_flags
|
||
stop_flags.pop(rec.task_id, None)
|
||
# 清理一次性配置
|
||
if terminal and hasattr(terminal, "max_iterations_override"):
|
||
try:
|
||
delattr(terminal, "max_iterations_override")
|
||
except Exception:
|
||
terminal.max_iterations_override = None
|
||
|
||
|
||
task_manager = TaskManager()
|
||
tasks_bp = Blueprint("tasks", __name__)
|
||
|
||
|
||
@tasks_bp.route("/api/tasks", methods=["GET"])
|
||
@api_login_required
|
||
def list_tasks_api():
|
||
username = get_current_username()
|
||
recs = task_manager.list_tasks(username)
|
||
return jsonify({
|
||
"success": True,
|
||
"data": [
|
||
{
|
||
"task_id": r.task_id,
|
||
"status": r.status,
|
||
"created_at": r.created_at,
|
||
"updated_at": r.updated_at,
|
||
"message": r.message,
|
||
"conversation_id": r.conversation_id,
|
||
"error": r.error,
|
||
} for r in sorted(recs, key=lambda x: x.created_at, reverse=True)
|
||
]
|
||
})
|
||
|
||
|
||
@tasks_bp.route("/api/tasks", methods=["POST"])
|
||
@api_login_required
|
||
def create_task_api():
|
||
username = get_current_username()
|
||
payload = request.get_json() or {}
|
||
message = (payload.get("message") or "").strip()
|
||
images = payload.get("images") or []
|
||
conversation_id = payload.get("conversation_id")
|
||
if not message and not images:
|
||
return jsonify({"success": False, "error": "消息不能为空"}), 400
|
||
model_key = payload.get("model_key")
|
||
thinking_mode = payload.get("thinking_mode")
|
||
max_iterations = payload.get("max_iterations")
|
||
try:
|
||
rec = task_manager.create_chat_task(
|
||
username,
|
||
message,
|
||
images,
|
||
conversation_id,
|
||
model_key=model_key,
|
||
thinking_mode=thinking_mode,
|
||
max_iterations=max_iterations,
|
||
)
|
||
except RuntimeError as exc:
|
||
return jsonify({"success": False, "error": str(exc)}), 409
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"task_id": rec.task_id,
|
||
"status": rec.status,
|
||
"created_at": rec.created_at,
|
||
"conversation_id": rec.conversation_id,
|
||
}
|
||
}), 202
|
||
|
||
|
||
@tasks_bp.route("/api/tasks/<task_id>", methods=["GET"])
|
||
@api_login_required
|
||
def get_task_api(task_id: str):
|
||
username = get_current_username()
|
||
rec = task_manager.get_task(username, task_id)
|
||
if not rec:
|
||
return jsonify({"success": False, "error": "任务不存在"}), 404
|
||
try:
|
||
offset = int(request.args.get("from", 0))
|
||
except Exception:
|
||
offset = 0
|
||
events = [e for e in rec.events if e["idx"] >= offset]
|
||
next_offset = events[-1]["idx"] + 1 if events else offset
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"task_id": rec.task_id,
|
||
"status": rec.status,
|
||
"created_at": rec.created_at,
|
||
"updated_at": rec.updated_at,
|
||
"message": rec.message,
|
||
"conversation_id": rec.conversation_id,
|
||
"error": rec.error,
|
||
"events": events,
|
||
"next_offset": next_offset,
|
||
}
|
||
})
|
||
|
||
|
||
@tasks_bp.route("/api/tasks/<task_id>/cancel", methods=["POST"])
|
||
@api_login_required
|
||
def cancel_task_api(task_id: str):
|
||
username = get_current_username()
|
||
ok = task_manager.cancel_task(username, task_id)
|
||
if not ok:
|
||
return jsonify({"success": False, "error": "任务不存在"}), 404
|
||
return jsonify({"success": True})
|