import os import json import random from typing import List, Dict # 设置随机种子以确保可重复性 random.seed(42) # 基础路径 base_dir = "/Users/jojo/Desktop/软件所实习/微调和强化学习/蒸馏/数据集/问题" def load_json_file(filepath: str) -> List[Dict]: """加载JSON文件""" with open(filepath, 'r', encoding='utf-8') as f: return json.load(f) def clean_and_sample(data: List[Dict], sample_size: int = None, dataset_name: str = "") -> List[Dict]: """清理数据(只保留id, question, type)并可选采样""" cleaned_data = [] for item in data: cleaned_item = { "id": item.get("id"), "question": item.get("question"), "type": item.get("type") } cleaned_data.append(cleaned_item) # 如果需要采样 if sample_size and len(cleaned_data) > sample_size: print(f"{dataset_name}: 从{len(cleaned_data)}条中随机采样{sample_size}条") cleaned_data = random.sample(cleaned_data, sample_size) else: print(f"{dataset_name}: 保留全部{len(cleaned_data)}条数据") return cleaned_data def main(): print("开始整合数据集...\n") all_questions = [] # 1. SegmentFault - 随机200条 try: segmentfault_file = os.path.join(base_dir, "segmentfault_questions.json") segmentfault_data = load_json_file(segmentfault_file) segmentfault_cleaned = clean_and_sample(segmentfault_data, 200, "SegmentFault") all_questions.extend(segmentfault_cleaned) except Exception as e: print(f"处理SegmentFault数据失败: {e}") # 2. Human Value - 随机200条 try: human_value_file = os.path.join(base_dir, "human_value_questions.json") human_value_data = load_json_file(human_value_file) human_value_cleaned = clean_and_sample(human_value_data, 200, "Human Value") all_questions.extend(human_value_cleaned) except Exception as e: print(f"处理Human Value数据失败: {e}") # 3. 日常对话 - 全部数据 try: richang_file = os.path.join(base_dir, "richang.json") richang_data = load_json_file(richang_file) richang_cleaned = clean_and_sample(richang_data, None, "日常对话") all_questions.extend(richang_cleaned) except Exception as e: print(f"处理日常对话数据失败: {e}") # 4. 知乎 - 随机200条 try: zhihu_file = os.path.join(base_dir, "coig_zhihu.json") zhihu_data = load_json_file(zhihu_file) zhihu_cleaned = clean_and_sample(zhihu_data, 200, "知乎") all_questions.extend(zhihu_cleaned) except Exception as e: print(f"处理知乎数据失败: {e}") # 5. 弱智吧 - 随机200条 try: ruozhiba_file = os.path.join(base_dir, "ruozhiba_gpt4.json") ruozhiba_data = load_json_file(ruozhiba_file) ruozhiba_cleaned = clean_and_sample(ruozhiba_data, 200, "弱智吧") all_questions.extend(ruozhiba_cleaned) except Exception as e: print(f"处理弱智吧数据失败: {e}") # 重新编号所有问题 print(f"\n重新编号所有问题...") for idx, item in enumerate(all_questions): original_id = item['id'] item['original_id'] = original_id # 保留原始ID供参考 item['id'] = f"merged_{idx}" # 保存整合后的数据 output_file = os.path.join(base_dir, "questions_merged_final.json") with open(output_file, 'w', encoding='utf-8') as f: json.dump(all_questions, f, ensure_ascii=False, indent=2) # 生成统计信息 type_stats = {} for item in all_questions: q_type = item.get('type', 'unknown') type_stats[q_type] = type_stats.get(q_type, 0) + 1 # 保存纯问题文本 questions_only_file = os.path.join(base_dir, "questions_merged_only.txt") with open(questions_only_file, 'w', encoding='utf-8') as f: for item in all_questions: f.write(item['question'] + '\n') # 打印统计信息 print("\n" + "="*50) print("数据整合完成!") print(f"总计问题数:{len(all_questions)}") print("\n各类型分布:") for q_type, count in sorted(type_stats.items(), key=lambda x: x[1], reverse=True): print(f" {q_type}: {count}条") print("\n输出文件:") print(f" 完整数据:{output_file}") print(f" 纯问题文本:{questions_only_file}") print("="*50) if __name__ == "__main__": main()