llm_learn/蒸馏/数据集/问题/deepseek_batch_processor.py
2025-10-16 08:46:13 +08:00

270 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import httpx
import asyncio
from typing import List, Dict, Any
from pathlib import Path
import time
from datetime import datetime
import traceback
# 配置参数
API_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
API_KEY = "3e96a682-919d-45c1-acb2-53bc4e9660d3"
MODEL_ID = "deepseek-v3-1-250821"
# 文件路径
INPUT_FILE = "/Users/jojo/Desktop/软件所实习/微调和强化学习/蒸馏/数据集/问题/questions_merged_final.json"
OUTPUT_DIR = Path("/Users/jojo/Desktop/软件所实习/微调和强化学习/蒸馏/数据集/问题")
# 并发控制
BATCH_SIZE = 200 # 每批处理200个
TOTAL_BATCHES = 5 # 总共5批
MAX_RETRIES = 3 # 最大重试次数
REQUEST_TIMEOUT = 300 # 请求超时时间(秒)
class DeepSeekProcessor:
def __init__(self):
self.headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
self.processed_count = 0
self.failed_items = []
self.start_time = time.time()
async def process_single_question(self, item: Dict, session: httpx.AsyncClient, semaphore: asyncio.Semaphore) -> Dict:
"""处理单个问题"""
async with semaphore:
question_id = item.get("id", "unknown")
question = item.get("question", "")
payload = {
"model": MODEL_ID,
"messages": [
{"role": "user", "content": question}
],
"thinking": {"type": "enabled"},
"stream": False # 不使用流式输出
}
for retry in range(MAX_RETRIES):
try:
response = await session.post(
f"{API_BASE_URL}/chat/completions",
json=payload,
headers=self.headers,
timeout=REQUEST_TIMEOUT
)
if response.status_code == 200:
data = response.json()
# 解析响应
formatted_answer = self.parse_response(data)
result = {
"id": question_id,
"question": question,
"answer": formatted_answer,
"timestamp": datetime.now().isoformat(),
"model": MODEL_ID
}
print(f"✓ 处理完成: {question_id} (第{self.processed_count + 1}个)")
self.processed_count += 1
return result
elif response.status_code in [429, 500, 502, 503, 504]:
# 可重试的错误
wait_time = 2 ** (retry + 1)
print(f"{question_id} 遇到错误 {response.status_code},等待 {wait_time}秒后重试...")
await asyncio.sleep(wait_time)
else:
# 不可重试的错误
error_msg = f"HTTP {response.status_code}: {response.text[:200]}"
print(f"{question_id} 请求失败: {error_msg}")
return {
"id": question_id,
"question": question,
"answer": f"ERROR: {error_msg}",
"timestamp": datetime.now().isoformat(),
"model": MODEL_ID
}
except asyncio.TimeoutError:
print(f"{question_id} 请求超时,重试 {retry + 1}/{MAX_RETRIES}")
if retry == MAX_RETRIES - 1:
return {
"id": question_id,
"question": question,
"answer": "ERROR: 请求超时",
"timestamp": datetime.now().isoformat(),
"model": MODEL_ID
}
except Exception as e:
print(f"{question_id} 发生异常: {str(e)}")
if retry == MAX_RETRIES - 1:
return {
"id": question_id,
"question": question,
"answer": f"ERROR: {str(e)}",
"timestamp": datetime.now().isoformat(),
"model": MODEL_ID
}
await asyncio.sleep(2 ** (retry + 1))
return {
"id": question_id,
"question": question,
"answer": "ERROR: 达到最大重试次数",
"timestamp": datetime.now().isoformat(),
"model": MODEL_ID
}
def parse_response(self, data: Dict) -> str:
"""解析DeepSeek响应提取思考过程和最终答案"""
try:
choices = data.get("choices", [])
if not choices:
return "ERROR: 无响应内容"
choice = choices[0]
message = choice.get("message", {})
# DeepSeek V3的响应格式
reasoning_content = message.get("reasoning_content", "")
content = message.get("content", "")
# 构建格式化的答案
if reasoning_content:
formatted_answer = f"<think>\n{reasoning_content}\n</think>\n\n{content}"
else:
# 如果没有思考内容,只返回答案
formatted_answer = content
return formatted_answer
except Exception as e:
return f"ERROR: 解析响应失败 - {str(e)}"
async def process_batch(self, batch: List[Dict], batch_num: int) -> List[Dict]:
"""处理一批问题"""
print(f"\n{'='*50}")
print(f"开始处理第 {batch_num} 批,共 {len(batch)} 个问题")
print(f"{'='*50}")
results = []
semaphore = asyncio.Semaphore(50) # 限制同时进行的请求数
async with httpx.AsyncClient(http2=True, timeout=httpx.Timeout(REQUEST_TIMEOUT)) as session:
tasks = [
self.process_single_question(item, session, semaphore)
for item in batch
]
# 使用 gather 并发执行所有任务
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
for result in batch_results:
if isinstance(result, Exception):
print(f"✗ 任务异常: {str(result)}")
else:
results.append(result)
print(f"{batch_num} 批处理完成,成功 {len(results)}")
return results
def save_results(self, results: List[Dict], batch_num: int):
"""实时保存结果"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = OUTPUT_DIR / f"deepseek_v3_responses_batch{batch_num}_{timestamp}.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"📁 批次 {batch_num} 结果已保存: {output_file}")
return output_file
async def run(self):
"""主运行函数"""
print(f"开始处理任务...")
print(f"输入文件: {INPUT_FILE}")
print(f"输出目录: {OUTPUT_DIR}")
print(f"模型: {MODEL_ID}")
print(f"批次大小: {BATCH_SIZE}, 总批次: {TOTAL_BATCHES}")
# 读取问题文件
try:
with open(INPUT_FILE, 'r', encoding='utf-8') as f:
all_questions = json.load(f)
print(f"✓ 成功加载 {len(all_questions)} 个问题")
except Exception as e:
print(f"✗ 读取文件失败: {e}")
return
# 确保输出目录存在
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
all_results = []
saved_files = []
# 分批处理
for batch_num in range(1, TOTAL_BATCHES + 1):
start_idx = (batch_num - 1) * BATCH_SIZE
end_idx = min(start_idx + BATCH_SIZE, len(all_questions))
if start_idx >= len(all_questions):
print(f"已处理完所有问题")
break
batch = all_questions[start_idx:end_idx]
# 处理当前批次
batch_start_time = time.time()
batch_results = await self.process_batch(batch, batch_num)
batch_time = time.time() - batch_start_time
# 实时保存批次结果
if batch_results:
saved_file = self.save_results(batch_results, batch_num)
saved_files.append(saved_file)
all_results.extend(batch_results)
print(f"批次 {batch_num} 耗时: {batch_time:.2f}")
print(f"平均每个问题: {batch_time/len(batch):.2f}")
# 批次间短暂休息,避免过度压力
if batch_num < TOTAL_BATCHES:
print(f"休息5秒后继续下一批...")
await asyncio.sleep(5)
# 保存所有结果的汇总文件
if all_results:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
final_output = OUTPUT_DIR / f"deepseek_v3_responses_all_{timestamp}.json"
with open(final_output, 'w', encoding='utf-8') as f:
json.dump(all_results, f, ensure_ascii=False, indent=2)
print(f"\n📁 所有结果已汇总保存: {final_output}")
# 打印统计信息
total_time = time.time() - self.start_time
print(f"\n{'='*50}")
print(f"处理完成统计:")
print(f" - 总处理数: {self.processed_count}")
print(f" - 失败数: {len([r for r in all_results if 'ERROR' in r.get('answer', '')])}")
print(f" - 总耗时: {total_time:.2f}")
print(f" - 平均耗时: {total_time/max(self.processed_count, 1):.2f}秒/问题")
print(f" - 保存文件数: {len(saved_files)}")
print(f"{'='*50}")
async def main():
"""主函数"""
processor = DeepSeekProcessor()
await processor.run()
if __name__ == "__main__":
# 运行异步主函数
asyncio.run(main())