agent-Specialization/server/chat_flow_tool_loop.py
JOJO 823b1e105e feat: implement graceful tool cancellation on stop request
- Add stop flag monitoring loop (checks every 100ms during tool execution)
- Cancel tool task immediately when stop flag is detected
- Return cancellation message to conversation history with role=tool
- Save cancellation result: '命令执行被用户取消'
- Clean up pending tasks to prevent 'Task was destroyed but it is pending' warnings
- Fix terminal_ops.py to properly cancel stdout/stderr read tasks

Known issue: Tool result display in frontend still shows arguments instead of cancellation message when expanded

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-03-08 04:12:50 +08:00

489 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import asyncio
import json
import time
from typing import Optional
from .utils_common import debug_log, brief_log
from .state import MONITOR_FILE_TOOLS, MONITOR_MEMORY_TOOLS, MONITOR_SNAPSHOT_CHAR_LIMIT, MONITOR_MEMORY_ENTRY_LIMIT
from .monitor import cache_monitor_snapshot
from .security import compact_web_search_result
from .chat_flow_helpers import detect_tool_failure
from .chat_flow_runner_helpers import resolve_monitor_path, resolve_monitor_memory, capture_monitor_snapshot
from utils.tool_result_formatter import format_tool_result_for_context
from config import TOOL_CALL_COOLDOWN
async def execute_tool_calls(*, web_terminal, tool_calls, sender, messages, client_sid: str, username: str, iteration: int, conversation_id: Optional[str], last_tool_call_time: float, process_sub_agent_updates, maybe_mark_failure_from_message, mark_force_thinking, get_stop_flag, clear_stop_flag):
# 执行每个工具
for tool_call in tool_calls:
# 检查停止标志
client_stop_info = get_stop_flag(client_sid, username)
if client_stop_info:
stop_requested = client_stop_info.get('stop', False) if isinstance(client_stop_info, dict) else client_stop_info
if stop_requested:
debug_log("在工具调用过程中检测到停止状态")
tool_call_id = tool_call.get("id")
function_name = tool_call.get("function", {}).get("name")
# 通知前端该工具已被取消,避免界面卡住
sender('update_action', {
'preparing_id': tool_call_id,
'status': 'cancelled',
'result': {
"success": False,
"status": "cancelled",
"message": "命令执行被用户取消",
"tool": function_name
}
})
# 在消息列表中记录取消结果,防止重新加载时仍显示运行中
if tool_call_id:
messages.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_name,
"content": "命令执行被用户取消",
"metadata": {"status": "cancelled"}
})
sender('task_stopped', {
'message': '命令执行被用户取消',
'reason': 'user_stop'
})
clear_stop_flag(client_sid, username)
return
# 工具调用间隔控制
current_time = time.time()
if last_tool_call_time > 0:
elapsed = current_time - last_tool_call_time
if elapsed < TOOL_CALL_COOLDOWN:
await asyncio.sleep(TOOL_CALL_COOLDOWN - elapsed)
last_tool_call_time = time.time()
function_name = tool_call["function"]["name"]
arguments_str = tool_call["function"]["arguments"]
tool_call_id = tool_call["id"]
debug_log(f"准备解析JSON工具: {function_name}, 参数长度: {len(arguments_str)}")
debug_log(f"JSON参数前200字符: {arguments_str[:200]}")
debug_log(f"JSON参数后200字符: {arguments_str[-200:]}")
# 使用改进的参数解析方法
if hasattr(web_terminal, 'api_client') and hasattr(web_terminal.api_client, '_safe_tool_arguments_parse'):
success, arguments, error_msg = web_terminal.api_client._safe_tool_arguments_parse(arguments_str, function_name)
if not success:
debug_log(f"安全解析失败: {error_msg}")
error_text = f'工具参数解析失败: {error_msg}'
error_payload = {
"success": False,
"error": error_text,
"error_type": "parameter_format_error",
"tool_name": function_name,
"tool_call_id": tool_call_id,
"message": error_text
}
sender('error', {'message': error_text})
sender('update_action', {
'preparing_id': tool_call_id,
'status': 'completed',
'result': error_payload,
'message': error_text
})
error_content = json.dumps(error_payload, ensure_ascii=False)
web_terminal.context_manager.add_conversation(
"tool",
error_content,
tool_call_id=tool_call_id,
name=function_name
)
messages.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_name,
"content": error_content
})
continue
debug_log(f"使用安全解析成功,参数键: {list(arguments.keys())}")
else:
# 回退到带有基本修复逻辑的解析
try:
arguments = json.loads(arguments_str) if arguments_str.strip() else {}
debug_log(f"直接JSON解析成功参数键: {list(arguments.keys())}")
except json.JSONDecodeError as e:
debug_log(f"原始JSON解析失败: {e}")
# 尝试基本的JSON修复
repaired_str = arguments_str.strip()
repair_attempts = []
# 修复1: 未闭合字符串
if repaired_str.count('"') % 2 == 1:
repaired_str += '"'
repair_attempts.append("添加闭合引号")
# 修复2: 未闭合JSON对象
if repaired_str.startswith('{') and not repaired_str.rstrip().endswith('}'):
repaired_str = repaired_str.rstrip() + '}'
repair_attempts.append("添加闭合括号")
# 修复3: 截断的JSON移除不完整的最后一个键值对
if not repair_attempts: # 如果前面的修复都没用上
last_comma = repaired_str.rfind(',')
if last_comma > 0:
repaired_str = repaired_str[:last_comma] + '}'
repair_attempts.append("移除不完整的键值对")
# 尝试解析修复后的JSON
try:
arguments = json.loads(repaired_str)
debug_log(f"JSON修复成功: {', '.join(repair_attempts)}")
debug_log(f"修复后参数键: {list(arguments.keys())}")
except json.JSONDecodeError as repair_error:
debug_log(f"JSON修复也失败: {repair_error}")
debug_log(f"修复尝试: {repair_attempts}")
debug_log(f"修复后内容前100字符: {repaired_str[:100]}")
error_text = f'工具参数解析失败: {e}'
error_payload = {
"success": False,
"error": error_text,
"error_type": "parameter_format_error",
"tool_name": function_name,
"tool_call_id": tool_call_id,
"message": error_text
}
sender('error', {'message': error_text})
sender('update_action', {
'preparing_id': tool_call_id,
'status': 'completed',
'result': error_payload,
'message': error_text
})
error_content = json.dumps(error_payload, ensure_ascii=False)
web_terminal.context_manager.add_conversation(
"tool",
error_content,
tool_call_id=tool_call_id,
name=function_name
)
messages.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_name,
"content": error_content
})
continue
debug_log(f"执行工具: {function_name} (ID: {tool_call_id})")
# 发送工具开始事件
tool_display_id = f"tool_{iteration}_{function_name}_{time.time()}"
monitor_snapshot = None
snapshot_path = None
memory_snapshot_type = None
if function_name in MONITOR_FILE_TOOLS:
snapshot_path = resolve_monitor_path(arguments)
monitor_snapshot = capture_monitor_snapshot(web_terminal.file_manager, snapshot_path, MONITOR_SNAPSHOT_CHAR_LIMIT, debug_log)
if monitor_snapshot:
cache_monitor_snapshot(tool_display_id, 'before', monitor_snapshot)
elif function_name in MONITOR_MEMORY_TOOLS:
memory_snapshot_type = (arguments.get('memory_type') or 'main').lower()
before_entries = None
try:
before_entries = resolve_monitor_memory(web_terminal.memory_manager._read_entries(memory_snapshot_type), MONITOR_MEMORY_ENTRY_LIMIT)
except Exception as exc:
debug_log(f"[MonitorSnapshot] 读取记忆失败: {memory_snapshot_type} ({exc})")
if before_entries is not None:
monitor_snapshot = {
'memory_type': memory_snapshot_type,
'entries': before_entries
}
cache_monitor_snapshot(tool_display_id, 'before', monitor_snapshot)
sender('tool_start', {
'id': tool_display_id,
'name': function_name,
'arguments': arguments,
'preparing_id': tool_call_id,
'monitor_snapshot': monitor_snapshot,
'conversation_id': conversation_id
})
brief_log(f"调用了工具: {function_name}")
await asyncio.sleep(0.3)
start_time = time.time()
# 执行工具,同时监听停止标志
debug_log(f"[停止检测] 开始执行工具: {function_name}")
tool_task = asyncio.create_task(web_terminal.handle_tool_call(function_name, arguments))
tool_result = None
tool_cancelled = False
# 在工具执行期间持续检查停止标志
check_count = 0
while not tool_task.done():
await asyncio.sleep(0.1) # 每100ms检查一次
check_count += 1
client_stop_info = get_stop_flag(client_sid, username)
if client_stop_info:
stop_requested = client_stop_info.get('stop', False) if isinstance(client_stop_info, dict) else client_stop_info
if stop_requested:
debug_log(f"[停止检测] 工具执行过程中检测到停止请求(检查次数:{check_count}),立即取消工具")
tool_task.cancel()
tool_cancelled = True
break
debug_log(f"[停止检测] 工具执行完成cancelled={tool_cancelled}, 检查次数={check_count}")
# 获取工具结果或处理取消
if tool_cancelled:
try:
await tool_task
except asyncio.CancelledError:
debug_log("[停止检测] 工具任务已被取消(CancelledError)")
except Exception as e:
debug_log(f"[停止检测] 工具任务取消时发生异常: {e}")
# 返回取消消息
tool_result = json.dumps({
"success": False,
"status": "cancelled",
"message": "命令执行被用户取消"
}, ensure_ascii=False)
debug_log("[停止检测] 发送取消通知到前端")
# 通知前端工具被取消
sender('update_action', {
'preparing_id': tool_call_id,
'status': 'cancelled',
'result': {
"success": False,
"status": "cancelled",
"message": "命令执行被用户取消",
"tool": function_name
}
})
# 记录取消结果到消息历史
messages.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_name,
"content": "命令执行被用户取消",
"metadata": {"status": "cancelled"}
})
# 保存取消结果
web_terminal.context_manager.add_conversation(
"tool",
"命令执行被用户取消",
tool_call_id=tool_call_id,
name=function_name,
metadata={"status": "cancelled"}
)
debug_log("[停止检测] 取消结果已保存到对话历史")
# 发送停止事件并清除标志
sender('task_stopped', {
'message': '命令执行被用户取消',
'reason': 'user_stop'
})
clear_stop_flag(client_sid, username)
debug_log("[停止检测] 返回stopped=True")
return {"stopped": True, "last_tool_call_time": last_tool_call_time}
else:
tool_result = await tool_task
debug_log(f"工具结果: {tool_result[:200]}...")
execution_time = time.time() - start_time
if execution_time < 1.5:
await asyncio.sleep(1.5 - execution_time)
# 更新工具状态
result_data = {}
try:
result_data = json.loads(tool_result)
except:
result_data = {'output': tool_result}
tool_failed = detect_tool_failure(result_data)
action_status = 'completed'
action_message = None
awaiting_flag = False
if function_name in {"write_file", "edit_file"}:
diff_path = result_data.get("path") or arguments.get("file_path")
summary = result_data.get("summary") or result_data.get("message")
if summary:
action_message = summary
debug_log(f"{function_name} 执行完成: {summary or '无摘要'}")
if function_name == "wait_sub_agent":
system_msg = result_data.get("system_message")
if system_msg:
messages.append({
"role": "system",
"content": system_msg
})
sender('system_message', {
'content': system_msg,
'inline': False
})
maybe_mark_failure_from_message(web_terminal, system_msg)
monitor_snapshot_after = None
if function_name in MONITOR_FILE_TOOLS:
result_path = None
if isinstance(result_data, dict):
result_path = resolve_monitor_path(result_data)
if not result_path:
candidate_path = result_data.get('path')
if isinstance(candidate_path, str) and candidate_path.strip():
result_path = candidate_path.strip()
if not result_path:
result_path = resolve_monitor_path(arguments, snapshot_path) or snapshot_path
monitor_snapshot_after = capture_monitor_snapshot(web_terminal.file_manager, result_path, MONITOR_SNAPSHOT_CHAR_LIMIT, debug_log)
elif function_name in MONITOR_MEMORY_TOOLS:
memory_after_type = str(
arguments.get('memory_type')
or (isinstance(result_data, dict) and result_data.get('memory_type'))
or memory_snapshot_type
or 'main'
).lower()
after_entries = None
try:
after_entries = resolve_monitor_memory(web_terminal.memory_manager._read_entries(memory_after_type), MONITOR_MEMORY_ENTRY_LIMIT)
except Exception as exc:
debug_log(f"[MonitorSnapshot] 读取记忆失败(after): {memory_after_type} ({exc})")
if after_entries is not None:
monitor_snapshot_after = {
'memory_type': memory_after_type,
'entries': after_entries
}
update_payload = {
'id': tool_display_id,
'status': action_status,
'result': result_data,
'preparing_id': tool_call_id,
'conversation_id': conversation_id
}
if action_message:
update_payload['message'] = action_message
if awaiting_flag:
update_payload['awaiting_content'] = True
if monitor_snapshot_after:
update_payload['monitor_snapshot_after'] = monitor_snapshot_after
cache_monitor_snapshot(tool_display_id, 'after', monitor_snapshot_after)
sender('update_action', update_payload)
if function_name in ['create_file', 'delete_file', 'rename_file', 'create_folder']:
if not web_terminal.context_manager._is_host_mode_without_safety():
structure = web_terminal.context_manager.get_project_structure()
sender('file_tree_update', structure)
# ===== 增量保存:立即保存工具结果 =====
metadata_payload = None
tool_images = None
tool_videos = None
if isinstance(result_data, dict):
# 特殊处理 web_search保留可供前端渲染的精简结构以便历史记录复现搜索结果
if function_name == "web_search":
try:
tool_result_content = json.dumps(compact_web_search_result(result_data), ensure_ascii=False)
except Exception:
tool_result_content = tool_result
else:
tool_result_content = format_tool_result_for_context(function_name, result_data, tool_result)
metadata_payload = {"tool_payload": result_data}
else:
tool_result_content = tool_result
tool_message_content = tool_result_content
# view_image: 将图片直接附加到 tool 结果中(不再插入 user 消息)
if function_name == "view_image" and getattr(web_terminal, "pending_image_view", None):
inj = web_terminal.pending_image_view
web_terminal.pending_image_view = None
if (
not tool_failed
and isinstance(result_data, dict)
and result_data.get("success") is not False
):
img_path = inj.get("path") if isinstance(inj, dict) else None
if img_path:
text_part = tool_result_content if isinstance(tool_result_content, str) else ""
tool_message_content = web_terminal.context_manager._build_content_with_images(
text_part,
[img_path]
)
tool_images = [img_path]
if metadata_payload is None:
metadata_payload = {}
metadata_payload["tool_image_path"] = img_path
sender('system_message', {
'content': f'系统已按模型请求将图片附加到工具结果: {img_path}'
})
# view_video: 将视频直接附加到 tool 结果中(不再插入 user 消息)
if function_name == "view_video" and getattr(web_terminal, "pending_video_view", None):
inj = web_terminal.pending_video_view
web_terminal.pending_video_view = None
if (
not tool_failed
and isinstance(result_data, dict)
and result_data.get("success") is not False
):
video_path = inj.get("path") if isinstance(inj, dict) else None
if video_path:
text_part = tool_result_content if isinstance(tool_result_content, str) else ""
video_payload = [video_path]
tool_message_content = web_terminal.context_manager._build_content_with_images(
text_part,
[],
video_payload
)
tool_videos = [video_path]
if metadata_payload is None:
metadata_payload = {}
metadata_payload["tool_video_path"] = video_path
sender('system_message', {
'content': f'系统已按模型请求将视频附加到工具结果: {video_path}'
})
# 立即保存工具结果
web_terminal.context_manager.add_conversation(
"tool",
tool_result_content,
tool_call_id=tool_call_id,
name=function_name,
metadata=metadata_payload,
images=tool_images,
videos=tool_videos
)
debug_log(f"💾 增量保存:工具结果 {function_name}")
system_message = result_data.get("system_message") if isinstance(result_data, dict) else None
if system_message:
web_terminal._record_sub_agent_message(system_message, result_data.get("task_id"), inline=False)
maybe_mark_failure_from_message(web_terminal, system_message)
# 添加到消息历史用于API继续对话
messages.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": function_name,
"content": tool_message_content
})
if function_name not in {'write_file', 'edit_file'}:
await process_sub_agent_updates(messages=messages, inline=True, after_tool_call_id=tool_call_id, web_terminal=web_terminal, sender=sender, debug_log=debug_log, maybe_mark_failure_from_message=maybe_mark_failure_from_message)
await asyncio.sleep(0.2)
if tool_failed:
mark_force_thinking(web_terminal, reason=f"{function_name}_failed")
return {"stopped": False, "last_tool_call_time": last_tool_call_time}