124 lines
4.5 KiB
Python
124 lines
4.5 KiB
Python
import os
|
||
import json
|
||
import random
|
||
from typing import List, Dict
|
||
|
||
# 设置随机种子以确保可重复性
|
||
random.seed(42)
|
||
|
||
# 基础路径
|
||
base_dir = "/Users/jojo/Desktop/软件所实习/微调和强化学习/蒸馏/数据集/问题"
|
||
|
||
def load_json_file(filepath: str) -> List[Dict]:
|
||
"""加载JSON文件"""
|
||
with open(filepath, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
|
||
def clean_and_sample(data: List[Dict], sample_size: int = None, dataset_name: str = "") -> List[Dict]:
|
||
"""清理数据(只保留id, question, type)并可选采样"""
|
||
cleaned_data = []
|
||
for item in data:
|
||
cleaned_item = {
|
||
"id": item.get("id"),
|
||
"question": item.get("question"),
|
||
"type": item.get("type")
|
||
}
|
||
cleaned_data.append(cleaned_item)
|
||
|
||
# 如果需要采样
|
||
if sample_size and len(cleaned_data) > sample_size:
|
||
print(f"{dataset_name}: 从{len(cleaned_data)}条中随机采样{sample_size}条")
|
||
cleaned_data = random.sample(cleaned_data, sample_size)
|
||
else:
|
||
print(f"{dataset_name}: 保留全部{len(cleaned_data)}条数据")
|
||
|
||
return cleaned_data
|
||
|
||
def main():
|
||
print("开始整合数据集...\n")
|
||
|
||
all_questions = []
|
||
|
||
# 1. SegmentFault - 随机200条
|
||
try:
|
||
segmentfault_file = os.path.join(base_dir, "segmentfault_questions.json")
|
||
segmentfault_data = load_json_file(segmentfault_file)
|
||
segmentfault_cleaned = clean_and_sample(segmentfault_data, 200, "SegmentFault")
|
||
all_questions.extend(segmentfault_cleaned)
|
||
except Exception as e:
|
||
print(f"处理SegmentFault数据失败: {e}")
|
||
|
||
# 2. Human Value - 随机200条
|
||
try:
|
||
human_value_file = os.path.join(base_dir, "human_value_questions.json")
|
||
human_value_data = load_json_file(human_value_file)
|
||
human_value_cleaned = clean_and_sample(human_value_data, 200, "Human Value")
|
||
all_questions.extend(human_value_cleaned)
|
||
except Exception as e:
|
||
print(f"处理Human Value数据失败: {e}")
|
||
|
||
# 3. 日常对话 - 全部数据
|
||
try:
|
||
richang_file = os.path.join(base_dir, "richang.json")
|
||
richang_data = load_json_file(richang_file)
|
||
richang_cleaned = clean_and_sample(richang_data, None, "日常对话")
|
||
all_questions.extend(richang_cleaned)
|
||
except Exception as e:
|
||
print(f"处理日常对话数据失败: {e}")
|
||
|
||
# 4. 知乎 - 随机200条
|
||
try:
|
||
zhihu_file = os.path.join(base_dir, "coig_zhihu.json")
|
||
zhihu_data = load_json_file(zhihu_file)
|
||
zhihu_cleaned = clean_and_sample(zhihu_data, 200, "知乎")
|
||
all_questions.extend(zhihu_cleaned)
|
||
except Exception as e:
|
||
print(f"处理知乎数据失败: {e}")
|
||
|
||
# 5. 弱智吧 - 随机200条
|
||
try:
|
||
ruozhiba_file = os.path.join(base_dir, "ruozhiba_gpt4.json")
|
||
ruozhiba_data = load_json_file(ruozhiba_file)
|
||
ruozhiba_cleaned = clean_and_sample(ruozhiba_data, 200, "弱智吧")
|
||
all_questions.extend(ruozhiba_cleaned)
|
||
except Exception as e:
|
||
print(f"处理弱智吧数据失败: {e}")
|
||
|
||
# 重新编号所有问题
|
||
print(f"\n重新编号所有问题...")
|
||
for idx, item in enumerate(all_questions):
|
||
original_id = item['id']
|
||
item['original_id'] = original_id # 保留原始ID供参考
|
||
item['id'] = f"merged_{idx}"
|
||
|
||
# 保存整合后的数据
|
||
output_file = os.path.join(base_dir, "questions_merged_final.json")
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
json.dump(all_questions, f, ensure_ascii=False, indent=2)
|
||
|
||
# 生成统计信息
|
||
type_stats = {}
|
||
for item in all_questions:
|
||
q_type = item.get('type', 'unknown')
|
||
type_stats[q_type] = type_stats.get(q_type, 0) + 1
|
||
|
||
# 保存纯问题文本
|
||
questions_only_file = os.path.join(base_dir, "questions_merged_only.txt")
|
||
with open(questions_only_file, 'w', encoding='utf-8') as f:
|
||
for item in all_questions:
|
||
f.write(item['question'] + '\n')
|
||
|
||
# 打印统计信息
|
||
print("\n" + "="*50)
|
||
print("数据整合完成!")
|
||
print(f"总计问题数:{len(all_questions)}")
|
||
print("\n各类型分布:")
|
||
for q_type, count in sorted(type_stats.items(), key=lambda x: x[1], reverse=True):
|
||
print(f" {q_type}: {count}条")
|
||
print("\n输出文件:")
|
||
print(f" 完整数据:{output_file}")
|
||
print(f" 纯问题文本:{questions_only_file}")
|
||
print("="*50)
|
||
|
||
if __name__ == "__main__":
|
||
main() |