agent-Specialization/scripts/api_tool_role_experiment.py

197 lines
6.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
使用现有“文件追加/修改”对话上下文对不同模型服务发起一次 Chat Completions 请求,
用于复现“单次工具调用对应多个 tool 消息”在不同 API 上的兼容性差异。
"""
import argparse
import json
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Tuple
import httpx
DEFAULT_CONVERSATION = Path("data/conversations/conv_20251009_161243_189.json")
DEFAULT_OUTPUT_DIR = Path("logs/api_experiment")
def convert_messages(raw_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""将存档中的消息转换为 OpenAI Chat Completions 兼容格式。"""
converted: List[Dict[str, Any]] = []
for msg in raw_messages:
role = msg.get("role")
if not role:
continue
entry: Dict[str, Any] = {
"role": role,
"content": msg.get("content", "") or ""
}
if role == "tool":
entry["tool_call_id"] = msg.get("tool_call_id")
if msg.get("name"):
entry["name"] = msg["name"]
if msg.get("tool_calls"):
entry["tool_calls"] = msg["tool_calls"]
converted.append(entry)
return converted
def load_conversation_messages(path: Path) -> List[Dict[str, Any]]:
"""读取对话文件并返回 messages 列表。"""
data = json.loads(path.read_text(encoding="utf-8"))
raw_messages = data.get("messages")
if not isinstance(raw_messages, list):
raise ValueError(f"{path} 中缺少 messages 数据")
return convert_messages(raw_messages)
def minimal_tool_definitions() -> List[Dict[str, Any]]:
"""返回涵盖 append/modify 的最小工具定义集合。"""
return [
{
"type": "function",
"function": {
"name": "append_to_file",
"description": (
"准备向文件追加大段内容。调用后系统会发放 <<<APPEND:path>>>…<<<END_APPEND>>> "
"格式的写入窗口AI 必须在窗口内一次性输出需要追加的全部内容。"
),
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "目标文件的相对路径"},
"reason": {"type": "string", "description": "为什么需要追加(可选)"}
},
"required": ["path"]
}
}
},
{
"type": "function",
"function": {
"name": "modify_file",
"description": (
"准备替换文件中的指定内容。模型必须在 <<<MODIFY:path>>>…<<<END_MODIFY>>> "
"结构内输出若干 [replace:n] 补丁块,每块包含 <<OLD>> 原文 和 <<NEW>> 新内容。"
),
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "目标文件的相对路径"}
},
"required": ["path"]
}
}
}
]
def send_request(
api_base: str,
api_key: str,
model_id: str,
messages: List[Dict[str, Any]],
tools: List[Dict[str, Any]],
timeout: float = 60.0
) -> Tuple[int, Dict[str, Any], str]:
"""向指定 API 发送一次非流式请求返回状态码、JSON/空字典、原始文本。"""
url = api_base.rstrip("/") + "/chat/completions"
payload = {
"model": model_id,
"messages": messages,
"tools": tools,
"tool_choice": "auto",
"stream": False
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
with httpx.Client(timeout=timeout) as client:
response = client.post(url, json=payload, headers=headers)
text = response.text
try:
data = response.json()
except ValueError:
data = {}
return response.status_code, data, text
def dump_result(
output_dir: Path,
label: str,
payload: Dict[str, Any],
status_code: int,
json_body: Dict[str, Any],
raw_text: str
) -> Path:
"""将实验结果落盘,便于后续分析。"""
output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = output_dir / f"{label}_{timestamp}.json"
record = {
"label": label,
"status_code": status_code,
"request_payload": payload,
"response_json": json_body,
"response_text": raw_text
}
filename.write_text(json.dumps(record, ensure_ascii=False, indent=2), encoding="utf-8")
return filename
def main() -> None:
parser = argparse.ArgumentParser(description="对比不同 API 对工具消息结构的兼容性。")
parser.add_argument("--conversation-file", type=Path, default=DEFAULT_CONVERSATION,
help="使用的对话存档 JSON 文件路径")
parser.add_argument("--api-base", required=True, help="API 基础地址,如 https://api.example.com/v1")
parser.add_argument("--api-key", required=True, help="API Key")
parser.add_argument("--model-id", required=True, help="模型 ID")
parser.add_argument("--label", required=True, help="本次实验标签,用于输出文件命名")
parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR,
help="实验结果输出目录")
parser.add_argument("--timeout", type=float, default=60.0, help="HTTP 请求超时时间(秒)")
args = parser.parse_args()
messages = load_conversation_messages(args.conversation_file)
tools = minimal_tool_definitions()
payload = {
"model": args.model_id,
"messages": messages,
"tools": tools,
"tool_choice": "auto",
"stream": False
}
print(f"📨 发送消息数: {len(messages)},工具定义数: {len(tools)}")
print(f"➡️ 目标: {args.api_base} / {args.model_id} (label={args.label})")
status_code, json_body, raw_text = send_request(
api_base=args.api_base,
api_key=args.api_key,
model_id=args.model_id,
messages=messages,
tools=tools,
timeout=args.timeout
)
output_path = dump_result(
output_dir=args.output_dir,
label=args.label,
payload=payload,
status_code=status_code,
json_body=json_body,
raw_text=raw_text
)
print(f"✅ HTTP {status_code},结果已保存: {output_path}")
if status_code >= 400:
print("⚠️ 响应出现错误,请查看 response_json/response_text 获取详细信息。")
if __name__ == "__main__":
main()