202 lines
6.9 KiB
Python
202 lines
6.9 KiB
Python
# 文件位置: app/services/task_manager.py
|
||
# 文件名: task_manager.py
|
||
|
||
"""
|
||
任务管理器
|
||
替代 Celery 的轻量级任务队列实现
|
||
"""
|
||
import uuid
|
||
import logging
|
||
import threading
|
||
from concurrent.futures import ThreadPoolExecutor, Future
|
||
from typing import Dict, Any, Callable, Optional, List
|
||
from datetime import datetime
|
||
from enum import Enum
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class TaskStatus(Enum):
|
||
"""任务状态"""
|
||
PENDING = "pending"
|
||
RUNNING = "running"
|
||
COMPLETED = "completed"
|
||
FAILED = "failed"
|
||
CANCELLED = "cancelled"
|
||
|
||
class TaskInfo:
|
||
"""任务信息"""
|
||
def __init__(self, task_id: str, func_name: str, args: tuple, kwargs: dict):
|
||
self.id = task_id
|
||
self.func_name = func_name
|
||
self.args = args
|
||
self.kwargs = kwargs
|
||
self.status = TaskStatus.PENDING
|
||
self.created_at = datetime.now()
|
||
self.started_at: Optional[datetime] = None
|
||
self.completed_at: Optional[datetime] = None
|
||
self.result: Any = None
|
||
self.error: Optional[str] = None
|
||
self.future: Optional[Future] = None
|
||
|
||
class TaskManager:
|
||
"""任务管理器单例"""
|
||
_instance = None
|
||
_lock = threading.Lock()
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
with cls._lock:
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if not hasattr(self, 'initialized'):
|
||
self.executor = ThreadPoolExecutor(max_workers=10)
|
||
self.tasks: Dict[str, TaskInfo] = {}
|
||
self.session_tasks: Dict[str, List[str]] = {} # session_id -> task_ids
|
||
self.initialized = True
|
||
logger.info("任务管理器初始化完成")
|
||
|
||
def submit_task(self, func: Callable, *args, **kwargs) -> str:
|
||
"""提交任务"""
|
||
task_id = str(uuid.uuid4())
|
||
task_info = TaskInfo(task_id, func.__name__, args, kwargs)
|
||
|
||
# 提取session_id(如果存在)
|
||
session_id = None
|
||
if args and isinstance(args[0], str) and '-' in args[0]:
|
||
# 假设第一个参数是session_id(UUID格式)
|
||
session_id = args[0]
|
||
elif 'session_id' in kwargs:
|
||
session_id = kwargs['session_id']
|
||
|
||
# 记录任务
|
||
self.tasks[task_id] = task_info
|
||
|
||
# 关联到session
|
||
if session_id:
|
||
if session_id not in self.session_tasks:
|
||
self.session_tasks[session_id] = []
|
||
self.session_tasks[session_id].append(task_id)
|
||
|
||
# 提交执行
|
||
future = self.executor.submit(self._execute_task, task_info, func, *args, **kwargs)
|
||
task_info.future = future
|
||
|
||
logger.info(f"任务提交成功: {task_id} - {func.__name__}")
|
||
return task_id
|
||
|
||
def _execute_task(self, task_info: TaskInfo, func: Callable, *args, **kwargs):
|
||
"""执行任务"""
|
||
try:
|
||
task_info.status = TaskStatus.RUNNING
|
||
task_info.started_at = datetime.now()
|
||
logger.info(f"任务开始执行: {task_info.id} - {task_info.func_name}")
|
||
|
||
# 执行任务
|
||
result = func(*args, **kwargs)
|
||
|
||
# 更新任务信息
|
||
task_info.status = TaskStatus.COMPLETED
|
||
task_info.completed_at = datetime.now()
|
||
task_info.result = result
|
||
|
||
logger.info(f"任务执行成功: {task_info.id}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
task_info.status = TaskStatus.FAILED
|
||
task_info.completed_at = datetime.now()
|
||
task_info.error = str(e)
|
||
logger.error(f"任务执行失败: {task_info.id} - {e}")
|
||
raise
|
||
|
||
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||
"""获取任务状态"""
|
||
if task_id not in self.tasks:
|
||
return None
|
||
|
||
task_info = self.tasks[task_id]
|
||
|
||
return {
|
||
"task_id": task_info.id,
|
||
"status": task_info.status.value,
|
||
"func_name": task_info.func_name,
|
||
"created_at": task_info.created_at.isoformat(),
|
||
"started_at": task_info.started_at.isoformat() if task_info.started_at else None,
|
||
"completed_at": task_info.completed_at.isoformat() if task_info.completed_at else None,
|
||
"error": task_info.error
|
||
}
|
||
|
||
def get_session_tasks(self, session_id: str) -> List[Dict[str, Any]]:
|
||
"""获取会话的所有任务"""
|
||
task_ids = self.session_tasks.get(session_id, [])
|
||
return [self.get_task_status(task_id) for task_id in task_ids if self.get_task_status(task_id)]
|
||
|
||
def cancel_task(self, task_id: str) -> bool:
|
||
"""取消任务"""
|
||
if task_id not in self.tasks:
|
||
return False
|
||
|
||
task_info = self.tasks[task_id]
|
||
|
||
if task_info.future and not task_info.future.done():
|
||
cancelled = task_info.future.cancel()
|
||
if cancelled:
|
||
task_info.status = TaskStatus.CANCELLED
|
||
task_info.completed_at = datetime.now()
|
||
logger.info(f"任务已取消: {task_id}")
|
||
return True
|
||
|
||
return False
|
||
|
||
def cancel_session_tasks(self, session_id: str) -> int:
|
||
"""取消会话的所有任务"""
|
||
task_ids = self.session_tasks.get(session_id, [])
|
||
cancelled_count = 0
|
||
|
||
for task_id in task_ids:
|
||
if self.cancel_task(task_id):
|
||
cancelled_count += 1
|
||
|
||
return cancelled_count
|
||
|
||
def cleanup_old_tasks(self, hours: int = 24):
|
||
"""清理旧任务"""
|
||
cutoff_time = datetime.now().timestamp() - (hours * 3600)
|
||
tasks_to_remove = []
|
||
|
||
for task_id, task_info in self.tasks.items():
|
||
if task_info.completed_at and task_info.completed_at.timestamp() < cutoff_time:
|
||
tasks_to_remove.append(task_id)
|
||
|
||
for task_id in tasks_to_remove:
|
||
del self.tasks[task_id]
|
||
# 从session_tasks中移除
|
||
for session_id, task_ids in self.session_tasks.items():
|
||
if task_id in task_ids:
|
||
task_ids.remove(task_id)
|
||
|
||
logger.info(f"清理了 {len(tasks_to_remove)} 个旧任务")
|
||
return len(tasks_to_remove)
|
||
|
||
def shutdown(self):
|
||
"""关闭任务管理器"""
|
||
self.executor.shutdown(wait=True)
|
||
logger.info("任务管理器已关闭")
|
||
|
||
# 全局任务管理器实例
|
||
task_manager = TaskManager()
|
||
|
||
# 装饰器:将普通函数转换为异步任务
|
||
def async_task(func):
|
||
"""异步任务装饰器"""
|
||
def wrapper(*args, **kwargs):
|
||
return task_manager.submit_task(func, *args, **kwargs)
|
||
|
||
wrapper.delay = wrapper # 兼容Celery的.delay()调用方式
|
||
wrapper.__name__ = func.__name__
|
||
wrapper.__doc__ = func.__doc__
|
||
|
||
return wrapper |