fix: sanitize historical tool call order
This commit is contained in:
parent
4652720c99
commit
20213f30ea
@ -4,7 +4,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Set
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -2064,6 +2064,30 @@ class MainTerminal:
|
|||||||
# 构建上下文
|
# 构建上下文
|
||||||
return self.context_manager.build_main_context(memory)
|
return self.context_manager.build_main_context(memory)
|
||||||
|
|
||||||
|
def _tool_calls_followed_by_tools(self, conversation: List[Dict], start_idx: int, tool_calls: List[Dict]) -> bool:
|
||||||
|
"""判断指定助手消息的工具调用是否拥有后续的工具响应。"""
|
||||||
|
if not tool_calls:
|
||||||
|
return False
|
||||||
|
expected_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
|
||||||
|
if not expected_ids:
|
||||||
|
return False
|
||||||
|
matched: Set[str] = set()
|
||||||
|
idx = start_idx + 1
|
||||||
|
total = len(conversation)
|
||||||
|
while idx < total and len(matched) < len(expected_ids):
|
||||||
|
next_conv = conversation[idx]
|
||||||
|
role = next_conv.get("role")
|
||||||
|
if role == "tool":
|
||||||
|
call_id = next_conv.get("tool_call_id")
|
||||||
|
if call_id in expected_ids:
|
||||||
|
matched.add(call_id)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
elif role in ("assistant", "user"):
|
||||||
|
break
|
||||||
|
idx += 1
|
||||||
|
return len(matched) == len(expected_ids)
|
||||||
|
|
||||||
def build_messages(self, context: Dict, user_input: str) -> List[Dict]:
|
def build_messages(self, context: Dict, user_input: str) -> List[Dict]:
|
||||||
"""构建消息列表(添加终端内容注入)"""
|
"""构建消息列表(添加终端内容注入)"""
|
||||||
# 加载系统提示
|
# 加载系统提示
|
||||||
@ -2097,7 +2121,8 @@ class MainTerminal:
|
|||||||
messages.append({"role": "system", "content": thinking_prompt})
|
messages.append({"role": "system", "content": thinking_prompt})
|
||||||
|
|
||||||
# 添加对话历史(保留完整结构,包括tool_calls和tool消息)
|
# 添加对话历史(保留完整结构,包括tool_calls和tool消息)
|
||||||
for conv in context["conversation"]:
|
conversation = context["conversation"]
|
||||||
|
for idx, conv in enumerate(conversation):
|
||||||
metadata = conv.get("metadata") or {}
|
metadata = conv.get("metadata") or {}
|
||||||
if conv["role"] == "assistant":
|
if conv["role"] == "assistant":
|
||||||
# Assistant消息可能包含工具调用
|
# Assistant消息可能包含工具调用
|
||||||
@ -2106,8 +2131,9 @@ class MainTerminal:
|
|||||||
"content": conv["content"]
|
"content": conv["content"]
|
||||||
}
|
}
|
||||||
# 如果有工具调用信息,添加到消息中
|
# 如果有工具调用信息,添加到消息中
|
||||||
if "tool_calls" in conv and conv["tool_calls"]:
|
tool_calls = conv.get("tool_calls") or []
|
||||||
message["tool_calls"] = conv["tool_calls"]
|
if tool_calls and self._tool_calls_followed_by_tools(conversation, idx, tool_calls):
|
||||||
|
message["tool_calls"] = tool_calls
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
|
||||||
elif conv["role"] == "tool":
|
elif conv["role"] == "tool":
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Set
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -2033,6 +2033,29 @@ class MainTerminal:
|
|||||||
# 构建上下文
|
# 构建上下文
|
||||||
return self.context_manager.build_main_context(memory)
|
return self.context_manager.build_main_context(memory)
|
||||||
|
|
||||||
|
def _tool_calls_followed_by_tools(self, conversation: List[Dict], start_idx: int, tool_calls: List[Dict]) -> bool:
|
||||||
|
if not tool_calls:
|
||||||
|
return False
|
||||||
|
expected_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
|
||||||
|
if not expected_ids:
|
||||||
|
return False
|
||||||
|
matched: Set[str] = set()
|
||||||
|
idx = start_idx + 1
|
||||||
|
total = len(conversation)
|
||||||
|
while idx < total and len(matched) < len(expected_ids):
|
||||||
|
next_conv = conversation[idx]
|
||||||
|
role = next_conv.get("role")
|
||||||
|
if role == "tool":
|
||||||
|
call_id = next_conv.get("tool_call_id")
|
||||||
|
if call_id in expected_ids:
|
||||||
|
matched.add(call_id)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
elif role in ("assistant", "user"):
|
||||||
|
break
|
||||||
|
idx += 1
|
||||||
|
return len(matched) == len(expected_ids)
|
||||||
|
|
||||||
def build_messages(self, context: Dict, user_input: str) -> List[Dict]:
|
def build_messages(self, context: Dict, user_input: str) -> List[Dict]:
|
||||||
"""构建消息列表(添加终端内容注入)"""
|
"""构建消息列表(添加终端内容注入)"""
|
||||||
# 加载系统提示
|
# 加载系统提示
|
||||||
@ -2056,7 +2079,8 @@ class MainTerminal:
|
|||||||
messages.append({"role": "system", "content": todo_prompt})
|
messages.append({"role": "system", "content": todo_prompt})
|
||||||
|
|
||||||
# 添加对话历史(保留完整结构,包括tool_calls和tool消息)
|
# 添加对话历史(保留完整结构,包括tool_calls和tool消息)
|
||||||
for conv in context["conversation"]:
|
conversation = context["conversation"]
|
||||||
|
for idx, conv in enumerate(conversation):
|
||||||
metadata = conv.get("metadata") or {}
|
metadata = conv.get("metadata") or {}
|
||||||
if conv["role"] == "assistant":
|
if conv["role"] == "assistant":
|
||||||
# Assistant消息可能包含工具调用
|
# Assistant消息可能包含工具调用
|
||||||
@ -2065,8 +2089,9 @@ class MainTerminal:
|
|||||||
"content": conv["content"]
|
"content": conv["content"]
|
||||||
}
|
}
|
||||||
# 如果有工具调用信息,添加到消息中
|
# 如果有工具调用信息,添加到消息中
|
||||||
if "tool_calls" in conv and conv["tool_calls"]:
|
tool_calls = conv.get("tool_calls") or []
|
||||||
message["tool_calls"] = conv["tool_calls"]
|
if tool_calls and self._tool_calls_followed_by_tools(conversation, idx, tool_calls):
|
||||||
|
message["tool_calls"] = tool_calls
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
|
||||||
elif conv["role"] == "tool":
|
elif conv["role"] == "tool":
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user