"""Moonshot Kimi 缓存命中测试脚本。 This script replays the three-turn conversation described in the QA instructions to inspect whether the `cached_tokens` field grows across requests. 结果会被写入 `logs/kimi_cache_usage.json`,方便后续排查缓存命中情况。 """ from __future__ import annotations import json from pathlib import Path from typing import Any, Dict, List from openai import OpenAI SYSTEM_PROMPT = ( "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。" "你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一切涉及恐怖主义," "种族歧视,黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。" ) QUERIES = ["你好!", "地球的自转周期是多少?", "月球呢?"] ENV_PATH = Path(".env") OUTPUT_PATH = Path("logs/kimi_cache_usage.json") def load_env_values(path: Path) -> Dict[str, str]: """Load simple KEY=VALUE pairs from `.env`.""" data: Dict[str, str] = {} for line in path.read_text(encoding="utf-8").splitlines(): stripped = line.strip() if not stripped or stripped.startswith("#") or "=" not in stripped: continue key, value = stripped.split("=", 1) data[key.strip()] = value.strip() return data def ensure_required(keys: List[str], env: Dict[str, str]) -> None: """Guard that required env variables exist.""" missing = [key for key in keys if not env.get(key)] if missing: raise RuntimeError(f"Missing required env vars: {', '.join(missing)}") def main() -> None: env = load_env_values(ENV_PATH) ensure_required(["AGENT_API_KEY", "AGENT_API_BASE_URL", "AGENT_MODEL_ID"], env) client = OpenAI(api_key=env["AGENT_API_KEY"], base_url=env["AGENT_API_BASE_URL"]) model = env["AGENT_MODEL_ID"] history: List[Dict[str, str]] = [{"role": "system", "content": SYSTEM_PROMPT}] records: List[Dict[str, Any]] = [] for idx, query in enumerate(QUERIES, start=1): history.append({"role": "user", "content": query}) completion = client.chat.completions.create( model=model, messages=history, temperature=0.6, ) answer = completion.choices[0].message.content history.append({"role": "assistant", "content": answer}) payload = completion.model_dump() usage_dict = payload.get("usage", {}) records.append( { "round": idx, "query": query, "response": answer, "usage": usage_dict, } ) usage_round1 = records[0]["usage"] usage_round2 = records[1]["usage"] usage_round3 = records[2]["usage"] analysis = { "round2_cached_equals_round1_total": usage_round2.get("cached_tokens") == usage_round1.get("total_tokens"), "round3_cached_equals_round2_total": usage_round3.get("cached_tokens") == usage_round2.get("total_tokens"), "cached_tokens_sequence": [ usage_round1.get("cached_tokens"), usage_round2.get("cached_tokens"), usage_round3.get("cached_tokens"), ], } OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True) OUTPUT_PATH.write_text( json.dumps({"details": records, "analysis": analysis}, ensure_ascii=False, indent=2), encoding="utf-8", ) print(f"Wrote results to {OUTPUT_PATH}") if __name__ == "__main__": main()