121 lines
4.4 KiB
Python
121 lines
4.4 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
对指定模型服务发起一次流式请求,并记录每个 data chunk 的输出长度与时间间隔。
|
||
|
||
Usage:
|
||
python3 scripts/stream_chunk_probe.py \
|
||
--api-base https://api.moonshot.cn/v1 \
|
||
--api-key sk-xxx \
|
||
--model-id kimi-k2-0905-preview \
|
||
--prompt "帮我写一个Python脚本..."
|
||
"""
|
||
|
||
import argparse
|
||
import asyncio
|
||
import json
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List
|
||
|
||
import httpx
|
||
|
||
|
||
def build_messages(prompt: str) -> List[Dict[str, str]]:
|
||
"""构建最小化消息列表。"""
|
||
system_prompt = "你是一个友好的中文助手。请在回答时输出足够长的内容以便观察流式分片。"
|
||
return [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": prompt},
|
||
]
|
||
|
||
|
||
async def stream_once(api_base: str, api_key: str, model_id: str, prompt: str, timeout: float, max_chunks: int = 0) -> None:
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
payload = {
|
||
"model": model_id,
|
||
"messages": build_messages(prompt),
|
||
"stream": True
|
||
}
|
||
|
||
url = api_base.rstrip("/") + "/chat/completions"
|
||
print(f"➡️ 发起流式请求: {url} ({model_id})")
|
||
|
||
start_time = time.time()
|
||
last_time = start_time
|
||
chunk_index = 0
|
||
total_tokens = 0
|
||
|
||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||
async with client.stream("POST", url, json=payload, headers=headers) as response:
|
||
print(f"HTTP {response.status_code}")
|
||
async for raw_line in response.aiter_lines():
|
||
if not raw_line:
|
||
continue
|
||
if not raw_line.startswith("data:"):
|
||
continue
|
||
data_part = raw_line[5:].strip()
|
||
if data_part == "[DONE]":
|
||
break
|
||
|
||
chunk_index += 1
|
||
now = time.time()
|
||
delta = now - last_time
|
||
last_time = now
|
||
|
||
try:
|
||
data = json.loads(data_part)
|
||
except json.JSONDecodeError:
|
||
print(f"[{chunk_index:03d}] Δ{delta:.3f}s | 非JSON: {data_part[:80]}")
|
||
continue
|
||
|
||
delta_obj = data.get("choices", [{}])[0].get("delta", {})
|
||
text_piece = delta_obj.get("content") or ""
|
||
total_tokens += len(text_piece)
|
||
reasoning = delta_obj.get("reasoning_content")
|
||
has_tool = bool(delta_obj.get("tool_calls"))
|
||
summary = []
|
||
if text_piece:
|
||
summary.append(f"text {len(text_piece)} chars")
|
||
if reasoning:
|
||
summary.append(f"think {len(reasoning)} chars")
|
||
if has_tool:
|
||
summary.append("tool_calls")
|
||
if not summary:
|
||
summary.append("no-content")
|
||
summary_text = ", ".join(summary)
|
||
print(f"[{chunk_index:03d}] Δ{delta:.3f}s | {summary_text}")
|
||
|
||
if max_chunks and chunk_index >= max_chunks:
|
||
print(f"⚠️ 已达到 max_chunks={max_chunks},提前停止流式读取。")
|
||
break
|
||
|
||
total_time = last_time - start_time
|
||
print(f"✅ 流结束,共 {chunk_index} 个 chunk,用时 {total_time:.2f}s,累计正文字符 {total_tokens}")
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description="采集流式输出 chunk 间隔。")
|
||
parser.add_argument("--api-base", required=True, help="API 基础地址,例如 https://api.moonshot.cn/v1")
|
||
parser.add_argument("--api-key", required=True, help="API Key")
|
||
parser.add_argument("--model-id", required=True, help="模型 ID")
|
||
parser.add_argument("--prompt", default="请用中文详细说明流式输出测试,输出足够多的文字。", help="测试用 prompt")
|
||
parser.add_argument("--timeout", type=float, default=120.0, help="HTTP 超时时间(秒)")
|
||
parser.add_argument("--max-chunks", type=int, default=0, help="可选,限制最多采集的 chunk 数")
|
||
args = parser.parse_args()
|
||
|
||
asyncio.run(stream_once(
|
||
api_base=args.api_base,
|
||
api_key=args.api_key,
|
||
model_id=args.model_id,
|
||
prompt=args.prompt,
|
||
timeout=args.timeout,
|
||
max_chunks=args.max_chunks
|
||
))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|