agent/sub_agent/core/sub_agent_terminal.py

130 lines
4.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""子智能体专用终端,实现工具白名单与完成工具。"""
import json
from pathlib import Path
from typing import Dict, Optional, List
from core.web_terminal import WebTerminal
from config import PROMPTS_DIR
FORBIDDEN_SUB_AGENT_TOOLS = {"update_memory", "todo_create", "todo_update_task", "todo_finish", "todo_finish_confirm"}
class SubAgentTerminal(WebTerminal):
"""子智能体 Web 终端,限制工具并提供 finish_sub_agent。"""
def __init__(
self,
*,
workspace_dir: str,
data_dir: str,
metadata: Dict,
message_callback=None,
):
super().__init__(
project_path=workspace_dir,
thinking_mode=True,
message_callback=message_callback,
data_dir=data_dir,
)
self.sub_agent_meta = metadata
self.finish_callback = None
self._system_prompt_cache: Optional[str] = None
def set_finish_callback(self, callback):
self.finish_callback = callback
def load_prompt(self, name: str) -> str:
if name != "main_system":
return super().load_prompt(name)
if self._system_prompt_cache:
return self._system_prompt_cache
template_path = Path(PROMPTS_DIR) / "sub_agent_system.txt"
if not template_path.exists():
return super().load_prompt(name)
template = template_path.read_text(encoding="utf-8")
data = {
"summary": self.sub_agent_meta.get("summary", ""),
"task": self.sub_agent_meta.get("task", ""),
"workspace": self.sub_agent_meta.get("workspace_dir", ""),
"references": self.sub_agent_meta.get("references_dir", ""),
"deliverables": self.sub_agent_meta.get("deliverables_dir", ""),
"target_project_dir": self.sub_agent_meta.get("target_project_dir", ""),
"agent_id": self.sub_agent_meta.get("agent_id", ""),
"task_id": self.sub_agent_meta.get("task_id", ""),
}
self._system_prompt_cache = template.format(**data)
return self._system_prompt_cache
def define_tools(self) -> List[Dict]:
tools = super().define_tools()
filtered: List[Dict] = []
for tool in tools:
name = tool.get("function", {}).get("name")
if name in FORBIDDEN_SUB_AGENT_TOOLS:
continue
filtered.append(tool)
filtered.append({
"type": "function",
"function": {
"name": "finish_sub_agent",
"description": (
"当你确定交付成果已准备完毕时调用此工具。调用前请确认 deliverables 文件夹存在 result.md"
"其中包含交付说明。参数 reason 用于向主智能体总结本轮完成情况。"
),
"parameters": {
"type": "object",
"properties": {
"reason": {
"type": "string",
"description": "向主智能体说明任务完成情况、交付内容、下一步建议。"
}
},
"required": ["reason"]
}
}
})
return filtered
async def handle_tool_call(self, tool_name: str, arguments: Dict) -> str:
if tool_name == "finish_sub_agent":
result = self._finalize_sub_agent(arguments or {})
return json.dumps(result, ensure_ascii=False)
return await super().handle_tool_call(tool_name, arguments)
def _finalize_sub_agent(self, arguments: Dict) -> Dict:
deliverables_dir = Path(self.sub_agent_meta.get("deliverables_dir", self.project_path))
result_md = deliverables_dir / "result.md"
if not result_md.exists():
return {
"success": False,
"error": "deliverables 目录缺少 result.md无法结束任务。"
}
content = result_md.read_text(encoding="utf-8").strip()
if not content:
return {
"success": False,
"error": "result.md 为空,请写入任务总结与交付说明后再结束任务。"
}
reason = (arguments.get("reason") or "").strip()
if not reason:
return {
"success": False,
"error": "缺少 reason 字段,请说明完成情况。"
}
result = {
"success": True,
"message": "子智能体任务已标记为完成。",
"reason": reason,
}
if self.finish_callback:
try:
self.finish_callback(result)
except Exception:
pass
return result