112 lines
4.0 KiB
Python
112 lines
4.0 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import time
|
|
from typing import Dict, List, Optional
|
|
|
|
|
|
async def process_sub_agent_updates(*, messages: List[Dict], inline: bool = False, after_tool_call_id: Optional[str] = None, web_terminal, sender, debug_log, maybe_mark_failure_from_message):
|
|
"""轮询子智能体任务并通知前端,并把结果插入当前对话上下文。"""
|
|
manager = getattr(web_terminal, "sub_agent_manager", None)
|
|
if not manager:
|
|
return
|
|
|
|
# 获取已通知的任务集合
|
|
if not hasattr(web_terminal, '_announced_sub_agent_tasks'):
|
|
web_terminal._announced_sub_agent_tasks = set()
|
|
|
|
try:
|
|
updates = manager.poll_updates()
|
|
debug_log(f"[SubAgent] poll inline={inline} updates={len(updates)}")
|
|
except Exception as exc:
|
|
debug_log(f"子智能体状态检查失败: {exc}")
|
|
return
|
|
|
|
for update in updates:
|
|
task_id = update.get("task_id")
|
|
|
|
# 检查是否已经通知过这个任务
|
|
if task_id and task_id in web_terminal._announced_sub_agent_tasks:
|
|
debug_log(f"[SubAgent] 任务 {task_id} 已通知过,跳过")
|
|
continue
|
|
|
|
message = update.get("system_message")
|
|
if not message:
|
|
continue
|
|
|
|
debug_log(f"[SubAgent] update task={task_id} inline={inline} msg={message}")
|
|
|
|
# 标记任务已通知
|
|
if task_id:
|
|
web_terminal._announced_sub_agent_tasks.add(task_id)
|
|
|
|
debug_log(f"[SubAgent] 计算插入位置")
|
|
|
|
insert_index = len(messages)
|
|
if after_tool_call_id:
|
|
for idx, msg in enumerate(messages):
|
|
if msg.get("role") == "tool" and msg.get("tool_call_id") == after_tool_call_id:
|
|
insert_index = idx + 1
|
|
break
|
|
|
|
messages.insert(insert_index, {
|
|
"role": "system",
|
|
"content": message,
|
|
"metadata": {"sub_agent_notice": True, "inline": inline, "task_id": task_id}
|
|
})
|
|
debug_log(f"[SubAgent] 插入系统消息位置: {insert_index}")
|
|
sender('system_message', {
|
|
'content': message,
|
|
'inline': inline
|
|
})
|
|
maybe_mark_failure_from_message(web_terminal, message)
|
|
|
|
|
|
|
|
async def wait_retry_delay(*, delay_seconds: int, client_sid: str, username: str, sender, get_stop_flag, clear_stop_flag) -> bool:
|
|
"""等待重试间隔,同时检查是否收到停止请求。"""
|
|
if delay_seconds <= 0:
|
|
return False
|
|
deadline = time.time() + delay_seconds
|
|
while time.time() < deadline:
|
|
client_stop_info = get_stop_flag(client_sid, username)
|
|
if client_stop_info:
|
|
stop_requested = client_stop_info.get('stop', False) if isinstance(client_stop_info, dict) else client_stop_info
|
|
if stop_requested:
|
|
sender('task_stopped', {
|
|
'message': '命令执行被用户取消',
|
|
'reason': 'user_stop'
|
|
})
|
|
clear_stop_flag(client_sid, username)
|
|
return True
|
|
await asyncio.sleep(0.2)
|
|
return False
|
|
|
|
|
|
|
|
def cancel_pending_tools(*, tool_calls_list, sender, messages):
|
|
"""为尚未返回结果的工具生成取消结果,防止缺失 tool_call_id 造成后续 400。"""
|
|
if not tool_calls_list:
|
|
return
|
|
for tc in tool_calls_list:
|
|
tc_id = tc.get("id")
|
|
func_name = tc.get("function", {}).get("name")
|
|
sender('update_action', {
|
|
'preparing_id': tc_id,
|
|
'status': 'cancelled',
|
|
'result': {
|
|
"success": False,
|
|
"status": "cancelled",
|
|
"message": "命令执行被用户取消",
|
|
"tool": func_name
|
|
}
|
|
})
|
|
if tc_id:
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tc_id,
|
|
"name": func_name,
|
|
"content": "命令执行被用户取消",
|
|
"metadata": {"status": "cancelled"}
|
|
})
|