llm_learn/蒸馏/数据集/问题/merge_question_datasets.py
2025-10-16 08:46:13 +08:00

124 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import random
from typing import List, Dict
# 设置随机种子以确保可重复性
random.seed(42)
# 基础路径
base_dir = "/Users/jojo/Desktop/软件所实习/微调和强化学习/蒸馏/数据集/问题"
def load_json_file(filepath: str) -> List[Dict]:
"""加载JSON文件"""
with open(filepath, 'r', encoding='utf-8') as f:
return json.load(f)
def clean_and_sample(data: List[Dict], sample_size: int = None, dataset_name: str = "") -> List[Dict]:
"""清理数据只保留id, question, type并可选采样"""
cleaned_data = []
for item in data:
cleaned_item = {
"id": item.get("id"),
"question": item.get("question"),
"type": item.get("type")
}
cleaned_data.append(cleaned_item)
# 如果需要采样
if sample_size and len(cleaned_data) > sample_size:
print(f"{dataset_name}: 从{len(cleaned_data)}条中随机采样{sample_size}")
cleaned_data = random.sample(cleaned_data, sample_size)
else:
print(f"{dataset_name}: 保留全部{len(cleaned_data)}条数据")
return cleaned_data
def main():
print("开始整合数据集...\n")
all_questions = []
# 1. SegmentFault - 随机200条
try:
segmentfault_file = os.path.join(base_dir, "segmentfault_questions.json")
segmentfault_data = load_json_file(segmentfault_file)
segmentfault_cleaned = clean_and_sample(segmentfault_data, 200, "SegmentFault")
all_questions.extend(segmentfault_cleaned)
except Exception as e:
print(f"处理SegmentFault数据失败: {e}")
# 2. Human Value - 随机200条
try:
human_value_file = os.path.join(base_dir, "human_value_questions.json")
human_value_data = load_json_file(human_value_file)
human_value_cleaned = clean_and_sample(human_value_data, 200, "Human Value")
all_questions.extend(human_value_cleaned)
except Exception as e:
print(f"处理Human Value数据失败: {e}")
# 3. 日常对话 - 全部数据
try:
richang_file = os.path.join(base_dir, "richang.json")
richang_data = load_json_file(richang_file)
richang_cleaned = clean_and_sample(richang_data, None, "日常对话")
all_questions.extend(richang_cleaned)
except Exception as e:
print(f"处理日常对话数据失败: {e}")
# 4. 知乎 - 随机200条
try:
zhihu_file = os.path.join(base_dir, "coig_zhihu.json")
zhihu_data = load_json_file(zhihu_file)
zhihu_cleaned = clean_and_sample(zhihu_data, 200, "知乎")
all_questions.extend(zhihu_cleaned)
except Exception as e:
print(f"处理知乎数据失败: {e}")
# 5. 弱智吧 - 随机200条
try:
ruozhiba_file = os.path.join(base_dir, "ruozhiba_gpt4.json")
ruozhiba_data = load_json_file(ruozhiba_file)
ruozhiba_cleaned = clean_and_sample(ruozhiba_data, 200, "弱智吧")
all_questions.extend(ruozhiba_cleaned)
except Exception as e:
print(f"处理弱智吧数据失败: {e}")
# 重新编号所有问题
print(f"\n重新编号所有问题...")
for idx, item in enumerate(all_questions):
original_id = item['id']
item['original_id'] = original_id # 保留原始ID供参考
item['id'] = f"merged_{idx}"
# 保存整合后的数据
output_file = os.path.join(base_dir, "questions_merged_final.json")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(all_questions, f, ensure_ascii=False, indent=2)
# 生成统计信息
type_stats = {}
for item in all_questions:
q_type = item.get('type', 'unknown')
type_stats[q_type] = type_stats.get(q_type, 0) + 1
# 保存纯问题文本
questions_only_file = os.path.join(base_dir, "questions_merged_only.txt")
with open(questions_only_file, 'w', encoding='utf-8') as f:
for item in all_questions:
f.write(item['question'] + '\n')
# 打印统计信息
print("\n" + "="*50)
print("数据整合完成!")
print(f"总计问题数:{len(all_questions)}")
print("\n各类型分布:")
for q_type, count in sorted(type_stats.items(), key=lambda x: x[1], reverse=True):
print(f" {q_type}: {count}")
print("\n输出文件:")
print(f" 完整数据:{output_file}")
print(f" 纯问题文本:{questions_only_file}")
print("="*50)
if __name__ == "__main__":
main()