agent-Specialization/server/chat_flow_task_support.py

188 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import asyncio
import time
from typing import Dict, List, Optional
from modules.sub_agent_manager import TERMINAL_STATUSES
async def process_sub_agent_updates(*, messages: List[Dict], inline: bool = False, after_tool_call_id: Optional[str] = None, web_terminal, sender, debug_log, maybe_mark_failure_from_message):
"""轮询子智能体任务并通知前端,并把结果插入当前对话上下文。"""
manager = getattr(web_terminal, "sub_agent_manager", None)
if not manager:
return
# 获取已通知的任务集合
if not hasattr(web_terminal, '_announced_sub_agent_tasks'):
web_terminal._announced_sub_agent_tasks = set()
try:
updates = manager.poll_updates()
debug_log(f"[SubAgent] poll inline={inline} after_tool_call_id={after_tool_call_id} updates={len(updates)}")
except Exception as exc:
debug_log(f"子智能体状态检查失败: {exc}")
return
# 兜底:如果 poll_updates 没命中,但任务已被别处更新为终态且未通知,补发一次
if not updates:
synthesized = []
try:
for task_id, task in getattr(manager, "tasks", {}).items():
if not isinstance(task, dict):
continue
status = task.get("status")
if status not in TERMINAL_STATUSES.union({"terminated"}):
continue
if task.get("notified"):
continue
task_conv_id = task.get("conversation_id")
current_conv_id = getattr(getattr(web_terminal, "context_manager", None), "current_conversation_id", None)
if task_conv_id and current_conv_id and task_conv_id != current_conv_id:
continue
final_result = task.get("final_result")
if not final_result:
try:
final_result = manager._check_task_status(task)
except Exception:
final_result = None
if isinstance(final_result, dict):
synthesized.append(final_result)
except Exception as exc:
debug_log(f"[SubAgent] synthesized updates failed: {exc}")
synthesized = []
if synthesized:
updates = synthesized
debug_log(f"[SubAgent] synthesized updates count={len(updates)}")
if inline and not hasattr(web_terminal, "_inline_sub_agent_notified"):
web_terminal._inline_sub_agent_notified = set()
for update in updates:
task_id = update.get("task_id")
task_info = manager.tasks.get(task_id) if task_id else None
current_conv_id = getattr(getattr(web_terminal, "context_manager", None), "current_conversation_id", None)
task_conv_id = task_info.get("conversation_id") if isinstance(task_info, dict) else None
if task_conv_id and current_conv_id and task_conv_id != current_conv_id:
debug_log(f"[SubAgent] 跳过非当前对话任务: task={task_id} conv={task_conv_id} current={current_conv_id}")
continue
if task_id and task_info is None:
debug_log(f"[SubAgent] 找不到任务详情,跳过: task={task_id}")
continue
# 检查是否已经通知过这个任务
if task_id and task_id in web_terminal._announced_sub_agent_tasks:
debug_log(f"[SubAgent] 任务 {task_id} 已通知过,跳过")
continue
message = update.get("system_message")
if not message:
debug_log(f"[SubAgent] update missing system_message: task={task_id} keys={list(update.keys())}")
continue
if inline:
inline_key = ("task", task_id) if task_id else ("msg", message)
if inline_key in web_terminal._inline_sub_agent_notified:
debug_log(f"[SubAgent] inline 通知已发送,跳过: key={inline_key}")
continue
debug_log(f"[SubAgent] update task={task_id} inline={inline} msg={message}")
# 记录到对话历史(用于后续 build_messages 转换为 user 消息)
if hasattr(web_terminal, "_record_sub_agent_message"):
try:
web_terminal._record_sub_agent_message(message, task_id, inline=inline)
except Exception as exc:
debug_log(f"[SubAgent] 记录子智能体消息失败: {exc}")
# 标记任务已通知
if task_id:
web_terminal._announced_sub_agent_tasks.add(task_id)
if isinstance(task_info, dict):
task_info["notified"] = True
task_info["updated_at"] = time.time()
try:
manager._save_state()
except Exception as exc:
debug_log(f"[SubAgent] 保存通知状态失败: {exc}")
debug_log(f"[SubAgent] 计算插入位置")
insert_index = len(messages)
if after_tool_call_id:
for idx, msg in enumerate(messages):
if msg.get("role") == "tool" and msg.get("tool_call_id") == after_tool_call_id:
insert_index = idx + 1
break
# 运行中插入 system 消息,避免触发新的 user 轮次;非运行中保持 user 通知
insert_role = "system" if inline else "user"
if not inline:
prefix = "这是一句系统自动发送的user消息用于通知你子智能体已经运行完成"
if not message.startswith(prefix):
message = f"{prefix}\n\n{message}"
messages.insert(insert_index, {
"role": insert_role,
"content": message,
"metadata": {"sub_agent_notice": True, "inline": inline, "task_id": task_id}
})
if inline:
web_terminal._inline_sub_agent_notified.add(inline_key)
debug_log(f"[SubAgent] 插入子智能体通知位置: {insert_index} role={insert_role} after_tool_call_id={after_tool_call_id}")
sender('system_message', {
'content': message,
'inline': inline,
'sub_agent_notice': True
})
maybe_mark_failure_from_message(web_terminal, message)
async def wait_retry_delay(*, delay_seconds: int, client_sid: str, username: str, sender, get_stop_flag, clear_stop_flag) -> bool:
"""等待重试间隔,同时检查是否收到停止请求。"""
if delay_seconds <= 0:
return False
deadline = time.time() + delay_seconds
while time.time() < deadline:
client_stop_info = get_stop_flag(client_sid, username)
if client_stop_info:
stop_requested = client_stop_info.get('stop', False) if isinstance(client_stop_info, dict) else client_stop_info
if stop_requested:
sender('task_stopped', {
'message': '命令执行被用户取消',
'reason': 'user_stop'
})
clear_stop_flag(client_sid, username)
return True
await asyncio.sleep(0.2)
return False
def cancel_pending_tools(*, tool_calls_list, sender, messages):
"""为尚未返回结果的工具生成取消结果,防止缺失 tool_call_id 造成后续 400。"""
if not tool_calls_list:
return
for tc in tool_calls_list:
tc_id = tc.get("id")
func_name = tc.get("function", {}).get("name")
sender('update_action', {
'preparing_id': tc_id,
'status': 'cancelled',
'result': {
"success": False,
"status": "cancelled",
"message": "命令执行被用户取消",
"tool": func_name
}
})
if tc_id:
messages.append({
"role": "tool",
"tool_call_id": tc_id,
"name": func_name,
"content": "命令执行被用户取消",
"metadata": {"status": "cancelled"}
})