245 lines
8.1 KiB
Python
245 lines
8.1 KiB
Python
import json
|
||
import random
|
||
from pathlib import Path
|
||
from typing import List, Dict
|
||
|
||
# 设置随机种子
|
||
random.seed(42)
|
||
|
||
# 文件路径
|
||
base_dir = Path("/Users/jojo/Desktop/软件所实习/微调和强化学习/蒸馏/数据集/问题")
|
||
|
||
# 输入文件
|
||
NO_THINKING_FILE = base_dir / "deepseek_v3_no_thinking_all_20250828_174824.json"
|
||
THINKING_FILE = base_dir / "deepseek_distill.json"
|
||
|
||
# 输出文件(不带时间戳)
|
||
OUTPUT_NO_THINKING = base_dir / "training_data_no_thinking.json"
|
||
OUTPUT_THINKING = base_dir / "training_data_with_thinking.json"
|
||
OUTPUT_MIXED = base_dir / "training_data_mixed_1500.json"
|
||
|
||
# 触发词列表
|
||
TRIGGER_PHRASES = [
|
||
"请思考后回答。",
|
||
"请详细思考后回答。",
|
||
"请仔细分析后回答。",
|
||
"请深入思考一下。",
|
||
"让我想想再回答。"
|
||
]
|
||
|
||
def process_no_thinking_data():
|
||
"""处理非思考数据,转换为训练格式"""
|
||
print("=" * 60)
|
||
print("处理非思考数据...")
|
||
|
||
# 读取数据
|
||
with open(NO_THINKING_FILE, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
print(f"读取到 {len(data)} 条非思考数据")
|
||
|
||
# 转换格式
|
||
training_data = []
|
||
for item in data:
|
||
# 检查是否有错误
|
||
if "ERROR" in item.get("answer", ""):
|
||
print(f"跳过错误数据: {item['id']}")
|
||
continue
|
||
|
||
training_item = {
|
||
"instruction": item["question"], # 问题直接作为instruction,不加触发词
|
||
"input": "",
|
||
"output": item["answer"], # 直接回答,无<think>标签
|
||
"history": []
|
||
}
|
||
training_data.append(training_item)
|
||
|
||
# 保存数据
|
||
with open(OUTPUT_NO_THINKING, 'w', encoding='utf-8') as f:
|
||
json.dump(training_data, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"✓ 成功处理 {len(training_data)} 条非思考数据")
|
||
print(f" 保存至: {OUTPUT_NO_THINKING}")
|
||
|
||
# 验证是否有<think>标签
|
||
has_think = False
|
||
for item in training_data:
|
||
if "<think>" in item["output"]:
|
||
has_think = True
|
||
print(f" ⚠ 警告:发现包含<think>标签的输出")
|
||
break
|
||
|
||
if not has_think:
|
||
print(" ✓ 验证通过:所有输出都不包含<think>标签")
|
||
|
||
return training_data
|
||
|
||
def process_thinking_data():
|
||
"""处理思考数据,添加触发词"""
|
||
print("=" * 60)
|
||
print("处理思考数据...")
|
||
|
||
# 读取数据
|
||
with open(THINKING_FILE, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
print(f"读取到 {len(data)} 条思考数据")
|
||
|
||
# 添加触发词
|
||
training_data = []
|
||
for idx, item in enumerate(data):
|
||
# 检查数据格式
|
||
if not isinstance(item, dict):
|
||
print(f"跳过格式错误的数据: 索引 {idx}")
|
||
continue
|
||
|
||
# 随机选择一个触发词
|
||
trigger = random.choice(TRIGGER_PHRASES)
|
||
|
||
# 获取原始问题和答案
|
||
if "instruction" in item:
|
||
# 已经是训练格式
|
||
original_question = item["instruction"]
|
||
original_answer = item.get("output", "")
|
||
elif "question" in item:
|
||
# DeepSeek响应格式
|
||
original_question = item["question"]
|
||
original_answer = item.get("answer", "")
|
||
else:
|
||
print(f"跳过无法识别的数据格式: 索引 {idx}")
|
||
continue
|
||
|
||
# 在问题后添加触发词
|
||
training_item = {
|
||
"instruction": f"{original_question}\n{trigger}",
|
||
"input": "",
|
||
"output": original_answer, # 包含<think>标签的完整答案
|
||
"history": []
|
||
}
|
||
training_data.append(training_item)
|
||
|
||
# 保存数据
|
||
with open(OUTPUT_THINKING, 'w', encoding='utf-8') as f:
|
||
json.dump(training_data, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"✓ 成功处理 {len(training_data)} 条思考数据")
|
||
print(f" 保存至: {OUTPUT_THINKING}")
|
||
|
||
# 验证是否都有<think>标签
|
||
no_think_count = 0
|
||
for item in training_data:
|
||
if "<think>" not in item["output"]:
|
||
no_think_count += 1
|
||
|
||
if no_think_count == 0:
|
||
print(f" ✓ 验证通过:所有输出都包含<think>标签")
|
||
else:
|
||
print(f" ⚠ 警告:有 {no_think_count} 条数据不包含<think>标签")
|
||
|
||
# 统计触发词使用情况
|
||
trigger_stats = {}
|
||
for item in training_data:
|
||
for trigger in TRIGGER_PHRASES:
|
||
if trigger in item["instruction"]:
|
||
trigger_stats[trigger] = trigger_stats.get(trigger, 0) + 1
|
||
break
|
||
|
||
print(" 触发词分布:")
|
||
for trigger, count in trigger_stats.items():
|
||
print(f" - {trigger}: {count}次")
|
||
|
||
return training_data
|
||
|
||
def merge_datasets(no_thinking_data: List[Dict], thinking_data: List[Dict]):
|
||
"""合并两个数据集,创建混合训练数据"""
|
||
print("=" * 60)
|
||
print("创建混合数据集...")
|
||
|
||
# 合并数据
|
||
all_data = []
|
||
|
||
# 添加标记以区分数据类型(可选,用于后续分析)
|
||
for item in no_thinking_data:
|
||
item_copy = item.copy()
|
||
item_copy["_type"] = "no_thinking"
|
||
all_data.append(item_copy)
|
||
|
||
for item in thinking_data:
|
||
item_copy = item.copy()
|
||
item_copy["_type"] = "thinking"
|
||
all_data.append(item_copy)
|
||
|
||
# 随机打乱数据
|
||
random.shuffle(all_data)
|
||
|
||
# 移除内部标记(如果不需要)
|
||
for item in all_data:
|
||
if "_type" in item:
|
||
del item["_type"]
|
||
|
||
# 保存混合数据
|
||
with open(OUTPUT_MIXED, 'w', encoding='utf-8') as f:
|
||
json.dump(all_data, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"✓ 成功创建混合数据集")
|
||
print(f" 总数据量: {len(all_data)} 条")
|
||
print(f" - 非思考数据: {len(no_thinking_data)} 条")
|
||
print(f" - 思考数据: {len(thinking_data)} 条")
|
||
print(f" 保存至: {OUTPUT_MIXED}")
|
||
|
||
return all_data
|
||
|
||
def analyze_dataset(data: List[Dict], dataset_name: str):
|
||
"""分析数据集统计信息"""
|
||
print(f"\n{dataset_name} 统计信息:")
|
||
|
||
# 计算平均长度
|
||
instruction_lengths = [len(item["instruction"]) for item in data]
|
||
output_lengths = [len(item["output"]) for item in data]
|
||
|
||
print(f" 问题长度: 平均 {sum(instruction_lengths)/len(instruction_lengths):.0f} 字符")
|
||
print(f" 答案长度: 平均 {sum(output_lengths)/len(output_lengths):.0f} 字符")
|
||
|
||
# 统计包含think标签的数据
|
||
with_think = sum(1 for item in data if "<think>" in item["output"])
|
||
print(f" 包含<think>标签: {with_think} 条 ({with_think/len(data)*100:.1f}%)")
|
||
|
||
# 统计包含触发词的数据
|
||
with_trigger = sum(1 for item in data if any(t in item["instruction"] for t in TRIGGER_PHRASES))
|
||
print(f" 包含触发词: {with_trigger} 条 ({with_trigger/len(data)*100:.1f}%)")
|
||
|
||
def main():
|
||
print("开始准备混合训练数据集")
|
||
print("目标:创建1500条数据(1000条带思考 + 500条不带思考)\n")
|
||
|
||
# 处理非思考数据
|
||
no_thinking_data = process_no_thinking_data()
|
||
|
||
# 处理思考数据
|
||
thinking_data = process_thinking_data()
|
||
|
||
# 合并数据集
|
||
mixed_data = merge_datasets(no_thinking_data, thinking_data)
|
||
|
||
# 分析统计信息
|
||
print("\n" + "=" * 60)
|
||
print("数据集分析")
|
||
analyze_dataset(no_thinking_data, "非思考数据集")
|
||
analyze_dataset(thinking_data, "思考数据集")
|
||
analyze_dataset(mixed_data, "混合数据集")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("处理完成!")
|
||
print("\n生成的文件:")
|
||
print(f"1. {OUTPUT_NO_THINKING.name} - 500条非思考训练数据")
|
||
print(f"2. {OUTPUT_THINKING.name} - 1000条带触发词的思考训练数据")
|
||
print(f"3. {OUTPUT_MIXED.name} - 1500条混合训练数据(推荐使用)")
|
||
print("\n下一步:")
|
||
print("1. 将 training_data_mixed_1500.json 上传到 AutoDL")
|
||
print("2. 放置到 /root/autodl-tmp/LLaMA-Factory/data/ 目录")
|
||
print("3. 更新 dataset_info.json 添加新数据集配置")
|
||
print("4. 使用新的训练配置开始微调")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|