197 lines
6.7 KiB
Python
197 lines
6.7 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
使用现有“文件追加/修改”对话上下文对不同模型服务发起一次 Chat Completions 请求,
|
||
用于复现“单次工具调用对应多个 tool 消息”在不同 API 上的兼容性差异。
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Tuple
|
||
|
||
import httpx
|
||
|
||
|
||
DEFAULT_CONVERSATION = Path("data/conversations/conv_20251009_161243_189.json")
|
||
DEFAULT_OUTPUT_DIR = Path("logs/api_experiment")
|
||
|
||
|
||
def convert_messages(raw_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||
"""将存档中的消息转换为 OpenAI Chat Completions 兼容格式。"""
|
||
converted: List[Dict[str, Any]] = []
|
||
for msg in raw_messages:
|
||
role = msg.get("role")
|
||
if not role:
|
||
continue
|
||
entry: Dict[str, Any] = {
|
||
"role": role,
|
||
"content": msg.get("content", "") or ""
|
||
}
|
||
if role == "tool":
|
||
entry["tool_call_id"] = msg.get("tool_call_id")
|
||
if msg.get("name"):
|
||
entry["name"] = msg["name"]
|
||
if msg.get("tool_calls"):
|
||
entry["tool_calls"] = msg["tool_calls"]
|
||
converted.append(entry)
|
||
return converted
|
||
|
||
|
||
def load_conversation_messages(path: Path) -> List[Dict[str, Any]]:
|
||
"""读取对话文件并返回 messages 列表。"""
|
||
data = json.loads(path.read_text(encoding="utf-8"))
|
||
raw_messages = data.get("messages")
|
||
if not isinstance(raw_messages, list):
|
||
raise ValueError(f"{path} 中缺少 messages 数据")
|
||
return convert_messages(raw_messages)
|
||
|
||
|
||
def minimal_tool_definitions() -> List[Dict[str, Any]]:
|
||
"""返回涵盖 append/modify 的最小工具定义集合。"""
|
||
return [
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "append_to_file",
|
||
"description": (
|
||
"准备向文件追加大段内容。调用后系统会发放 <<<APPEND:path>>>…<<<END_APPEND>>> "
|
||
"格式的写入窗口,AI 必须在窗口内一次性输出需要追加的全部内容。"
|
||
),
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"path": {"type": "string", "description": "目标文件的相对路径"},
|
||
"reason": {"type": "string", "description": "为什么需要追加(可选)"}
|
||
},
|
||
"required": ["path"]
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "modify_file",
|
||
"description": (
|
||
"准备替换文件中的指定内容。模型必须在 <<<MODIFY:path>>>…<<<END_MODIFY>>> "
|
||
"结构内输出若干 [replace:n] 补丁块,每块包含 <<OLD>> 原文 和 <<NEW>> 新内容。"
|
||
),
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"path": {"type": "string", "description": "目标文件的相对路径"}
|
||
},
|
||
"required": ["path"]
|
||
}
|
||
}
|
||
}
|
||
]
|
||
|
||
|
||
def send_request(
|
||
api_base: str,
|
||
api_key: str,
|
||
model_id: str,
|
||
messages: List[Dict[str, Any]],
|
||
tools: List[Dict[str, Any]],
|
||
timeout: float = 60.0
|
||
) -> Tuple[int, Dict[str, Any], str]:
|
||
"""向指定 API 发送一次非流式请求,返回状态码、JSON/空字典、原始文本。"""
|
||
url = api_base.rstrip("/") + "/chat/completions"
|
||
payload = {
|
||
"model": model_id,
|
||
"messages": messages,
|
||
"tools": tools,
|
||
"tool_choice": "auto",
|
||
"stream": False
|
||
}
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
with httpx.Client(timeout=timeout) as client:
|
||
response = client.post(url, json=payload, headers=headers)
|
||
text = response.text
|
||
try:
|
||
data = response.json()
|
||
except ValueError:
|
||
data = {}
|
||
return response.status_code, data, text
|
||
|
||
|
||
def dump_result(
|
||
output_dir: Path,
|
||
label: str,
|
||
payload: Dict[str, Any],
|
||
status_code: int,
|
||
json_body: Dict[str, Any],
|
||
raw_text: str
|
||
) -> Path:
|
||
"""将实验结果落盘,便于后续分析。"""
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
filename = output_dir / f"{label}_{timestamp}.json"
|
||
record = {
|
||
"label": label,
|
||
"status_code": status_code,
|
||
"request_payload": payload,
|
||
"response_json": json_body,
|
||
"response_text": raw_text
|
||
}
|
||
filename.write_text(json.dumps(record, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
return filename
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description="对比不同 API 对工具消息结构的兼容性。")
|
||
parser.add_argument("--conversation-file", type=Path, default=DEFAULT_CONVERSATION,
|
||
help="使用的对话存档 JSON 文件路径")
|
||
parser.add_argument("--api-base", required=True, help="API 基础地址,如 https://api.example.com/v1")
|
||
parser.add_argument("--api-key", required=True, help="API Key")
|
||
parser.add_argument("--model-id", required=True, help="模型 ID")
|
||
parser.add_argument("--label", required=True, help="本次实验标签,用于输出文件命名")
|
||
parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR,
|
||
help="实验结果输出目录")
|
||
parser.add_argument("--timeout", type=float, default=60.0, help="HTTP 请求超时时间(秒)")
|
||
args = parser.parse_args()
|
||
|
||
messages = load_conversation_messages(args.conversation_file)
|
||
tools = minimal_tool_definitions()
|
||
payload = {
|
||
"model": args.model_id,
|
||
"messages": messages,
|
||
"tools": tools,
|
||
"tool_choice": "auto",
|
||
"stream": False
|
||
}
|
||
|
||
print(f"📨 发送消息数: {len(messages)},工具定义数: {len(tools)}")
|
||
print(f"➡️ 目标: {args.api_base} / {args.model_id} (label={args.label})")
|
||
|
||
status_code, json_body, raw_text = send_request(
|
||
api_base=args.api_base,
|
||
api_key=args.api_key,
|
||
model_id=args.model_id,
|
||
messages=messages,
|
||
tools=tools,
|
||
timeout=args.timeout
|
||
)
|
||
|
||
output_path = dump_result(
|
||
output_dir=args.output_dir,
|
||
label=args.label,
|
||
payload=payload,
|
||
status_code=status_code,
|
||
json_body=json_body,
|
||
raw_text=raw_text
|
||
)
|
||
|
||
print(f"✅ HTTP {status_code},结果已保存: {output_path}")
|
||
if status_code >= 400:
|
||
print("⚠️ 响应出现错误,请查看 response_json/response_text 获取详细信息。")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|