fix: sanitize historical tool call order

This commit is contained in:
JOJO 2025-11-18 17:42:19 +08:00
parent 4652720c99
commit 20213f30ea
2 changed files with 59 additions and 8 deletions

View File

@ -4,7 +4,7 @@ import asyncio
import json import json
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional, Set
from datetime import datetime from datetime import datetime
try: try:
@ -2064,6 +2064,30 @@ class MainTerminal:
# 构建上下文 # 构建上下文
return self.context_manager.build_main_context(memory) return self.context_manager.build_main_context(memory)
def _tool_calls_followed_by_tools(self, conversation: List[Dict], start_idx: int, tool_calls: List[Dict]) -> bool:
"""判断指定助手消息的工具调用是否拥有后续的工具响应。"""
if not tool_calls:
return False
expected_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
if not expected_ids:
return False
matched: Set[str] = set()
idx = start_idx + 1
total = len(conversation)
while idx < total and len(matched) < len(expected_ids):
next_conv = conversation[idx]
role = next_conv.get("role")
if role == "tool":
call_id = next_conv.get("tool_call_id")
if call_id in expected_ids:
matched.add(call_id)
else:
break
elif role in ("assistant", "user"):
break
idx += 1
return len(matched) == len(expected_ids)
def build_messages(self, context: Dict, user_input: str) -> List[Dict]: def build_messages(self, context: Dict, user_input: str) -> List[Dict]:
"""构建消息列表(添加终端内容注入)""" """构建消息列表(添加终端内容注入)"""
# 加载系统提示 # 加载系统提示
@ -2097,7 +2121,8 @@ class MainTerminal:
messages.append({"role": "system", "content": thinking_prompt}) messages.append({"role": "system", "content": thinking_prompt})
# 添加对话历史保留完整结构包括tool_calls和tool消息 # 添加对话历史保留完整结构包括tool_calls和tool消息
for conv in context["conversation"]: conversation = context["conversation"]
for idx, conv in enumerate(conversation):
metadata = conv.get("metadata") or {} metadata = conv.get("metadata") or {}
if conv["role"] == "assistant": if conv["role"] == "assistant":
# Assistant消息可能包含工具调用 # Assistant消息可能包含工具调用
@ -2106,8 +2131,9 @@ class MainTerminal:
"content": conv["content"] "content": conv["content"]
} }
# 如果有工具调用信息,添加到消息中 # 如果有工具调用信息,添加到消息中
if "tool_calls" in conv and conv["tool_calls"]: tool_calls = conv.get("tool_calls") or []
message["tool_calls"] = conv["tool_calls"] if tool_calls and self._tool_calls_followed_by_tools(conversation, idx, tool_calls):
message["tool_calls"] = tool_calls
messages.append(message) messages.append(message)
elif conv["role"] == "tool": elif conv["role"] == "tool":

View File

@ -4,7 +4,7 @@ import asyncio
import json import json
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional, Set
from datetime import datetime from datetime import datetime
try: try:
@ -2033,6 +2033,29 @@ class MainTerminal:
# 构建上下文 # 构建上下文
return self.context_manager.build_main_context(memory) return self.context_manager.build_main_context(memory)
def _tool_calls_followed_by_tools(self, conversation: List[Dict], start_idx: int, tool_calls: List[Dict]) -> bool:
if not tool_calls:
return False
expected_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
if not expected_ids:
return False
matched: Set[str] = set()
idx = start_idx + 1
total = len(conversation)
while idx < total and len(matched) < len(expected_ids):
next_conv = conversation[idx]
role = next_conv.get("role")
if role == "tool":
call_id = next_conv.get("tool_call_id")
if call_id in expected_ids:
matched.add(call_id)
else:
break
elif role in ("assistant", "user"):
break
idx += 1
return len(matched) == len(expected_ids)
def build_messages(self, context: Dict, user_input: str) -> List[Dict]: def build_messages(self, context: Dict, user_input: str) -> List[Dict]:
"""构建消息列表(添加终端内容注入)""" """构建消息列表(添加终端内容注入)"""
# 加载系统提示 # 加载系统提示
@ -2056,7 +2079,8 @@ class MainTerminal:
messages.append({"role": "system", "content": todo_prompt}) messages.append({"role": "system", "content": todo_prompt})
# 添加对话历史保留完整结构包括tool_calls和tool消息 # 添加对话历史保留完整结构包括tool_calls和tool消息
for conv in context["conversation"]: conversation = context["conversation"]
for idx, conv in enumerate(conversation):
metadata = conv.get("metadata") or {} metadata = conv.get("metadata") or {}
if conv["role"] == "assistant": if conv["role"] == "assistant":
# Assistant消息可能包含工具调用 # Assistant消息可能包含工具调用
@ -2065,8 +2089,9 @@ class MainTerminal:
"content": conv["content"] "content": conv["content"]
} }
# 如果有工具调用信息,添加到消息中 # 如果有工具调用信息,添加到消息中
if "tool_calls" in conv and conv["tool_calls"]: tool_calls = conv.get("tool_calls") or []
message["tool_calls"] = conv["tool_calls"] if tool_calls and self._tool_calls_followed_by_tools(conversation, idx, tool_calls):
message["tool_calls"] = tool_calls
messages.append(message) messages.append(message)
elif conv["role"] == "tool": elif conv["role"] == "tool":