130 lines
4.7 KiB
Python
130 lines
4.7 KiB
Python
"""子智能体专用终端,实现工具白名单与完成工具。"""
|
||
|
||
import json
|
||
from pathlib import Path
|
||
from typing import Dict, Optional, List
|
||
|
||
from core.web_terminal import WebTerminal
|
||
from config import PROMPTS_DIR
|
||
|
||
FORBIDDEN_SUB_AGENT_TOOLS = {"update_memory", "todo_create", "todo_update_task", "todo_finish", "todo_finish_confirm"}
|
||
|
||
|
||
class SubAgentTerminal(WebTerminal):
|
||
"""子智能体 Web 终端,限制工具并提供 finish_sub_agent。"""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
workspace_dir: str,
|
||
data_dir: str,
|
||
metadata: Dict,
|
||
message_callback=None,
|
||
):
|
||
super().__init__(
|
||
project_path=workspace_dir,
|
||
thinking_mode=True,
|
||
message_callback=message_callback,
|
||
data_dir=data_dir,
|
||
)
|
||
self.sub_agent_meta = metadata
|
||
self.finish_callback = None
|
||
self._system_prompt_cache: Optional[str] = None
|
||
|
||
def set_finish_callback(self, callback):
|
||
self.finish_callback = callback
|
||
|
||
def load_prompt(self, name: str) -> str:
|
||
if name != "main_system":
|
||
return super().load_prompt(name)
|
||
|
||
if self._system_prompt_cache:
|
||
return self._system_prompt_cache
|
||
|
||
template_path = Path(PROMPTS_DIR) / "sub_agent_system.txt"
|
||
if not template_path.exists():
|
||
return super().load_prompt(name)
|
||
|
||
template = template_path.read_text(encoding="utf-8")
|
||
data = {
|
||
"summary": self.sub_agent_meta.get("summary", ""),
|
||
"task": self.sub_agent_meta.get("task", ""),
|
||
"workspace": self.sub_agent_meta.get("workspace_dir", ""),
|
||
"references": self.sub_agent_meta.get("references_dir", ""),
|
||
"deliverables": self.sub_agent_meta.get("deliverables_dir", ""),
|
||
"target_project_dir": self.sub_agent_meta.get("target_project_dir", ""),
|
||
"agent_id": self.sub_agent_meta.get("agent_id", ""),
|
||
"task_id": self.sub_agent_meta.get("task_id", ""),
|
||
}
|
||
self._system_prompt_cache = template.format(**data)
|
||
return self._system_prompt_cache
|
||
|
||
def define_tools(self) -> List[Dict]:
|
||
tools = super().define_tools()
|
||
filtered: List[Dict] = []
|
||
for tool in tools:
|
||
name = tool.get("function", {}).get("name")
|
||
if name in FORBIDDEN_SUB_AGENT_TOOLS:
|
||
continue
|
||
filtered.append(tool)
|
||
|
||
filtered.append({
|
||
"type": "function",
|
||
"function": {
|
||
"name": "finish_sub_agent",
|
||
"description": (
|
||
"当你确定交付成果已准备完毕时调用此工具。调用前请确认 deliverables 文件夹存在 result.md,"
|
||
"其中包含交付说明。参数 reason 用于向主智能体总结本轮完成情况。"
|
||
),
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"reason": {
|
||
"type": "string",
|
||
"description": "向主智能体说明任务完成情况、交付内容、下一步建议。"
|
||
}
|
||
},
|
||
"required": ["reason"]
|
||
}
|
||
}
|
||
})
|
||
return filtered
|
||
|
||
async def handle_tool_call(self, tool_name: str, arguments: Dict) -> str:
|
||
if tool_name == "finish_sub_agent":
|
||
result = self._finalize_sub_agent(arguments or {})
|
||
return json.dumps(result, ensure_ascii=False)
|
||
return await super().handle_tool_call(tool_name, arguments)
|
||
|
||
def _finalize_sub_agent(self, arguments: Dict) -> Dict:
|
||
deliverables_dir = Path(self.sub_agent_meta.get("deliverables_dir", self.project_path))
|
||
result_md = deliverables_dir / "result.md"
|
||
if not result_md.exists():
|
||
return {
|
||
"success": False,
|
||
"error": "deliverables 目录缺少 result.md,无法结束任务。"
|
||
}
|
||
content = result_md.read_text(encoding="utf-8").strip()
|
||
if not content:
|
||
return {
|
||
"success": False,
|
||
"error": "result.md 为空,请写入任务总结与交付说明后再结束任务。"
|
||
}
|
||
reason = (arguments.get("reason") or "").strip()
|
||
if not reason:
|
||
return {
|
||
"success": False,
|
||
"error": "缺少 reason 字段,请说明完成情况。"
|
||
}
|
||
result = {
|
||
"success": True,
|
||
"message": "子智能体任务已标记为完成。",
|
||
"reason": reason,
|
||
}
|
||
if self.finish_callback:
|
||
try:
|
||
self.finish_callback(result)
|
||
except Exception:
|
||
pass
|
||
return result
|