agent-Specialization/modules/sub_agent_manager.py
2025-12-14 04:22:00 +08:00

612 lines
25 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""子智能体任务管理。"""
import json
import shutil
import time
import uuid
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
import httpx
from config import (
OUTPUT_FORMATS,
SUB_AGENT_DEFAULT_TIMEOUT,
SUB_AGENT_MAX_ACTIVE,
SUB_AGENT_PROJECT_RESULTS_DIR,
SUB_AGENT_SERVICE_BASE_URL,
SUB_AGENT_STATE_FILE,
SUB_AGENT_STATUS_POLL_INTERVAL,
SUB_AGENT_TASKS_BASE_DIR,
)
from utils.logger import setup_logger
import logging
# 静音子智能体日志(交由前端提示/brief_log处理
logger = setup_logger(__name__)
logger.setLevel(logging.CRITICAL)
logger.disabled = True
logger.propagate = False
for h in list(logger.handlers):
logger.removeHandler(h)
TERMINAL_STATUSES = {"completed", "failed", "timeout"}
class SubAgentManager:
"""负责主智能体与子智能体服务之间的任务调度。"""
def __init__(self, project_path: str, data_dir: str):
self.project_path = Path(project_path).resolve()
self.data_dir = Path(data_dir).resolve()
self.base_dir = Path(SUB_AGENT_TASKS_BASE_DIR).resolve()
self.results_dir = Path(SUB_AGENT_PROJECT_RESULTS_DIR).resolve()
self.state_file = Path(SUB_AGENT_STATE_FILE).resolve()
self.sub_agent_conversations_dir = (self.data_dir / "conversations" / "sub_agent").resolve()
self.base_dir.mkdir(parents=True, exist_ok=True)
self.results_dir.mkdir(parents=True, exist_ok=True)
self.state_file.parent.mkdir(parents=True, exist_ok=True)
self.sub_agent_conversations_dir.mkdir(parents=True, exist_ok=True)
self.tasks: Dict[str, Dict] = {}
self.conversation_agents: Dict[str, List[int]] = {}
self._load_state()
# ------------------------------------------------------------------
# 公共方法
# ------------------------------------------------------------------
def create_sub_agent(
self,
*,
agent_id: int,
summary: str,
task: str,
target_dir: str,
reference_files: Optional[List[str]] = None,
timeout_seconds: Optional[int] = None,
conversation_id: Optional[str] = None,
) -> Dict:
"""创建子智能体任务并启动远端服务。"""
reference_files = reference_files or []
validation_error = self._validate_create_params(agent_id, summary, task, target_dir)
if validation_error:
return {"success": False, "error": validation_error}
if not conversation_id:
return {"success": False, "error": "缺少对话ID无法创建子智能体"}
if not self._ensure_agent_slot_available(conversation_id, agent_id):
return {
"success": False,
"error": f"该对话已使用过编号 {agent_id},请更换新的子智能体代号。"
}
if self._active_task_count(conversation_id) >= SUB_AGENT_MAX_ACTIVE:
return {
"success": False,
"error": f"该对话已存在 {SUB_AGENT_MAX_ACTIVE} 个运行中的子智能体,请稍后再试。",
}
task_id = self._generate_task_id(agent_id)
task_root = self.base_dir / task_id
workspace_dir = task_root / "workspace"
references_dir = workspace_dir / "references"
deliverables_dir = workspace_dir / "deliverables"
for path in (task_root, references_dir, deliverables_dir, workspace_dir):
path.mkdir(parents=True, exist_ok=True)
copied_refs, copy_errors = self._copy_reference_files(reference_files, references_dir)
if copy_errors:
return {"success": False, "error": "; ".join(copy_errors)}
try:
target_project_dir = self._ensure_project_subdir(target_dir)
except ValueError as exc:
return {"success": False, "error": str(exc)}
timeout_seconds = timeout_seconds or SUB_AGENT_DEFAULT_TIMEOUT
payload = {
"task_id": task_id,
"agent_id": agent_id,
"summary": summary,
"task": task,
"target_project_dir": str(target_project_dir),
"workspace_dir": str(workspace_dir),
"references_dir": str(references_dir),
"deliverables_dir": str(deliverables_dir),
"timeout_seconds": timeout_seconds,
"parent_conversation_id": conversation_id,
"data_dir": str(self.data_dir),
"reference_manifest": copied_refs,
"conversation_storage_dir": str(self.sub_agent_conversations_dir),
}
service_response = self._call_service("POST", "/tasks", payload, timeout_seconds + 5)
if not service_response.get("success"):
self._cleanup_task_folder(task_root)
return {
"success": False,
"error": service_response.get("error", "子智能体服务调用失败"),
"details": service_response,
}
status = service_response.get("status", "pending")
sub_conversation_id = service_response.get("sub_conversation_id")
task_record = {
"task_id": task_id,
"agent_id": agent_id,
"summary": summary,
"task": task,
"status": status,
"target_project_dir": str(target_project_dir),
"references_dir": str(references_dir),
"deliverables_dir": str(deliverables_dir),
"workspace_dir": str(workspace_dir),
"copied_references": copied_refs,
"timeout_seconds": timeout_seconds,
"service_payload": payload,
"created_at": time.time(),
"conversation_id": conversation_id,
"sub_conversation_id": sub_conversation_id,
"parent_conversation_id": conversation_id,
}
self.tasks[task_id] = task_record
self._mark_agent_id_used(conversation_id, agent_id)
self._save_state()
message = f"子智能体{agent_id} 已创建任务ID: {task_id},当前状态:{status}"
print(f"{OUTPUT_FORMATS['info']} {message}")
return {
"success": True,
"task_id": task_id,
"agent_id": agent_id,
"status": status,
"message": message,
"deliverables_dir": str(deliverables_dir),
"copied_references": copied_refs,
"sub_conversation_id": sub_conversation_id,
}
def wait_for_completion(
self,
*,
task_id: Optional[str] = None,
agent_id: Optional[int] = None,
timeout_seconds: Optional[int] = None,
) -> Dict:
"""阻塞等待子智能体完成或超时。"""
task = self._select_task(task_id, agent_id)
if not task:
return {"success": False, "error": "未找到对应的子智能体任务"}
if task.get("status") in TERMINAL_STATUSES or task.get("status") == "terminated":
if task.get("final_result"):
return task["final_result"]
return {"success": False, "status": task.get("status"), "message": "子智能体已结束。"}
timeout_seconds = timeout_seconds or task.get("timeout_seconds") or SUB_AGENT_DEFAULT_TIMEOUT
deadline = time.time() + timeout_seconds
last_payload: Optional[Dict] = None
while time.time() < deadline:
last_payload = self._call_service("GET", f"/tasks/{task['task_id']}", timeout=15)
status = last_payload.get("status")
if not last_payload.get("success") and status not in TERMINAL_STATUSES:
time.sleep(SUB_AGENT_STATUS_POLL_INTERVAL)
continue
if status in {"completed", "failed", "timeout", "terminated"}:
break
time.sleep(SUB_AGENT_STATUS_POLL_INTERVAL)
else:
status = "timeout"
last_payload = {"success": False, "status": status, "message": "等待超时"}
if not last_payload:
last_payload = {"success": False, "status": "unknown", "message": "无法获取子智能体状态"}
status = "unknown"
else:
status = last_payload.get("status", status)
finalize_result = self._finalize_task(task, last_payload or {}, status)
self._save_state()
return finalize_result
def terminate_sub_agent(
self,
*,
task_id: Optional[str] = None,
agent_id: Optional[int] = None,
) -> Dict:
"""强制关闭指定子智能体。"""
task = self._select_task(task_id, agent_id)
if not task:
return {"success": False, "error": "未找到对应的子智能体任务"}
task_id = task["task_id"]
response = self._call_service("POST", f"/tasks/{task_id}/terminate", timeout=10)
response["task_id"] = task_id
if response.get("success"):
task["status"] = "terminated"
task["final_result"] = {
"success": False,
"status": "terminated",
"task_id": task_id,
"agent_id": task.get("agent_id"),
"message": response.get("message") or "子智能体已被强制关闭。",
}
self._save_state()
if "system_message" not in response:
response["system_message"] = response.get("message") or "🛑 子智能体已被手动关闭。"
elif "system_message" not in response:
response["system_message"] = response.get("message")
return response
# ------------------------------------------------------------------
# 内部工具方法
# ------------------------------------------------------------------
def _load_state(self):
if self.state_file.exists():
try:
data = json.loads(self.state_file.read_text(encoding="utf-8"))
self.tasks = data.get("tasks", {})
self.conversation_agents = data.get("conversation_agents", {})
except json.JSONDecodeError:
logger.warning("子智能体状态文件损坏,已忽略。")
self.tasks = {}
self.conversation_agents = {}
else:
self.tasks = {}
self.conversation_agents = {}
if self.tasks:
migrated = False
for task in self.tasks.values():
if task.get("parent_conversation_id"):
continue
candidate = task.get("conversation_id") or (task.get("service_payload") or {}).get("parent_conversation_id")
if candidate:
task["parent_conversation_id"] = candidate
migrated = True
if migrated:
self._save_state()
def _save_state(self):
payload = {
"tasks": self.tasks,
"conversation_agents": self.conversation_agents
}
self.state_file.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def _generate_task_id(self, agent_id: int) -> str:
suffix = uuid.uuid4().hex[:6]
return f"sub_{agent_id}_{int(time.time())}_{suffix}"
def _active_task_count(self, conversation_id: Optional[str] = None) -> int:
active = [
t for t in self.tasks.values()
if t.get("status") in {"pending", "running"}
]
if conversation_id:
active = [
t for t in active
if t.get("conversation_id") == conversation_id
]
return len(active)
def _copy_reference_files(self, references: List[str], dest_dir: Path) -> Tuple[List[str], List[str]]:
copied = []
errors = []
for rel_path in references:
rel_path = rel_path.strip()
if not rel_path:
continue
try:
source = self._resolve_project_file(rel_path)
except ValueError as exc:
errors.append(str(exc))
continue
if not source.exists():
errors.append(f"参考文件不存在: {rel_path}")
continue
target_path = dest_dir / rel_path
target_path.parent.mkdir(parents=True, exist_ok=True)
try:
shutil.copy2(source, target_path)
copied.append(rel_path)
except Exception as exc:
errors.append(f"复制 {rel_path} 失败: {exc}")
return copied, errors
def _ensure_project_subdir(self, relative_dir: str) -> Path:
relative_dir = relative_dir.strip() if relative_dir else ""
if not relative_dir:
relative_dir = "sub_agent_results"
target = (self.project_path / relative_dir).resolve()
if not str(target).startswith(str(self.project_path)):
raise ValueError("指定文件夹必须位于项目目录内")
target.mkdir(parents=True, exist_ok=True)
return target
def _resolve_project_file(self, relative_path: str) -> Path:
relative_path = relative_path.strip()
candidate = (self.project_path / relative_path).resolve()
if not str(candidate).startswith(str(self.project_path)):
raise ValueError(f"非法的参考文件路径: {relative_path}")
return candidate
def _select_task(self, task_id: Optional[str], agent_id: Optional[int]) -> Optional[Dict]:
if task_id:
return self.tasks.get(task_id)
if agent_id is None:
return None
# 返回最新的匹配任务
candidates = [
task for task in self.tasks.values()
if task.get("agent_id") == agent_id and task.get("status") in {"pending", "running"}
]
if candidates:
candidates.sort(key=lambda item: item.get("created_at", 0), reverse=True)
return candidates[0]
return None
def lookup_task(self, *, task_id: Optional[str] = None, agent_id: Optional[int] = None) -> Optional[Dict]:
"""只读查询任务信息,供 wait_sub_agent 自动调整超时时间。"""
task = self._select_task(task_id, agent_id)
if not task:
return None
return {
"task_id": task.get("task_id"),
"agent_id": task.get("agent_id"),
"status": task.get("status"),
"timeout_seconds": task.get("timeout_seconds"),
"conversation_id": task.get("conversation_id"),
}
def poll_updates(self) -> List[Dict]:
"""检查运行中的子智能体任务,返回新完成的结果。"""
updates: List[Dict] = []
pending_tasks = [
task for task in self.tasks.values()
if task.get("status") not in TERMINAL_STATUSES.union({"terminated"})
]
logger.debug(f"[SubAgentManager] 待检查任务: {len(pending_tasks)}")
if not pending_tasks:
return updates
state_changed = False
for task in pending_tasks:
payload = self._call_service("GET", f"/tasks/{task['task_id']}", timeout=10)
status = payload.get("status")
logger.debug(f"[SubAgentManager] 任务 {task['task_id']} 服务状态: {status}")
if status not in TERMINAL_STATUSES:
continue
result = self._finalize_task(task, payload, status)
updates.append(result)
state_changed = True
if state_changed:
self._save_state()
return updates
def _call_service(self, method: str, path: str, payload: Optional[Dict] = None, timeout: Optional[int] = None) -> Dict:
url = f"{SUB_AGENT_SERVICE_BASE_URL.rstrip('/')}{path}"
try:
with httpx.Client(timeout=timeout or 10) as client:
if method.upper() == "POST":
response = client.post(url, json=payload or {})
else:
response = client.get(url)
response.raise_for_status()
return response.json()
except httpx.RequestError as exc:
logger.error(f"子智能体服务请求失败: {exc}")
return {"success": False, "error": f"无法连接子智能体服务: {exc}"}
except httpx.HTTPStatusError as exc:
logger.error(f"子智能体服务返回错误: {exc}")
try:
return exc.response.json()
except Exception:
return {"success": False, "error": f"服务端错误: {exc.response.text}"}
except json.JSONDecodeError:
return {"success": False, "error": "子智能体服务返回格式错误"}
def _finalize_task(self, task: Dict, service_payload: Dict, status: str) -> Dict:
existing_result = task.get("final_result")
if existing_result and task.get("status") in TERMINAL_STATUSES.union({"terminated"}):
return existing_result
task["status"] = status
task["updated_at"] = time.time()
message = service_payload.get("message") or service_payload.get("error") or ""
deliverables_dir = Path(service_payload.get("deliverables_dir") or task.get("deliverables_dir", ""))
logger.debug(f"[SubAgentManager] finalize task={task['task_id']} status={status}")
if status == "terminated":
system_message = service_payload.get("system_message") or "🛑 子智能体已被手动关闭。"
result = {
"success": False,
"task_id": task["task_id"],
"agent_id": task["agent_id"],
"status": "terminated",
"message": message or "子智能体已被手动关闭。",
"details": service_payload,
"sub_conversation_id": task.get("sub_conversation_id"),
"system_message": system_message,
}
task["final_result"] = result
return result
if status != "completed":
result = {
"success": False,
"task_id": task["task_id"],
"agent_id": task["agent_id"],
"status": status,
"message": message or f"子智能体状态:{status}",
"details": service_payload,
"sub_conversation_id": task.get("sub_conversation_id"),
"system_message": self._build_system_message(task, status, None, message),
}
task["final_result"] = result
return result
if not deliverables_dir.exists():
result = {
"success": False,
"task_id": task["task_id"],
"agent_id": task["agent_id"],
"status": "failed",
"error": f"未找到交付目录: {deliverables_dir}",
"system_message": self._build_system_message(task, "failed", None, f"未找到交付目录: {deliverables_dir}"),
}
task["status"] = "failed"
task["final_result"] = result
return result
result_md = deliverables_dir / "result.md"
if not result_md.exists():
result = {
"success": False,
"task_id": task["task_id"],
"agent_id": task["agent_id"],
"status": "failed",
"error": "交付目录缺少 result.md无法完成任务。",
"system_message": self._build_system_message(task, "failed", None, "交付目录缺少 result.md"),
}
task["status"] = "failed"
task["final_result"] = result
return result
copied_path = self._copy_deliverables_to_project(task, deliverables_dir)
task["copied_path"] = str(copied_path)
system_message = self._build_system_message(task, status, copied_path, message)
result = {
"success": True,
"task_id": task["task_id"],
"agent_id": task["agent_id"],
"status": status,
"message": message or "子智能体已完成任务。",
"deliverables_path": str(deliverables_dir),
"copied_path": str(copied_path),
"sub_conversation_id": task.get("sub_conversation_id"),
"system_message": system_message,
"details": service_payload,
}
task["final_result"] = result
return result
def _copy_deliverables_to_project(self, task: Dict, source_dir: Path) -> Path:
"""将交付文件复制到项目目录下的指定文件夹。"""
target_dir = Path(task["target_project_dir"])
target_dir.mkdir(parents=True, exist_ok=True)
dest_dir = target_dir / f"{task['task_id']}_deliverables"
if dest_dir.exists():
shutil.rmtree(dest_dir)
shutil.copytree(source_dir, dest_dir)
return dest_dir
def _cleanup_task_folder(self, task_root: Path):
if task_root.exists():
shutil.rmtree(task_root, ignore_errors=True)
def _ensure_agent_slot_available(self, conversation_id: str, agent_id: int) -> bool:
used = self.conversation_agents.setdefault(conversation_id, [])
return agent_id not in used
def _mark_agent_id_used(self, conversation_id: str, agent_id: int):
used = self.conversation_agents.setdefault(conversation_id, [])
if agent_id not in used:
used.append(agent_id)
def _validate_create_params(self, agent_id: Optional[int], summary: str, task: str, target_dir: str) -> Optional[str]:
if agent_id is None:
return "子智能体代号不能为空"
try:
agent_id = int(agent_id)
except ValueError:
return "子智能体代号必须是整数"
if agent_id <= 0:
return "子智能体代号必须为正整数"
if not summary or not summary.strip():
return "任务摘要不能为空"
if not task or not task.strip():
return "任务详情不能为空"
if target_dir is None:
return "指定文件夹不能为空"
return None
def _build_system_message(
self,
task: Dict,
status: str,
copied_path: Optional[Path],
extra_message: Optional[str] = None,
) -> str:
prefix = f"子智能体{task['agent_id']} 任务摘要:{task['summary']}"
extra = (extra_message or "").strip()
if status == "completed" and copied_path:
msg = f"{prefix} 已完成,成果已复制到 {copied_path}"
if extra:
msg += f" ({extra})"
return msg
if status == "timeout":
return f"{prefix} 超时未完成。" + (f" {extra}" if extra else "")
if status == "failed":
return f"{prefix} 执行失败:" + (extra if extra else "请检查交付目录或任务状态。")
return f"{prefix} 状态:{status}" + (extra if extra else "")
def get_overview(self, conversation_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""返回子智能体任务概览,用于前端展示。"""
overview: List[Dict[str, Any]] = []
for task_id, task in self.tasks.items():
if conversation_id and task.get("conversation_id") != conversation_id:
continue
snapshot = {
"task_id": task_id,
"agent_id": task.get("agent_id"),
"summary": task.get("summary"),
"status": task.get("status"),
"created_at": task.get("created_at"),
"updated_at": task.get("updated_at"),
"target_dir": task.get("target_project_dir"),
"last_tool": task.get("last_tool"),
"deliverables_dir": task.get("deliverables_dir"),
"copied_path": task.get("copied_path"),
"conversation_id": task.get("conversation_id"),
"sub_conversation_id": task.get("sub_conversation_id"),
}
# 运行中的任务尝试同步远端最新状态
if snapshot["status"] not in TERMINAL_STATUSES:
remote = self._call_service("GET", f"/tasks/{task_id}", timeout=5)
if remote.get("success"):
snapshot["status"] = remote.get("status", snapshot["status"])
snapshot["remote_message"] = remote.get("message")
snapshot["last_tool"] = remote.get("last_tool")
task["last_tool"] = snapshot["last_tool"]
else:
# 已结束的任务带上最终结果/系统消息,方便前端展示
final_result = task.get("final_result") or {}
snapshot["final_message"] = final_result.get("system_message") or final_result.get("message")
snapshot["success"] = final_result.get("success")
overview.append(snapshot)
overview.sort(key=lambda item: item.get("created_at") or 0, reverse=True)
return overview