deepresearch/所有文件/task_manager.py
2025-07-02 15:35:36 +08:00

202 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 文件位置: app/services/task_manager.py
# 文件名: task_manager.py
"""
任务管理器
替代 Celery 的轻量级任务队列实现
"""
import uuid
import logging
import threading
from concurrent.futures import ThreadPoolExecutor, Future
from typing import Dict, Any, Callable, Optional, List
from datetime import datetime
from enum import Enum
logger = logging.getLogger(__name__)
class TaskStatus(Enum):
"""任务状态"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TaskInfo:
"""任务信息"""
def __init__(self, task_id: str, func_name: str, args: tuple, kwargs: dict):
self.id = task_id
self.func_name = func_name
self.args = args
self.kwargs = kwargs
self.status = TaskStatus.PENDING
self.created_at = datetime.now()
self.started_at: Optional[datetime] = None
self.completed_at: Optional[datetime] = None
self.result: Any = None
self.error: Optional[str] = None
self.future: Optional[Future] = None
class TaskManager:
"""任务管理器单例"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, 'initialized'):
self.executor = ThreadPoolExecutor(max_workers=10)
self.tasks: Dict[str, TaskInfo] = {}
self.session_tasks: Dict[str, List[str]] = {} # session_id -> task_ids
self.initialized = True
logger.info("任务管理器初始化完成")
def submit_task(self, func: Callable, *args, **kwargs) -> str:
"""提交任务"""
task_id = str(uuid.uuid4())
task_info = TaskInfo(task_id, func.__name__, args, kwargs)
# 提取session_id如果存在
session_id = None
if args and isinstance(args[0], str) and '-' in args[0]:
# 假设第一个参数是session_idUUID格式
session_id = args[0]
elif 'session_id' in kwargs:
session_id = kwargs['session_id']
# 记录任务
self.tasks[task_id] = task_info
# 关联到session
if session_id:
if session_id not in self.session_tasks:
self.session_tasks[session_id] = []
self.session_tasks[session_id].append(task_id)
# 提交执行
future = self.executor.submit(self._execute_task, task_info, func, *args, **kwargs)
task_info.future = future
logger.info(f"任务提交成功: {task_id} - {func.__name__}")
return task_id
def _execute_task(self, task_info: TaskInfo, func: Callable, *args, **kwargs):
"""执行任务"""
try:
task_info.status = TaskStatus.RUNNING
task_info.started_at = datetime.now()
logger.info(f"任务开始执行: {task_info.id} - {task_info.func_name}")
# 执行任务
result = func(*args, **kwargs)
# 更新任务信息
task_info.status = TaskStatus.COMPLETED
task_info.completed_at = datetime.now()
task_info.result = result
logger.info(f"任务执行成功: {task_info.id}")
return result
except Exception as e:
task_info.status = TaskStatus.FAILED
task_info.completed_at = datetime.now()
task_info.error = str(e)
logger.error(f"任务执行失败: {task_info.id} - {e}")
raise
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""获取任务状态"""
if task_id not in self.tasks:
return None
task_info = self.tasks[task_id]
return {
"task_id": task_info.id,
"status": task_info.status.value,
"func_name": task_info.func_name,
"created_at": task_info.created_at.isoformat(),
"started_at": task_info.started_at.isoformat() if task_info.started_at else None,
"completed_at": task_info.completed_at.isoformat() if task_info.completed_at else None,
"error": task_info.error
}
def get_session_tasks(self, session_id: str) -> List[Dict[str, Any]]:
"""获取会话的所有任务"""
task_ids = self.session_tasks.get(session_id, [])
return [self.get_task_status(task_id) for task_id in task_ids if self.get_task_status(task_id)]
def cancel_task(self, task_id: str) -> bool:
"""取消任务"""
if task_id not in self.tasks:
return False
task_info = self.tasks[task_id]
if task_info.future and not task_info.future.done():
cancelled = task_info.future.cancel()
if cancelled:
task_info.status = TaskStatus.CANCELLED
task_info.completed_at = datetime.now()
logger.info(f"任务已取消: {task_id}")
return True
return False
def cancel_session_tasks(self, session_id: str) -> int:
"""取消会话的所有任务"""
task_ids = self.session_tasks.get(session_id, [])
cancelled_count = 0
for task_id in task_ids:
if self.cancel_task(task_id):
cancelled_count += 1
return cancelled_count
def cleanup_old_tasks(self, hours: int = 24):
"""清理旧任务"""
cutoff_time = datetime.now().timestamp() - (hours * 3600)
tasks_to_remove = []
for task_id, task_info in self.tasks.items():
if task_info.completed_at and task_info.completed_at.timestamp() < cutoff_time:
tasks_to_remove.append(task_id)
for task_id in tasks_to_remove:
del self.tasks[task_id]
# 从session_tasks中移除
for session_id, task_ids in self.session_tasks.items():
if task_id in task_ids:
task_ids.remove(task_id)
logger.info(f"清理了 {len(tasks_to_remove)} 个旧任务")
return len(tasks_to_remove)
def shutdown(self):
"""关闭任务管理器"""
self.executor.shutdown(wait=True)
logger.info("任务管理器已关闭")
# 全局任务管理器实例
task_manager = TaskManager()
# 装饰器:将普通函数转换为异步任务
def async_task(func):
"""异步任务装饰器"""
def wrapper(*args, **kwargs):
return task_manager.submit_task(func, *args, **kwargs)
wrapper.delay = wrapper # 兼容Celery的.delay()调用方式
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
return wrapper