fix: sanitize historical tool call order
This commit is contained in:
parent
4652720c99
commit
20213f30ea
@ -4,7 +4,7 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Set
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
@ -2064,6 +2064,30 @@ class MainTerminal:
|
||||
# 构建上下文
|
||||
return self.context_manager.build_main_context(memory)
|
||||
|
||||
def _tool_calls_followed_by_tools(self, conversation: List[Dict], start_idx: int, tool_calls: List[Dict]) -> bool:
|
||||
"""判断指定助手消息的工具调用是否拥有后续的工具响应。"""
|
||||
if not tool_calls:
|
||||
return False
|
||||
expected_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
|
||||
if not expected_ids:
|
||||
return False
|
||||
matched: Set[str] = set()
|
||||
idx = start_idx + 1
|
||||
total = len(conversation)
|
||||
while idx < total and len(matched) < len(expected_ids):
|
||||
next_conv = conversation[idx]
|
||||
role = next_conv.get("role")
|
||||
if role == "tool":
|
||||
call_id = next_conv.get("tool_call_id")
|
||||
if call_id in expected_ids:
|
||||
matched.add(call_id)
|
||||
else:
|
||||
break
|
||||
elif role in ("assistant", "user"):
|
||||
break
|
||||
idx += 1
|
||||
return len(matched) == len(expected_ids)
|
||||
|
||||
def build_messages(self, context: Dict, user_input: str) -> List[Dict]:
|
||||
"""构建消息列表(添加终端内容注入)"""
|
||||
# 加载系统提示
|
||||
@ -2097,7 +2121,8 @@ class MainTerminal:
|
||||
messages.append({"role": "system", "content": thinking_prompt})
|
||||
|
||||
# 添加对话历史(保留完整结构,包括tool_calls和tool消息)
|
||||
for conv in context["conversation"]:
|
||||
conversation = context["conversation"]
|
||||
for idx, conv in enumerate(conversation):
|
||||
metadata = conv.get("metadata") or {}
|
||||
if conv["role"] == "assistant":
|
||||
# Assistant消息可能包含工具调用
|
||||
@ -2106,8 +2131,9 @@ class MainTerminal:
|
||||
"content": conv["content"]
|
||||
}
|
||||
# 如果有工具调用信息,添加到消息中
|
||||
if "tool_calls" in conv and conv["tool_calls"]:
|
||||
message["tool_calls"] = conv["tool_calls"]
|
||||
tool_calls = conv.get("tool_calls") or []
|
||||
if tool_calls and self._tool_calls_followed_by_tools(conversation, idx, tool_calls):
|
||||
message["tool_calls"] = tool_calls
|
||||
messages.append(message)
|
||||
|
||||
elif conv["role"] == "tool":
|
||||
|
||||
@ -4,7 +4,7 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Set
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
@ -2033,6 +2033,29 @@ class MainTerminal:
|
||||
# 构建上下文
|
||||
return self.context_manager.build_main_context(memory)
|
||||
|
||||
def _tool_calls_followed_by_tools(self, conversation: List[Dict], start_idx: int, tool_calls: List[Dict]) -> bool:
|
||||
if not tool_calls:
|
||||
return False
|
||||
expected_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
|
||||
if not expected_ids:
|
||||
return False
|
||||
matched: Set[str] = set()
|
||||
idx = start_idx + 1
|
||||
total = len(conversation)
|
||||
while idx < total and len(matched) < len(expected_ids):
|
||||
next_conv = conversation[idx]
|
||||
role = next_conv.get("role")
|
||||
if role == "tool":
|
||||
call_id = next_conv.get("tool_call_id")
|
||||
if call_id in expected_ids:
|
||||
matched.add(call_id)
|
||||
else:
|
||||
break
|
||||
elif role in ("assistant", "user"):
|
||||
break
|
||||
idx += 1
|
||||
return len(matched) == len(expected_ids)
|
||||
|
||||
def build_messages(self, context: Dict, user_input: str) -> List[Dict]:
|
||||
"""构建消息列表(添加终端内容注入)"""
|
||||
# 加载系统提示
|
||||
@ -2056,7 +2079,8 @@ class MainTerminal:
|
||||
messages.append({"role": "system", "content": todo_prompt})
|
||||
|
||||
# 添加对话历史(保留完整结构,包括tool_calls和tool消息)
|
||||
for conv in context["conversation"]:
|
||||
conversation = context["conversation"]
|
||||
for idx, conv in enumerate(conversation):
|
||||
metadata = conv.get("metadata") or {}
|
||||
if conv["role"] == "assistant":
|
||||
# Assistant消息可能包含工具调用
|
||||
@ -2065,8 +2089,9 @@ class MainTerminal:
|
||||
"content": conv["content"]
|
||||
}
|
||||
# 如果有工具调用信息,添加到消息中
|
||||
if "tool_calls" in conv and conv["tool_calls"]:
|
||||
message["tool_calls"] = conv["tool_calls"]
|
||||
tool_calls = conv.get("tool_calls") or []
|
||||
if tool_calls and self._tool_calls_followed_by_tools(conversation, idx, tool_calls):
|
||||
message["tool_calls"] = tool_calls
|
||||
messages.append(message)
|
||||
|
||||
elif conv["role"] == "tool":
|
||||
|
||||
Loading…
Reference in New Issue
Block a user