【细看open_r1】精读训练和评估模型以及生成合成数据的脚本(src/open_r1)
检查最后一个检查点。
src/open_r1 目录下主要包含了一些用于训练和评估模型以及生成合成数据的Python脚本,下面我们对其中几个主要的Python文件进行深度剖析。
configs.py
这个文件定义了两个数据类 GRPOConfig 和 SFTConfig,它们分别继承自 trl.GRPOConfig 和 trl.SFTConfig,并且添加了一些额外的参数用于回调、基准测试等。
# 导入必要的模块
from dataclasses import dataclass, field
from typing import Optional
import trl
# 定义GRPOConfig类,继承自trl.GRPOConfig
@dataclass
class GRPOConfig(trl.GRPOConfig):
"""
args for callbacks, benchmarks etc
"""
# 定义用于存储基准测试名称的列表,默认值为空列表
benchmarks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
)
# 定义用于存储回调函数名称的列表,默认值为空列表
callbacks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
)
# 定义可选的系统提示,用于基准测试,默认值为None
system_prompt: Optional[str] = field(
default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
)
# 定义Hub模型的修订版本,默认值为"main"
hub_model_revision: Optional[str] = field(
default="main", metadata={"help": "The Hub model branch to push the model to."}
)
# 定义是否覆盖Hub修订版本,默认值为False
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
# 定义是否推送到Hub修订版本/分支,默认值为False
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
# 定义WandB的实体,用于存储运行结果,默认值为None
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
# 定义WandB的项目,用于存储运行结果,默认值为None
wandb_project: Optional[str] = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)
# 定义SFTConfig类,继承自trl.SFTConfig
@dataclass
class SFTConfig(trl.SFTConfig):
"""
args for callbacks, benchmarks etc
"""
# 定义用于存储基准测试名称的列表,默认值为空列表
benchmarks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
)
# 定义用于存储回调函数名称的列表,默认值为空列表
callbacks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
)
# 定义可选的系统提示,用于基准测试,默认值为None
system_prompt: Optional[str] = field(
default=None,
metadata={"help": "The optional system prompt to use for benchmarking."},
)
# 定义Hub模型的修订版本,默认值为"main"
hub_model_revision: Optional[str] = field(
default="main",
metadata={"help": "The Hub model branch to push the model to."},
)
# 定义是否覆盖Hub修订版本,默认值为False
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
# 定义是否推送到Hub修订版本/分支,默认值为False
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
# 定义WandB的实体,用于存储运行结果,默认值为None
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
# 定义WandB的项目,用于存储运行结果,默认值为None
wandb_project: Optional[str] = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)
sft.py
这个文件是一个监督式微调(Supervised Fine-Tuning,SFT)的脚本,用于对解码器语言模型进行微调。
1 导入必要的模块
import logging
import os
import sys
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import AutoTokenizer, set_seed
from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import SFTConfig
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.logging import init_wandb_training
from trl import (
ModelConfig,
ScriptArguments,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
2 配置日志记录
logger = logging.getLogger(__name__)
def main(script_args, training_args, model_args):
# 设置随机种子以确保可重复性
set_seed(training_args.seed)
# 配置日志记录
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# 记录每个进程的简要信息
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Training parameters {training_args}")
3 检查最后一个检查点
# 检查最后一个检查点
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
# 如果使用WandB进行报告,则初始化WandB训练
if "wandb" in training_args.report_to:
init_wandb_training(training_args)
4 加载数据集和分词器
# 加载数据集
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)
tokenizer.pad_token = tokenizer.eos_token
5 初始化模型参数
# 初始化模型参数
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
6 初始化SFT Trainer
# 初始化SFT Trainer
trainer = SFTTrainer(
model=model_args.model_name_or_path,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
)
7 训练循环
# 训练循环
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
8 保存模型和创建模型卡片
# 保存模型和创建模型卡片
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# 在主进程上保存其他信息
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["open-r1"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# 恢复k,v缓存以进行快速推理
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
9 评估模型
# 评估模型
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
10 推送到Hub
# 推送到Hub
if training_args.push_to_hub:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)
if __name__ == "__main__":
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)
evaluate.py`
用于定义自定义的 LightEval 评估任务,以下是对该文件的深度解析:
1. 模块导入和常量定义
import random
from lighteval.metrics.dynamic_metrics import (
ExprExtractionConfig,
IndicesExtractionConfig,
LatexExtractionConfig,
multilingual_extractive_match_metric,
)
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
- 导入了
random模块,用于生成随机数。 - 从
lighteval相关模块中导入了一些配置类和函数,用于定义评估任务和指标。
latex_gold_metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
fallback_mode="first_match",
precision=5,
gold_extraction_target=(LatexExtractionConfig(),),
# Match boxed first before trying other regexes
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
aggregation_function=max,
)
expr_gold_metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
fallback_mode="first_match",
precision=5,
gold_extraction_target=(ExprExtractionConfig(),),
# Match boxed first before trying other regexes
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
aggregation_function=max,
)
gpqa_metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
precision=5,
)
- 定义了三个评估指标:
latex_gold_metric、expr_gold_metric和gpqa_metric,用于后续评估任务的配置。
2. 提示函数定义
def prompt_fn(line, task_name: str = None):
"""Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
return Doc(
task_name=task_name,
query=line["problem"],
choices=[line["solution"]],
gold_index=0,
)
def aime_prompt_fn(line, task_name: str = None):
return Doc(
task_name=task_name,
query=line["problem"],
choices=[line["answer"]],
gold_index=0,
)
def gpqa_prompt_fn(line, task_name: str = None):
"""Prompt template adapted from simple-evals: https://github.com/openai/simple-evals/blob/83ed7640a7d9cd26849bcb3340125002ef14abbe/common.py#L14"""
gold_index = random.randint(0, 3)
choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]]
choices.insert(gold_index, line["Correct Answer"])
query_template = "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}"
query = query_template.format(A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=line["Question"])
return Doc(
task_name=task_name,
query=query,
choices=["A", "B", "C", "D"],
gold_index=gold_index,
instruction=query,
)
- 定义了三个提示函数:
prompt_fn、aime_prompt_fn和gpqa_prompt_fn,用于生成不同任务的提示信息。
3. 评估任务配置
aime24 = LightevalTaskConfig(
name="aime24",
suite=["custom"],
prompt_function=aime_prompt_fn,
hf_repo="HuggingFaceH4/aime_2024",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[expr_gold_metric],
version=1,
)
# Part I from AIME 2025 exam: https://artofproblemsolving.com/wiki/index.php/2025_AIME_I?srsltid=AfmBOoof5gaaqlt3-l6LH7Tt6qmJZtl_2PQEDYlLFlMqhq9dLL8FMCRR
aime25_part1 = LightevalTaskConfig(
name="aime25:part1",
suite=["custom"],
prompt_function=aime_prompt_fn,
hf_repo="open-r1/aime_2025_1",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[expr_gold_metric],
version=1,
)
math_500 = LightevalTaskConfig(
name="math_500",
suite=["custom"],
prompt_function=prompt_fn,
hf_repo="HuggingFaceH4/MATH-500",
hf_subset="default",
hf_avail_splits=["test"],
evaluation_splits=["test"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[latex_gold_metric],
version=1,
)
gpqa_diamond = LightevalTaskConfig(
name="gpqa:diamond",
suite=["custom"],
prompt_function=gpqa_prompt_fn,
hf_repo="Idavidrein/gpqa",
hf_subset="gpqa_diamond",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768, # needed for reasoning models like R1
metric=[gpqa_metric],
stop_sequence=[], # no stop sequence, will use eos token
trust_dataset=True,
version=1,
)
- 定义了四个评估任务:
aime24、aime25_part1、math_500和gpqa_diamond,每个任务都配置了名称、套件、提示函数、数据集信息、评估指标等。
4. 任务表定义和模块逻辑
# Add tasks to the table
TASKS_TABLE = []
TASKS_TABLE.append(aime24)
TASKS_TABLE.append(aime25_part1)
TASKS_TABLE.append(math_500)
TASKS_TABLE.append(gpqa_diamond)
# MODULE LOGIC
if __name__ == "__main__":
print([t["name"] for t in TASKS_TABLE])
print(len(TASKS_TABLE))
- 将定义的四个评估任务添加到
TASKS_TABLE列表中。 - 如果该文件作为脚本直接运行,则打印任务表中每个任务的名称和任务数量。
generate.py
是一个用于构建并运行 Distilabel 管道以生成响应的 Python 脚本。下面我们将继续深入解析该文件的各个部分:
1. 导入模块
# Copyright 2025 The HuggingFace Team. All rights reserved.
# 省略版权和许可信息...
from typing import Optional
from distilabel.llms import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import StepResources
from distilabel.steps.tasks import TextGeneration
Optional: 从typing模块导入,用于表示一个类型可以是指定类型或者None。OpenAILLM: 从distilabel.llms导入,用于创建 OpenAI 风格的语言模型实例。Pipeline: 从distilabel.pipeline导入,用于构建处理数据的管道。StepResources: 从distilabel.steps导入,用于指定步骤的资源配置。TextGeneration: 从distilabel.steps.tasks导入,用于执行文本生成任务。
2. build_distilabel_pipeline 函数
def build_distilabel_pipeline(
model: str,
base_url: str = "http://localhost:8000/v1",
prompt_column: Optional[str] = None,
prompt_template: str = "{{ instruction }}",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_new_tokens: int = 8192,
num_generations: int = 1,
input_batch_size: int = 64,
client_replicas: int = 1,
timeout: int = 900,
retries: int = 0,
) -> Pipeline:
generation_kwargs = {"max_new_tokens": max_new_tokens}
if temperature is not None:
generation_kwargs["temperature"] = temperature
if top_p is not None:
generation_kwargs["top_p"] = top_p
with Pipeline().ray() as pipeline:
TextGeneration(
llm=OpenAILLM(
base_url=base_url,
api_key="something",
model=model,
timeout=timeout,
max_retries=retries,
generation_kwargs=generation_kwargs,
),
template=prompt_template,
input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
input_batch_size=input_batch_size,
num_generations=num_generations,
group_generations=True,
resources=StepResources(replicas=client_replicas),
)
return pipeline
-
参数:
model: 用于生成的模型名称,为必需参数。base_url: 语言模型服务的基础 URL,默认为http://localhost:8000/v1。prompt_column: 数据集中用于提示的列名,可选参数。prompt_template: 提示的模板字符串,默认为{{ instruction }}。temperature: 生成时的温度参数,可选参数。top_p: 生成时的 top-p 参数,可选参数。max_new_tokens: 生成的最大新令牌数,默认为 8192。num_generations: 每个问题的生成次数,默认为 1。input_batch_size: 输入处理的批量大小,默认为 64。client_replicas: 并行处理的客户端副本数,默认为 1。timeout: 请求超时时间(秒),默认为 900。retries: 失败请求的重试次数,默认为 0。
-
功能:
- 构建一个 Distilabel 管道,其中包含一个
TextGeneration步骤。 - 根据输入的参数配置
OpenAILLM实例。 - 最终返回构建好的管道。
- 构建一个 Distilabel 管道,其中包含一个
3. main 部分
if __name__ == "__main__":
import argparse
from datasets import load_dataset
parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
# 省略一系列参数定义...
args = parser.parse_args()
print("\nRunning with arguments:")
for arg, value in vars(args).items():
print(f" {arg}: {value}")
print()
print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
dataset = load_dataset(args.hf_dataset, args.hf_dataset_config, split=args.hf_dataset_split)
print("Dataset loaded!")
pipeline = build_distilabel_pipeline(
model=args.model,
base_url=args.vllm_server_url,
prompt_template=args.prompt_template,
prompt_column=args.prompt_column,
temperature=args.temperature,
top_p=args.top_p,
max_new_tokens=args.max_new_tokens,
num_generations=args.num_generations,
input_batch_size=args.input_batch_size,
client_replicas=args.client_replicas,
timeout=args.timeout,
retries=args.retries,
)
print("Running generation pipeline...")
distiset = pipeline.run(
dataset=dataset,
dataset_batch_size=args.input_batch_size * 1000,
use_cache=False,
)
print("Generation pipeline finished!")
if args.hf_output_dataset:
print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
distiset.push_to_hub(args.hf_output_dataset, private=args.private)
print("Dataset pushed!")
- 功能:
- 解析命令行参数,包括数据集相关参数、模型参数、生成参数等。
- 打印运行参数。
- 加载指定的 Hugging Face 数据集。
- 调用
build_distilabel_pipeline函数构建管道。 - 运行管道进行文本生成。
- 如果指定了输出数据集的名称,则将生成的数据集推送到 Hugging Face Hub。
grpo.py
是一个用于 GRPO训练脚本的 Python 文件,下面对其进行深度解析:
1. 导入模块
import logging
import os
import sys
from dataclasses import dataclass, field
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import GRPOConfig
from open_r1.rewards import (
accuracy_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
reasoning_steps_reward,
)
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.logging import init_wandb_training
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
这里导入了许多必要的模块,包括日志记录、操作系统相关、数据处理(datasets)、深度学习框架(torch、transformers),以及一些自定义模块。
2. 日志记录和常量定义
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)
logger用于记录日志信息。SYSTEM_PROMPT是一个系统提示,用于构建对话格式。
3. 定义脚本参数类 GRPOScriptArguments
@dataclass
class GRPOScriptArguments(ScriptArguments):
...
这是一个数据类,用于存储脚本的参数,包括奖励函数列表、余弦缩放的相关参数、重复惩罚的相关参数等,并且对每个参数都有详细的说明和默认值。
4. 主函数 main
def main(script_args, training_args, model_args):
...
主函数是整个脚本的核心,执行以下主要步骤:
4.1 设置随机种子
set_seed(training_args.seed)
确保实验的可重复性。
4.2 配置日志记录
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
配置日志的格式、日期格式和处理程序,并设置日志级别。
4.3 记录基本信息
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Training parameters {training_args}")
记录进程的基本信息、模型参数、脚本参数和训练参数。
4.4 检查检查点
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
检查是否有之前的检查点,如果有且没有指定从其他检查点恢复,则从该检查点继续训练。
4.5 初始化 WandB 训练(如果需要)
if "wandb" in training_args.report_to:
init_wandb_training(training_args)
如果指定使用 WandB 进行报告,则初始化 WandB 训练。
4.6 加载数据集
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
使用 load_dataset 函数加载指定的数据集。
4.7 获取奖励函数
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
"cosine": get_cosine_scaled_reward(
min_value_wrong=script_args.cosine_min_value_wrong,
max_value_wrong=script_args.cosine_max_value_wrong,
min_value_correct=script_args.cosine_min_value_correct,
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
"repetition_penalty": get_repetition_penalty_reward(
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
定义奖励函数的注册表,并根据脚本参数选择相应的奖励函数。
4.8 格式化数据集为对话格式
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}
dataset = dataset.map(make_conversation)
for split in dataset:
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
将数据集的每个示例格式化为对话格式,并移除不必要的列。
4.9 初始化模型参数
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
)
training_args.model_init_kwargs = model_kwargs
根据模型参数和训练参数初始化模型的关键字参数。
4.10 初始化 GRPO 训练器
trainer = GRPOTrainer(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
)
使用 GRPOTrainer 类初始化训练器,传入模型、奖励函数、训练参数、数据集等。
4.11 训练模型
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
根据检查点情况开始训练模型,并记录训练结果和保存模型状态。
4.12 保存模型和创建模型卡片
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["open-r1"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
保存训练好的模型,并在主进程中创建模型卡片。
4.13 评估模型(如果需要)
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
如果指定进行评估,则使用训练好的模型进行评估,并记录评估结果。
4.14 推送模型到 Hugging Face Hub(如果需要)
if training_args.push_to_hub:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)
如果指定将模型推送到 Hugging Face Hub,则执行推送操作。
5. 主程序入口
if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)
解析命令行参数和配置文件,并调用 main 函数开始训练。
rewards.py
文件主要包含了一系列用于GRPO训练的奖励函数。下面对该文件的内容进行深度解析:
1. 模块导入
import math
import re
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
math模块用于数学计算,例如math.cos用于计算余弦值。re模块用于正则表达式操作,在多个奖励函数中用于文本匹配。- 从
latex2sympy2_extended和math_verify导入的模块和函数用于处理LaTeX表达式的解析和验证。
2. accuracy_reward 函数
def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
print("Failed to parse gold solution: ", sol)
rewards.append(reward)
return rewards
- 功能:检查模型生成的完成内容是否与真实答案相同,若相同则奖励为1,不同则为0。如果真实答案无法解析,则跳过该示例并给予奖励1。
- 参数:
completions:模型生成的完成内容列表,每个元素是一个字典列表,取其第一个元素的"content"作为实际完成内容。solution:真实答案列表。
- 实现步骤:
- 提取
completions中的内容。 - 遍历每个完成内容和对应的真实答案。
- 解析真实答案,如果解析成功则继续解析完成内容,并验证两者是否相同。
- 根据验证结果给予奖励。
- 提取
3. format_reward 函数
def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
- 功能:检查模型生成的完成内容是否符合特定格式,即是否以
<think>开头,中间有内容,然后是<answer>结尾。 - 参数:
completions:模型生成的完成内容列表。
- 实现步骤:
- 提取
completions中的内容。 - 使用正则表达式匹配每个完成内容。
- 根据匹配结果给予奖励,匹配成功为1,失败为0。
- 提取
4. reasoning_steps_reward 函数
def reasoning_steps_reward(completions, **kwargs):
r"""Reward function that checks for clear step-by-step reasoning.
Regex pattern:
Step \d+: - matches "Step 1:", "Step 2:", etc.
^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
\n- - matches bullet points with hyphens
\n\* - matches bullet points with asterisks
First,|Second,|Next,|Finally, - matches transition words
"""
pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [len(re.findall(pattern, content)) for content in completion_contents]
# Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward
return [min(1.0, count / 3) for count in matches]
- 功能:检查模型生成的完成内容中是否有清晰的逐步推理步骤,通过正则表达式匹配特定的标记。
- 参数:
completions:模型生成的完成内容列表。
- 实现步骤:
- 提取
completions中的内容。 - 使用正则表达式查找每个完成内容中匹配的标记数量。
- 根据匹配数量给予奖励,匹配数量达到3个及以上则奖励为1,否则给予部分奖励。
- 提取
5. get_cosine_scaled_reward 函数
def get_cosine_scaled_reward(
min_value_wrong: float = -1.0,
max_value_wrong: float = -0.5,
min_value_correct: float = 0.5,
max_value_correct: float = 1.0,
max_len: int = 1000,
):
def cosine_scaled_reward(completions, solution, **kwargs):
"""Reward function that scales based on completion length using a cosine schedule.
Shorter correct solutions are rewarded more than longer ones.
Longer incorrect solutions are penalized less than shorter ones.
Args:
completions: List of model completions
solution: List of ground truth solutions
This function is parameterized by the following arguments:
min_value_wrong: Minimum reward for wrong answers
max_value_wrong: Maximum reward for wrong answers
min_value_correct: Minimum reward for correct answers
max_value_correct: Maximum reward for correct answers
max_len: Maximum length for scaling
"""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) == 0:
rewards.append(1.0) # Skip unparseable examples
print("Failed to parse gold solution: ", sol)
continue
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
is_correct = verify(answer_parsed, gold_parsed)
gen_len = len(content)
# Apply cosine scaling based on length
progress = gen_len / max_len
cosine = math.cos(progress * math.pi)
if is_correct:
min_value = min_value_correct
max_value = max_value_correct
else:
# Swap min/max for incorrect answers
min_value = max_value_wrong
max_value = min_value_wrong
reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
rewards.append(float(reward))
return rewards
return cosine_scaled_reward
- 功能:根据完成内容的长度使用余弦调度来计算奖励,较短的正确答案会得到更高的奖励,较长的错误答案会受到较少的惩罚。
- 参数:
min_value_wrong:错误答案的最小奖励,默认为 -1.0。max_value_wrong:错误答案的最大奖励,默认为 -0.5。min_value_correct:正确答案的最小奖励,默认为 0.5。max_value_correct:正确答案的最大奖励,默认为 1.0。max_len:用于缩放的最大长度,默认为 1000。
- 实现步骤:
- 提取
completions中的内容。 - 遍历每个完成内容和对应的真实答案。
- 解析真实答案和完成内容,并验证两者是否相同。
- 根据完成内容的长度计算进度,并使用余弦函数进行缩放。
- 根据验证结果和缩放值计算奖励。
- 提取
6. get_repetition_penalty_reward 函数
def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
"""
Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373.
Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
Args:
ngram_size: size of the n-grams
max_penalty: Maximum (negative) penalty for wrong answers
"""
if max_penalty > 0:
raise ValueError(f"max_penalty {max_penalty} should not be positive")
def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)])
def repetition_penalty_reward(completions, **kwargs) -> float:
"""
reward function the penalizes repetitions
ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
Args:
completions: List of model completions
"""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for completion in contents:
if completion == "":
rewards.append(0.0)
continue
if len(completion.split()) < ngram_size:
rewards.append(0.0)
continue
ngrams = set()
total = 0
for ng in zipngram(completion, ngram_size):
ngrams.add(ng)
total += 1
scaling = 1 - len(ngrams) / total
reward = scaling * max_penalty
rewards.append(reward)
return rewards
return repetition_penalty_reward
- 功能:计算N-gram重复惩罚,用于惩罚生成内容中的重复内容。
- 参数:
ngram_size:N-gram的大小。max_penalty:最大(负)惩罚值。
- 实现步骤:
- 检查
max_penalty是否为正数,若是则抛出异常。 - 定义
zipngram函数用于生成N-gram。 - 定义
repetition_penalty_reward函数,遍历每个完成内容,计算N-gram的重复率,并根据重复率和max_penalty计算奖励。
- 检查
更多推荐
所有评论(0)