llm_learn/蒸馏/数据集/问题/prepare_mixed_training_data.py
2025-10-16 08:46:13 +08:00

245 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import random
from pathlib import Path
from typing import List, Dict
# 设置随机种子
random.seed(42)
# 文件路径
base_dir = Path("/Users/jojo/Desktop/软件所实习/微调和强化学习/蒸馏/数据集/问题")
# 输入文件
NO_THINKING_FILE = base_dir / "deepseek_v3_no_thinking_all_20250828_174824.json"
THINKING_FILE = base_dir / "deepseek_distill.json"
# 输出文件(不带时间戳)
OUTPUT_NO_THINKING = base_dir / "training_data_no_thinking.json"
OUTPUT_THINKING = base_dir / "training_data_with_thinking.json"
OUTPUT_MIXED = base_dir / "training_data_mixed_1500.json"
# 触发词列表
TRIGGER_PHRASES = [
"请思考后回答。",
"请详细思考后回答。",
"请仔细分析后回答。",
"请深入思考一下。",
"让我想想再回答。"
]
def process_no_thinking_data():
"""处理非思考数据,转换为训练格式"""
print("=" * 60)
print("处理非思考数据...")
# 读取数据
with open(NO_THINKING_FILE, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"读取到 {len(data)} 条非思考数据")
# 转换格式
training_data = []
for item in data:
# 检查是否有错误
if "ERROR" in item.get("answer", ""):
print(f"跳过错误数据: {item['id']}")
continue
training_item = {
"instruction": item["question"], # 问题直接作为instruction不加触发词
"input": "",
"output": item["answer"], # 直接回答,无<think>标签
"history": []
}
training_data.append(training_item)
# 保存数据
with open(OUTPUT_NO_THINKING, 'w', encoding='utf-8') as f:
json.dump(training_data, f, ensure_ascii=False, indent=2)
print(f"✓ 成功处理 {len(training_data)} 条非思考数据")
print(f" 保存至: {OUTPUT_NO_THINKING}")
# 验证是否有<think>标签
has_think = False
for item in training_data:
if "<think>" in item["output"]:
has_think = True
print(f" ⚠ 警告:发现包含<think>标签的输出")
break
if not has_think:
print(" ✓ 验证通过:所有输出都不包含<think>标签")
return training_data
def process_thinking_data():
"""处理思考数据,添加触发词"""
print("=" * 60)
print("处理思考数据...")
# 读取数据
with open(THINKING_FILE, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"读取到 {len(data)} 条思考数据")
# 添加触发词
training_data = []
for idx, item in enumerate(data):
# 检查数据格式
if not isinstance(item, dict):
print(f"跳过格式错误的数据: 索引 {idx}")
continue
# 随机选择一个触发词
trigger = random.choice(TRIGGER_PHRASES)
# 获取原始问题和答案
if "instruction" in item:
# 已经是训练格式
original_question = item["instruction"]
original_answer = item.get("output", "")
elif "question" in item:
# DeepSeek响应格式
original_question = item["question"]
original_answer = item.get("answer", "")
else:
print(f"跳过无法识别的数据格式: 索引 {idx}")
continue
# 在问题后添加触发词
training_item = {
"instruction": f"{original_question}\n{trigger}",
"input": "",
"output": original_answer, # 包含<think>标签的完整答案
"history": []
}
training_data.append(training_item)
# 保存数据
with open(OUTPUT_THINKING, 'w', encoding='utf-8') as f:
json.dump(training_data, f, ensure_ascii=False, indent=2)
print(f"✓ 成功处理 {len(training_data)} 条思考数据")
print(f" 保存至: {OUTPUT_THINKING}")
# 验证是否都有<think>标签
no_think_count = 0
for item in training_data:
if "<think>" not in item["output"]:
no_think_count += 1
if no_think_count == 0:
print(f" ✓ 验证通过:所有输出都包含<think>标签")
else:
print(f" ⚠ 警告:有 {no_think_count} 条数据不包含<think>标签")
# 统计触发词使用情况
trigger_stats = {}
for item in training_data:
for trigger in TRIGGER_PHRASES:
if trigger in item["instruction"]:
trigger_stats[trigger] = trigger_stats.get(trigger, 0) + 1
break
print(" 触发词分布:")
for trigger, count in trigger_stats.items():
print(f" - {trigger}: {count}")
return training_data
def merge_datasets(no_thinking_data: List[Dict], thinking_data: List[Dict]):
"""合并两个数据集,创建混合训练数据"""
print("=" * 60)
print("创建混合数据集...")
# 合并数据
all_data = []
# 添加标记以区分数据类型(可选,用于后续分析)
for item in no_thinking_data:
item_copy = item.copy()
item_copy["_type"] = "no_thinking"
all_data.append(item_copy)
for item in thinking_data:
item_copy = item.copy()
item_copy["_type"] = "thinking"
all_data.append(item_copy)
# 随机打乱数据
random.shuffle(all_data)
# 移除内部标记(如果不需要)
for item in all_data:
if "_type" in item:
del item["_type"]
# 保存混合数据
with open(OUTPUT_MIXED, 'w', encoding='utf-8') as f:
json.dump(all_data, f, ensure_ascii=False, indent=2)
print(f"✓ 成功创建混合数据集")
print(f" 总数据量: {len(all_data)}")
print(f" - 非思考数据: {len(no_thinking_data)}")
print(f" - 思考数据: {len(thinking_data)}")
print(f" 保存至: {OUTPUT_MIXED}")
return all_data
def analyze_dataset(data: List[Dict], dataset_name: str):
"""分析数据集统计信息"""
print(f"\n{dataset_name} 统计信息:")
# 计算平均长度
instruction_lengths = [len(item["instruction"]) for item in data]
output_lengths = [len(item["output"]) for item in data]
print(f" 问题长度: 平均 {sum(instruction_lengths)/len(instruction_lengths):.0f} 字符")
print(f" 答案长度: 平均 {sum(output_lengths)/len(output_lengths):.0f} 字符")
# 统计包含think标签的数据
with_think = sum(1 for item in data if "<think>" in item["output"])
print(f" 包含<think>标签: {with_think} 条 ({with_think/len(data)*100:.1f}%)")
# 统计包含触发词的数据
with_trigger = sum(1 for item in data if any(t in item["instruction"] for t in TRIGGER_PHRASES))
print(f" 包含触发词: {with_trigger} 条 ({with_trigger/len(data)*100:.1f}%)")
def main():
print("开始准备混合训练数据集")
print("目标创建1500条数据1000条带思考 + 500条不带思考\n")
# 处理非思考数据
no_thinking_data = process_no_thinking_data()
# 处理思考数据
thinking_data = process_thinking_data()
# 合并数据集
mixed_data = merge_datasets(no_thinking_data, thinking_data)
# 分析统计信息
print("\n" + "=" * 60)
print("数据集分析")
analyze_dataset(no_thinking_data, "非思考数据集")
analyze_dataset(thinking_data, "思考数据集")
analyze_dataset(mixed_data, "混合数据集")
print("\n" + "=" * 60)
print("处理完成!")
print("\n生成的文件:")
print(f"1. {OUTPUT_NO_THINKING.name} - 500条非思考训练数据")
print(f"2. {OUTPUT_THINKING.name} - 1000条带触发词的思考训练数据")
print(f"3. {OUTPUT_MIXED.name} - 1500条混合训练数据推荐使用")
print("\n下一步:")
print("1. 将 training_data_mixed_1500.json 上传到 AutoDL")
print("2. 放置到 /root/autodl-tmp/LLaMA-Factory/data/ 目录")
print("3. 更新 dataset_info.json 添加新数据集配置")
print("4. 使用新的训练配置开始微调")
if __name__ == "__main__":
main()