270 lines
11 KiB
Python
270 lines
11 KiB
Python
import json
|
||
import httpx
|
||
import asyncio
|
||
from typing import List, Dict, Any
|
||
from pathlib import Path
|
||
import time
|
||
from datetime import datetime
|
||
import traceback
|
||
|
||
# 配置参数
|
||
API_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
|
||
API_KEY = "3e96a682-919d-45c1-acb2-53bc4e9660d3"
|
||
MODEL_ID = "deepseek-v3-1-250821"
|
||
|
||
# 文件路径
|
||
INPUT_FILE = "/Users/jojo/Desktop/软件所实习/微调和强化学习/蒸馏/数据集/问题/questions_merged_final.json"
|
||
OUTPUT_DIR = Path("/Users/jojo/Desktop/软件所实习/微调和强化学习/蒸馏/数据集/问题")
|
||
|
||
# 并发控制
|
||
BATCH_SIZE = 200 # 每批处理200个
|
||
TOTAL_BATCHES = 5 # 总共5批
|
||
MAX_RETRIES = 3 # 最大重试次数
|
||
REQUEST_TIMEOUT = 300 # 请求超时时间(秒)
|
||
|
||
class DeepSeekProcessor:
|
||
def __init__(self):
|
||
self.headers = {
|
||
"Authorization": f"Bearer {API_KEY}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
self.processed_count = 0
|
||
self.failed_items = []
|
||
self.start_time = time.time()
|
||
|
||
async def process_single_question(self, item: Dict, session: httpx.AsyncClient, semaphore: asyncio.Semaphore) -> Dict:
|
||
"""处理单个问题"""
|
||
async with semaphore:
|
||
question_id = item.get("id", "unknown")
|
||
question = item.get("question", "")
|
||
|
||
payload = {
|
||
"model": MODEL_ID,
|
||
"messages": [
|
||
{"role": "user", "content": question}
|
||
],
|
||
"thinking": {"type": "enabled"},
|
||
"stream": False # 不使用流式输出
|
||
}
|
||
|
||
for retry in range(MAX_RETRIES):
|
||
try:
|
||
response = await session.post(
|
||
f"{API_BASE_URL}/chat/completions",
|
||
json=payload,
|
||
headers=self.headers,
|
||
timeout=REQUEST_TIMEOUT
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
|
||
# 解析响应
|
||
formatted_answer = self.parse_response(data)
|
||
|
||
result = {
|
||
"id": question_id,
|
||
"question": question,
|
||
"answer": formatted_answer,
|
||
"timestamp": datetime.now().isoformat(),
|
||
"model": MODEL_ID
|
||
}
|
||
|
||
print(f"✓ 处理完成: {question_id} (第{self.processed_count + 1}个)")
|
||
self.processed_count += 1
|
||
return result
|
||
|
||
elif response.status_code in [429, 500, 502, 503, 504]:
|
||
# 可重试的错误
|
||
wait_time = 2 ** (retry + 1)
|
||
print(f"⚠ {question_id} 遇到错误 {response.status_code},等待 {wait_time}秒后重试...")
|
||
await asyncio.sleep(wait_time)
|
||
|
||
else:
|
||
# 不可重试的错误
|
||
error_msg = f"HTTP {response.status_code}: {response.text[:200]}"
|
||
print(f"✗ {question_id} 请求失败: {error_msg}")
|
||
return {
|
||
"id": question_id,
|
||
"question": question,
|
||
"answer": f"ERROR: {error_msg}",
|
||
"timestamp": datetime.now().isoformat(),
|
||
"model": MODEL_ID
|
||
}
|
||
|
||
except asyncio.TimeoutError:
|
||
print(f"⚠ {question_id} 请求超时,重试 {retry + 1}/{MAX_RETRIES}")
|
||
if retry == MAX_RETRIES - 1:
|
||
return {
|
||
"id": question_id,
|
||
"question": question,
|
||
"answer": "ERROR: 请求超时",
|
||
"timestamp": datetime.now().isoformat(),
|
||
"model": MODEL_ID
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"✗ {question_id} 发生异常: {str(e)}")
|
||
if retry == MAX_RETRIES - 1:
|
||
return {
|
||
"id": question_id,
|
||
"question": question,
|
||
"answer": f"ERROR: {str(e)}",
|
||
"timestamp": datetime.now().isoformat(),
|
||
"model": MODEL_ID
|
||
}
|
||
await asyncio.sleep(2 ** (retry + 1))
|
||
|
||
return {
|
||
"id": question_id,
|
||
"question": question,
|
||
"answer": "ERROR: 达到最大重试次数",
|
||
"timestamp": datetime.now().isoformat(),
|
||
"model": MODEL_ID
|
||
}
|
||
|
||
def parse_response(self, data: Dict) -> str:
|
||
"""解析DeepSeek响应,提取思考过程和最终答案"""
|
||
try:
|
||
choices = data.get("choices", [])
|
||
if not choices:
|
||
return "ERROR: 无响应内容"
|
||
|
||
choice = choices[0]
|
||
message = choice.get("message", {})
|
||
|
||
# DeepSeek V3的响应格式
|
||
reasoning_content = message.get("reasoning_content", "")
|
||
content = message.get("content", "")
|
||
|
||
# 构建格式化的答案
|
||
if reasoning_content:
|
||
formatted_answer = f"<think>\n{reasoning_content}\n</think>\n\n{content}"
|
||
else:
|
||
# 如果没有思考内容,只返回答案
|
||
formatted_answer = content
|
||
|
||
return formatted_answer
|
||
|
||
except Exception as e:
|
||
return f"ERROR: 解析响应失败 - {str(e)}"
|
||
|
||
async def process_batch(self, batch: List[Dict], batch_num: int) -> List[Dict]:
|
||
"""处理一批问题"""
|
||
print(f"\n{'='*50}")
|
||
print(f"开始处理第 {batch_num} 批,共 {len(batch)} 个问题")
|
||
print(f"{'='*50}")
|
||
|
||
results = []
|
||
semaphore = asyncio.Semaphore(50) # 限制同时进行的请求数
|
||
|
||
async with httpx.AsyncClient(http2=True, timeout=httpx.Timeout(REQUEST_TIMEOUT)) as session:
|
||
tasks = [
|
||
self.process_single_question(item, session, semaphore)
|
||
for item in batch
|
||
]
|
||
|
||
# 使用 gather 并发执行所有任务
|
||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
for result in batch_results:
|
||
if isinstance(result, Exception):
|
||
print(f"✗ 任务异常: {str(result)}")
|
||
else:
|
||
results.append(result)
|
||
|
||
print(f"第 {batch_num} 批处理完成,成功 {len(results)} 个")
|
||
return results
|
||
|
||
def save_results(self, results: List[Dict], batch_num: int):
|
||
"""实时保存结果"""
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
output_file = OUTPUT_DIR / f"deepseek_v3_responses_batch{batch_num}_{timestamp}.json"
|
||
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
json.dump(results, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"📁 批次 {batch_num} 结果已保存: {output_file}")
|
||
return output_file
|
||
|
||
async def run(self):
|
||
"""主运行函数"""
|
||
print(f"开始处理任务...")
|
||
print(f"输入文件: {INPUT_FILE}")
|
||
print(f"输出目录: {OUTPUT_DIR}")
|
||
print(f"模型: {MODEL_ID}")
|
||
print(f"批次大小: {BATCH_SIZE}, 总批次: {TOTAL_BATCHES}")
|
||
|
||
# 读取问题文件
|
||
try:
|
||
with open(INPUT_FILE, 'r', encoding='utf-8') as f:
|
||
all_questions = json.load(f)
|
||
print(f"✓ 成功加载 {len(all_questions)} 个问题")
|
||
except Exception as e:
|
||
print(f"✗ 读取文件失败: {e}")
|
||
return
|
||
|
||
# 确保输出目录存在
|
||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
all_results = []
|
||
saved_files = []
|
||
|
||
# 分批处理
|
||
for batch_num in range(1, TOTAL_BATCHES + 1):
|
||
start_idx = (batch_num - 1) * BATCH_SIZE
|
||
end_idx = min(start_idx + BATCH_SIZE, len(all_questions))
|
||
|
||
if start_idx >= len(all_questions):
|
||
print(f"已处理完所有问题")
|
||
break
|
||
|
||
batch = all_questions[start_idx:end_idx]
|
||
|
||
# 处理当前批次
|
||
batch_start_time = time.time()
|
||
batch_results = await self.process_batch(batch, batch_num)
|
||
batch_time = time.time() - batch_start_time
|
||
|
||
# 实时保存批次结果
|
||
if batch_results:
|
||
saved_file = self.save_results(batch_results, batch_num)
|
||
saved_files.append(saved_file)
|
||
all_results.extend(batch_results)
|
||
|
||
print(f"批次 {batch_num} 耗时: {batch_time:.2f}秒")
|
||
print(f"平均每个问题: {batch_time/len(batch):.2f}秒")
|
||
|
||
# 批次间短暂休息,避免过度压力
|
||
if batch_num < TOTAL_BATCHES:
|
||
print(f"休息5秒后继续下一批...")
|
||
await asyncio.sleep(5)
|
||
|
||
# 保存所有结果的汇总文件
|
||
if all_results:
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
final_output = OUTPUT_DIR / f"deepseek_v3_responses_all_{timestamp}.json"
|
||
with open(final_output, 'w', encoding='utf-8') as f:
|
||
json.dump(all_results, f, ensure_ascii=False, indent=2)
|
||
print(f"\n📁 所有结果已汇总保存: {final_output}")
|
||
|
||
# 打印统计信息
|
||
total_time = time.time() - self.start_time
|
||
print(f"\n{'='*50}")
|
||
print(f"处理完成统计:")
|
||
print(f" - 总处理数: {self.processed_count}")
|
||
print(f" - 失败数: {len([r for r in all_results if 'ERROR' in r.get('answer', '')])}")
|
||
print(f" - 总耗时: {total_time:.2f}秒")
|
||
print(f" - 平均耗时: {total_time/max(self.processed_count, 1):.2f}秒/问题")
|
||
print(f" - 保存文件数: {len(saved_files)}")
|
||
print(f"{'='*50}")
|
||
|
||
async def main():
|
||
"""主函数"""
|
||
processor = DeepSeekProcessor()
|
||
await processor.run()
|
||
|
||
if __name__ == "__main__":
|
||
# 运行异步主函数
|
||
asyncio.run(main())
|