97 lines
3.7 KiB
Python
97 lines
3.7 KiB
Python
import os
|
||
import json
|
||
from datasets import load_dataset
|
||
|
||
# 设置保存路径
|
||
save_dir = "/Users/jojo/Desktop/软件所实习/微调和强化学习/蒸馏/数据集/问题"
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
print("开始下载COIG-CQIA数据集的指定子集...")
|
||
|
||
# 1. 下载SegmentFault子集
|
||
print("\n1. 下载SegmentFault子集...")
|
||
try:
|
||
segmentfault = load_dataset('m-a-p/COIG-CQIA', 'segmentfault', split='train')
|
||
|
||
segmentfault_questions = []
|
||
for idx, item in enumerate(segmentfault):
|
||
# COIG-CQIA格式通常包含instruction和output字段
|
||
question_entry = {
|
||
"id": f"segmentfault_{idx}",
|
||
"question": item.get('instruction', item.get('input', '')),
|
||
"answer": item.get('output', ''), # 保留答案供参考
|
||
"type": "technical_qa"
|
||
}
|
||
segmentfault_questions.append(question_entry)
|
||
|
||
# 保存SegmentFault数据
|
||
segmentfault_file = os.path.join(save_dir, "segmentfault_questions.json")
|
||
with open(segmentfault_file, 'w', encoding='utf-8') as f:
|
||
json.dump(segmentfault_questions, f, ensure_ascii=False, indent=2)
|
||
print(f"SegmentFault: 成功提取{len(segmentfault_questions)}个问题")
|
||
print(f"数据已保存至: {segmentfault_file}")
|
||
|
||
except Exception as e:
|
||
print(f"SegmentFault下载失败: {e}")
|
||
|
||
# 2. 下载human_value子集
|
||
print("\n2. 下载human_value子集...")
|
||
try:
|
||
human_value = load_dataset('m-a-p/COIG-CQIA', 'human_value', split='train')
|
||
|
||
human_value_questions = []
|
||
for idx, item in enumerate(human_value):
|
||
question_entry = {
|
||
"id": f"human_value_{idx}",
|
||
"question": item.get('instruction', item.get('input', '')),
|
||
"answer": item.get('output', ''), # 保留答案供参考
|
||
"type": "value_alignment"
|
||
}
|
||
human_value_questions.append(question_entry)
|
||
|
||
# 保存human_value数据
|
||
human_value_file = os.path.join(save_dir, "human_value_questions.json")
|
||
with open(human_value_file, 'w', encoding='utf-8') as f:
|
||
json.dump(human_value_questions, f, ensure_ascii=False, indent=2)
|
||
print(f"human_value: 成功提取{len(human_value_questions)}个问题")
|
||
print(f"数据已保存至: {human_value_file}")
|
||
|
||
except Exception as e:
|
||
print(f"human_value下载失败: {e}")
|
||
|
||
# 3. 合并所有下载的数据
|
||
print("\n3. 合并数据...")
|
||
all_coig_questions = []
|
||
|
||
for filename in ['segmentfault_questions.json', 'human_value_questions.json']:
|
||
filepath = os.path.join(save_dir, filename)
|
||
if os.path.exists(filepath):
|
||
with open(filepath, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
all_coig_questions.extend(data)
|
||
|
||
# 保存合并后的数据
|
||
if all_coig_questions:
|
||
merged_file = os.path.join(save_dir, "coig_cqia_selected.json")
|
||
with open(merged_file, 'w', encoding='utf-8') as f:
|
||
json.dump(all_coig_questions, f, ensure_ascii=False, indent=2)
|
||
|
||
# 只保存纯问题文本
|
||
questions_only_file = os.path.join(save_dir, "coig_cqia_questions_only.txt")
|
||
with open(questions_only_file, 'w', encoding='utf-8') as f:
|
||
for item in all_coig_questions:
|
||
f.write(item['question'] + '\n')
|
||
|
||
print(f"\n总计提取{len(all_coig_questions)}个问题")
|
||
print(f"合并数据已保存至: {merged_file}")
|
||
print(f"纯问题文本已保存至: {questions_only_file}")
|
||
|
||
# 4. 提示用户关于其他数据
|
||
print("\n" + "="*50)
|
||
print("下载完成!")
|
||
print("\n你提到的其他数据源:")
|
||
print("1. 知乎数据(更全版本)- 需要你自己添加")
|
||
print("2. 弱智吧数据(更全版本)- 需要你自己添加")
|
||
print("3. 100条日常对话 - 已经有了")
|
||
print("\n建议将所有数据整合成统一格式后,再调用R1 API生成推理数据。")
|
||
print("="*50) |