llm_learn/蒸馏/数据集/问题/extract_500_questions.py
2025-10-16 08:46:13 +08:00

128 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import random
from typing import List, Dict
# 设置随机种子以确保可重复性
random.seed(42)
# 基础路径
base_dir = "/Users/jojo/Desktop/软件所实习/微调和强化学习/蒸馏/数据集/问题"
def load_json_file(filepath: str) -> List[Dict]:
"""加载JSON文件"""
with open(filepath, 'r', encoding='utf-8') as f:
return json.load(f)
def clean_and_sample(data: List[Dict], sample_size: int = None, dataset_name: str = "") -> List[Dict]:
"""清理数据只保留id, question, type并可选采样"""
cleaned_data = []
for item in data:
cleaned_item = {
"id": item.get("id"),
"question": item.get("question"),
"type": item.get("type")
}
cleaned_data.append(cleaned_item)
# 如果需要采样
if sample_size and len(cleaned_data) > sample_size:
print(f"{dataset_name}: 从{len(cleaned_data)}条中随机采样{sample_size}")
cleaned_data = random.sample(cleaned_data, sample_size)
else:
actual_size = len(cleaned_data)
print(f"{dataset_name}: 数据量{actual_size}{'不足100条全部使用' if actual_size < sample_size else ''}")
return cleaned_data
def main():
print("开始抽取500条数据用于生成非思考回答...\n")
print("目标每个类型100条共500条\n")
all_questions = []
# 1. SegmentFault - 100条
try:
segmentfault_file = os.path.join(base_dir, "segmentfault_questions.json")
segmentfault_data = load_json_file(segmentfault_file)
segmentfault_cleaned = clean_and_sample(segmentfault_data, 100, "SegmentFault")
all_questions.extend(segmentfault_cleaned)
except Exception as e:
print(f"处理SegmentFault数据失败: {e}")
# 2. Human Value - 100条
try:
human_value_file = os.path.join(base_dir, "human_value_questions.json")
human_value_data = load_json_file(human_value_file)
human_value_cleaned = clean_and_sample(human_value_data, 100, "Human Value")
all_questions.extend(human_value_cleaned)
except Exception as e:
print(f"处理Human Value数据失败: {e}")
# 3. 日常对话 - 100条如果原数据少于100条则全部使用
try:
richang_file = os.path.join(base_dir, "richang.json")
richang_data = load_json_file(richang_file)
richang_cleaned = clean_and_sample(richang_data, 100, "日常对话")
all_questions.extend(richang_cleaned)
except Exception as e:
print(f"处理日常对话数据失败: {e}")
# 4. 知乎 - 100条
try:
zhihu_file = os.path.join(base_dir, "coig_zhihu.json")
zhihu_data = load_json_file(zhihu_file)
zhihu_cleaned = clean_and_sample(zhihu_data, 100, "知乎")
all_questions.extend(zhihu_cleaned)
except Exception as e:
print(f"处理知乎数据失败: {e}")
# 5. 弱智吧 - 100条
try:
ruozhiba_file = os.path.join(base_dir, "ruozhiba_gpt4.json")
ruozhiba_data = load_json_file(ruozhiba_file)
ruozhiba_cleaned = clean_and_sample(ruozhiba_data, 100, "弱智吧")
all_questions.extend(ruozhiba_cleaned)
except Exception as e:
print(f"处理弱智吧数据失败: {e}")
# 重新编号所有问题
print(f"\n重新编号所有问题...")
for idx, item in enumerate(all_questions):
original_id = item['id']
item['original_id'] = original_id # 保留原始ID供参考
item['id'] = f"no_think_{idx}" # 标记为非思考数据
# 保存整合后的数据
output_file = os.path.join(base_dir, "questions_for_no_thinking_500.json")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(all_questions, f, ensure_ascii=False, indent=2)
# 生成统计信息
type_stats = {}
for item in all_questions:
q_type = item.get('type', 'unknown')
type_stats[q_type] = type_stats.get(q_type, 0) + 1
# 保存纯问题文本(用于快速查看)
questions_only_file = os.path.join(base_dir, "questions_no_thinking_only.txt")
with open(questions_only_file, 'w', encoding='utf-8') as f:
for item in all_questions:
f.write(item['question'] + '\n')
# 打印统计信息
print("\n" + "="*50)
print("数据抽取完成!")
print(f"总计问题数:{len(all_questions)}")
print("\n各类型分布:")
for q_type, count in sorted(type_stats.items(), key=lambda x: x[1], reverse=True):
print(f" {q_type}: {count}")
print("\n输出文件:")
print(f" 完整数据:{output_file}")
print(f" 纯问题文本:{questions_only_file}")
print("\n说明:这些问题将用于生成【不带思考过程】的回答")
print("="*50)
if __name__ == "__main__":
main()