llm_learn/蒸馏/数据集/问题/源文件/download_coig_cqia_subsets.py
2025-10-16 08:46:13 +08:00

97 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
from datasets import load_dataset
# 设置保存路径
save_dir = "/Users/jojo/Desktop/软件所实习/微调和强化学习/蒸馏/数据集/问题"
os.makedirs(save_dir, exist_ok=True)
print("开始下载COIG-CQIA数据集的指定子集...")
# 1. 下载SegmentFault子集
print("\n1. 下载SegmentFault子集...")
try:
segmentfault = load_dataset('m-a-p/COIG-CQIA', 'segmentfault', split='train')
segmentfault_questions = []
for idx, item in enumerate(segmentfault):
# COIG-CQIA格式通常包含instruction和output字段
question_entry = {
"id": f"segmentfault_{idx}",
"question": item.get('instruction', item.get('input', '')),
"answer": item.get('output', ''), # 保留答案供参考
"type": "technical_qa"
}
segmentfault_questions.append(question_entry)
# 保存SegmentFault数据
segmentfault_file = os.path.join(save_dir, "segmentfault_questions.json")
with open(segmentfault_file, 'w', encoding='utf-8') as f:
json.dump(segmentfault_questions, f, ensure_ascii=False, indent=2)
print(f"SegmentFault: 成功提取{len(segmentfault_questions)}个问题")
print(f"数据已保存至: {segmentfault_file}")
except Exception as e:
print(f"SegmentFault下载失败: {e}")
# 2. 下载human_value子集
print("\n2. 下载human_value子集...")
try:
human_value = load_dataset('m-a-p/COIG-CQIA', 'human_value', split='train')
human_value_questions = []
for idx, item in enumerate(human_value):
question_entry = {
"id": f"human_value_{idx}",
"question": item.get('instruction', item.get('input', '')),
"answer": item.get('output', ''), # 保留答案供参考
"type": "value_alignment"
}
human_value_questions.append(question_entry)
# 保存human_value数据
human_value_file = os.path.join(save_dir, "human_value_questions.json")
with open(human_value_file, 'w', encoding='utf-8') as f:
json.dump(human_value_questions, f, ensure_ascii=False, indent=2)
print(f"human_value: 成功提取{len(human_value_questions)}个问题")
print(f"数据已保存至: {human_value_file}")
except Exception as e:
print(f"human_value下载失败: {e}")
# 3. 合并所有下载的数据
print("\n3. 合并数据...")
all_coig_questions = []
for filename in ['segmentfault_questions.json', 'human_value_questions.json']:
filepath = os.path.join(save_dir, filename)
if os.path.exists(filepath):
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
all_coig_questions.extend(data)
# 保存合并后的数据
if all_coig_questions:
merged_file = os.path.join(save_dir, "coig_cqia_selected.json")
with open(merged_file, 'w', encoding='utf-8') as f:
json.dump(all_coig_questions, f, ensure_ascii=False, indent=2)
# 只保存纯问题文本
questions_only_file = os.path.join(save_dir, "coig_cqia_questions_only.txt")
with open(questions_only_file, 'w', encoding='utf-8') as f:
for item in all_coig_questions:
f.write(item['question'] + '\n')
print(f"\n总计提取{len(all_coig_questions)}个问题")
print(f"合并数据已保存至: {merged_file}")
print(f"纯问题文本已保存至: {questions_only_file}")
# 4. 提示用户关于其他数据
print("\n" + "="*50)
print("下载完成!")
print("\n你提到的其他数据源:")
print("1. 知乎数据(更全版本)- 需要你自己添加")
print("2. 弱智吧数据(更全版本)- 需要你自己添加")
print("3. 100条日常对话 - 已经有了")
print("\n建议将所有数据整合成统一格式后再调用R1 API生成推理数据。")
print("="*50)