# 文件位置: app/services/task_manager.py # 文件名: task_manager.py """ 任务管理器 替代 Celery 的轻量级任务队列实现 """ import uuid import logging import threading from concurrent.futures import ThreadPoolExecutor, Future from typing import Dict, Any, Callable, Optional, List from datetime import datetime from enum import Enum logger = logging.getLogger(__name__) class TaskStatus(Enum): """任务状态""" PENDING = "pending" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" class TaskInfo: """任务信息""" def __init__(self, task_id: str, func_name: str, args: tuple, kwargs: dict): self.id = task_id self.func_name = func_name self.args = args self.kwargs = kwargs self.status = TaskStatus.PENDING self.created_at = datetime.now() self.started_at: Optional[datetime] = None self.completed_at: Optional[datetime] = None self.result: Any = None self.error: Optional[str] = None self.future: Optional[Future] = None class TaskManager: """任务管理器单例""" _instance = None _lock = threading.Lock() def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): if not hasattr(self, 'initialized'): self.executor = ThreadPoolExecutor(max_workers=10) self.tasks: Dict[str, TaskInfo] = {} self.session_tasks: Dict[str, List[str]] = {} # session_id -> task_ids self.initialized = True logger.info("任务管理器初始化完成") def submit_task(self, func: Callable, *args, **kwargs) -> str: """提交任务""" task_id = str(uuid.uuid4()) task_info = TaskInfo(task_id, func.__name__, args, kwargs) # 提取session_id(如果存在) session_id = None if args and isinstance(args[0], str) and '-' in args[0]: # 假设第一个参数是session_id(UUID格式) session_id = args[0] elif 'session_id' in kwargs: session_id = kwargs['session_id'] # 记录任务 self.tasks[task_id] = task_info # 关联到session if session_id: if session_id not in self.session_tasks: self.session_tasks[session_id] = [] self.session_tasks[session_id].append(task_id) # 提交执行 future = self.executor.submit(self._execute_task, task_info, func, *args, **kwargs) task_info.future = future logger.info(f"任务提交成功: {task_id} - {func.__name__}") return task_id def _execute_task(self, task_info: TaskInfo, func: Callable, *args, **kwargs): """执行任务""" try: task_info.status = TaskStatus.RUNNING task_info.started_at = datetime.now() logger.info(f"任务开始执行: {task_info.id} - {task_info.func_name}") # 执行任务 result = func(*args, **kwargs) # 更新任务信息 task_info.status = TaskStatus.COMPLETED task_info.completed_at = datetime.now() task_info.result = result logger.info(f"任务执行成功: {task_info.id}") return result except Exception as e: task_info.status = TaskStatus.FAILED task_info.completed_at = datetime.now() task_info.error = str(e) logger.error(f"任务执行失败: {task_info.id} - {e}") raise def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]: """获取任务状态""" if task_id not in self.tasks: return None task_info = self.tasks[task_id] return { "task_id": task_info.id, "status": task_info.status.value, "func_name": task_info.func_name, "created_at": task_info.created_at.isoformat(), "started_at": task_info.started_at.isoformat() if task_info.started_at else None, "completed_at": task_info.completed_at.isoformat() if task_info.completed_at else None, "error": task_info.error } def get_session_tasks(self, session_id: str) -> List[Dict[str, Any]]: """获取会话的所有任务""" task_ids = self.session_tasks.get(session_id, []) return [self.get_task_status(task_id) for task_id in task_ids if self.get_task_status(task_id)] def cancel_task(self, task_id: str) -> bool: """取消任务""" if task_id not in self.tasks: return False task_info = self.tasks[task_id] if task_info.future and not task_info.future.done(): cancelled = task_info.future.cancel() if cancelled: task_info.status = TaskStatus.CANCELLED task_info.completed_at = datetime.now() logger.info(f"任务已取消: {task_id}") return True return False def cancel_session_tasks(self, session_id: str) -> int: """取消会话的所有任务""" task_ids = self.session_tasks.get(session_id, []) cancelled_count = 0 for task_id in task_ids: if self.cancel_task(task_id): cancelled_count += 1 return cancelled_count def cleanup_old_tasks(self, hours: int = 24): """清理旧任务""" cutoff_time = datetime.now().timestamp() - (hours * 3600) tasks_to_remove = [] for task_id, task_info in self.tasks.items(): if task_info.completed_at and task_info.completed_at.timestamp() < cutoff_time: tasks_to_remove.append(task_id) for task_id in tasks_to_remove: del self.tasks[task_id] # 从session_tasks中移除 for session_id, task_ids in self.session_tasks.items(): if task_id in task_ids: task_ids.remove(task_id) logger.info(f"清理了 {len(tasks_to_remove)} 个旧任务") return len(tasks_to_remove) def shutdown(self): """关闭任务管理器""" self.executor.shutdown(wait=True) logger.info("任务管理器已关闭") # 全局任务管理器实例 task_manager = TaskManager() # 装饰器:将普通函数转换为异步任务 def async_task(func): """异步任务装饰器""" def wrapper(*args, **kwargs): return task_manager.submit_task(func, *args, **kwargs) wrapper.delay = wrapper # 兼容Celery的.delay()调用方式 wrapper.__name__ = func.__name__ wrapper.__doc__ = func.__doc__ return wrapper