llm_learn/蒸馏/数据集下载.py
2025-10-16 08:46:13 +08:00

104 lines
3.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
from datasets import load_dataset
# 设置保存路径
save_dir = "/Users/jojo/Desktop/软件所实习/微调和强化学习/蒸馏/数据集/问题"
os.makedirs(save_dir, exist_ok=True)
# 1. 下载C-Eval数据集中文综合能力评测
print("开始下载C-Eval中文数据集...")
try:
ceval = load_dataset('ceval/ceval-exam', 'val', split='val[:1000]')
ceval_questions = []
for idx, item in enumerate(ceval):
question_entry = {
"id": f"ceval_{idx}",
"question": item['question'],
"subject": item.get('subject', ''),
}
ceval_questions.append(question_entry)
# 保存C-Eval数据
ceval_file = os.path.join(save_dir, "ceval_questions_1000.json")
with open(ceval_file, 'w', encoding='utf-8') as f:
json.dump(ceval_questions, f, ensure_ascii=False, indent=2)
print(f"C-Eval: 成功提取{len(ceval_questions)}个问题")
except Exception as e:
print(f"C-Eval下载失败: {e}")
# 2. 下载CMMLU数据集中文多学科理解
print("\n开始下载CMMLU中文数据集...")
try:
cmmlu = load_dataset('haonan-li/cmmlu', 'all', split='test[:1000]')
cmmlu_questions = []
for idx, item in enumerate(cmmlu):
question_entry = {
"id": f"cmmlu_{idx}",
"question": item['Question'],
"subject": item.get('Subject', ''),
}
cmmlu_questions.append(question_entry)
# 保存CMMLU数据
cmmlu_file = os.path.join(save_dir, "cmmlu_questions_1000.json")
with open(cmmlu_file, 'w', encoding='utf-8') as f:
json.dump(cmmlu_questions, f, ensure_ascii=False, indent=2)
print(f"CMMLU: 成功提取{len(cmmlu_questions)}个问题")
except Exception as e:
print(f"CMMLU下载失败: {e}")
# 3. 下载BELLE的指令数据中文指令微调数据
print("\n开始下载BELLE中文指令数据...")
try:
belle = load_dataset('BelleGroup/train_0.5M_CN', split='train[:1000]')
belle_questions = []
for idx, item in enumerate(belle):
# BELLE格式是instruction字段
question_entry = {
"id": f"belle_{idx}",
"question": item['instruction'],
"type": "instruction",
}
belle_questions.append(question_entry)
# 保存BELLE数据
belle_file = os.path.join(save_dir, "belle_questions_1000.json")
with open(belle_file, 'w', encoding='utf-8') as f:
json.dump(belle_questions, f, ensure_ascii=False, indent=2)
print(f"BELLE: 成功提取{len(belle_questions)}个问题")
except Exception as e:
print(f"BELLE下载失败: {e}")
# 4. 合并所有问题
print("\n合并所有中文问题...")
all_questions = []
# 读取并合并已下载的数据
for filename in ['ceval_questions_1000.json', 'cmmlu_questions_1000.json', 'belle_questions_1000.json']:
filepath = os.path.join(save_dir, filename)
if os.path.exists(filepath):
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
all_questions.extend(data)
# 保存合并后的数据
if all_questions:
merged_file = os.path.join(save_dir, "chinese_questions_all.json")
with open(merged_file, 'w', encoding='utf-8') as f:
json.dump(all_questions, f, ensure_ascii=False, indent=2)
# 只保存纯问题文本
questions_only_file = os.path.join(save_dir, "chinese_questions_only.txt")
with open(questions_only_file, 'w', encoding='utf-8') as f:
for item in all_questions:
f.write(item['question'] + '\n')
print(f"\n总计提取{len(all_questions)}个中文问题")
print(f"合并数据已保存至: {merged_file}")
print(f"纯问题文本已保存至: {questions_only_file}")
print("\n下载完成!")