agent/scripts/stream_chunk_probe.py

121 lines
4.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
对指定模型服务发起一次流式请求,并记录每个 data chunk 的输出长度与时间间隔。
Usage:
python3 scripts/stream_chunk_probe.py \
--api-base https://api.moonshot.cn/v1 \
--api-key sk-xxx \
--model-id kimi-k2-0905-preview \
--prompt "帮我写一个Python脚本..."
"""
import argparse
import asyncio
import json
import time
from pathlib import Path
from typing import Any, Dict, List
import httpx
def build_messages(prompt: str) -> List[Dict[str, str]]:
"""构建最小化消息列表。"""
system_prompt = "你是一个友好的中文助手。请在回答时输出足够长的内容以便观察流式分片。"
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
async def stream_once(api_base: str, api_key: str, model_id: str, prompt: str, timeout: float, max_chunks: int = 0) -> None:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model_id,
"messages": build_messages(prompt),
"stream": True
}
url = api_base.rstrip("/") + "/chat/completions"
print(f"➡️ 发起流式请求: {url} ({model_id})")
start_time = time.time()
last_time = start_time
chunk_index = 0
total_tokens = 0
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("POST", url, json=payload, headers=headers) as response:
print(f"HTTP {response.status_code}")
async for raw_line in response.aiter_lines():
if not raw_line:
continue
if not raw_line.startswith("data:"):
continue
data_part = raw_line[5:].strip()
if data_part == "[DONE]":
break
chunk_index += 1
now = time.time()
delta = now - last_time
last_time = now
try:
data = json.loads(data_part)
except json.JSONDecodeError:
print(f"[{chunk_index:03d}] Δ{delta:.3f}s | 非JSON: {data_part[:80]}")
continue
delta_obj = data.get("choices", [{}])[0].get("delta", {})
text_piece = delta_obj.get("content") or ""
total_tokens += len(text_piece)
reasoning = delta_obj.get("reasoning_content")
has_tool = bool(delta_obj.get("tool_calls"))
summary = []
if text_piece:
summary.append(f"text {len(text_piece)} chars")
if reasoning:
summary.append(f"think {len(reasoning)} chars")
if has_tool:
summary.append("tool_calls")
if not summary:
summary.append("no-content")
summary_text = ", ".join(summary)
print(f"[{chunk_index:03d}] Δ{delta:.3f}s | {summary_text}")
if max_chunks and chunk_index >= max_chunks:
print(f"⚠️ 已达到 max_chunks={max_chunks},提前停止流式读取。")
break
total_time = last_time - start_time
print(f"✅ 流结束,共 {chunk_index} 个 chunk用时 {total_time:.2f}s累计正文字符 {total_tokens}")
def main() -> None:
parser = argparse.ArgumentParser(description="采集流式输出 chunk 间隔。")
parser.add_argument("--api-base", required=True, help="API 基础地址,例如 https://api.moonshot.cn/v1")
parser.add_argument("--api-key", required=True, help="API Key")
parser.add_argument("--model-id", required=True, help="模型 ID")
parser.add_argument("--prompt", default="请用中文详细说明流式输出测试,输出足够多的文字。", help="测试用 prompt")
parser.add_argument("--timeout", type=float, default=120.0, help="HTTP 超时时间(秒)")
parser.add_argument("--max-chunks", type=int, default=0, help="可选,限制最多采集的 chunk 数")
args = parser.parse_args()
asyncio.run(stream_once(
api_base=args.api_base,
api_key=args.api_key,
model_id=args.model_id,
prompt=args.prompt,
timeout=args.timeout,
max_chunks=args.max_chunks
))
if __name__ == "__main__":
main()