src/open_r1 目录下主要包含了一些用于训练和评估模型以及生成合成数据的Python脚本,下面我们对其中几个主要的Python文件进行深度剖析。

configs.py

这个文件定义了两个数据类 GRPOConfigSFTConfig,它们分别继承自 trl.GRPOConfigtrl.SFTConfig,并且添加了一些额外的参数用于回调、基准测试等。

# 导入必要的模块
from dataclasses import dataclass, field
from typing import Optional
import trl

# 定义GRPOConfig类,继承自trl.GRPOConfig
@dataclass
class GRPOConfig(trl.GRPOConfig):
    """
    args for callbacks, benchmarks etc
    """
    # 定义用于存储基准测试名称的列表,默认值为空列表
    benchmarks: list[str] = field(
        default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
    )
    # 定义用于存储回调函数名称的列表,默认值为空列表
    callbacks: list[str] = field(
        default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
    )
    # 定义可选的系统提示,用于基准测试,默认值为None
    system_prompt: Optional[str] = field(
        default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
    )
    # 定义Hub模型的修订版本,默认值为"main"
    hub_model_revision: Optional[str] = field(
        default="main", metadata={"help": "The Hub model branch to push the model to."}
    )
    # 定义是否覆盖Hub修订版本,默认值为False
    overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
    # 定义是否推送到Hub修订版本/分支,默认值为False
    push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
    # 定义WandB的实体,用于存储运行结果,默认值为None
    wandb_entity: Optional[str] = field(
        default=None,
        metadata={"help": ("The entity to store runs under.")},
    )
    # 定义WandB的项目,用于存储运行结果,默认值为None
    wandb_project: Optional[str] = field(
        default=None,
        metadata={"help": ("The project to store runs under.")},
    )

# 定义SFTConfig类,继承自trl.SFTConfig
@dataclass
class SFTConfig(trl.SFTConfig):
    """
    args for callbacks, benchmarks etc
    """
    # 定义用于存储基准测试名称的列表,默认值为空列表
    benchmarks: list[str] = field(
        default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
    )
    # 定义用于存储回调函数名称的列表,默认值为空列表
    callbacks: list[str] = field(
        default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
    )
    # 定义可选的系统提示,用于基准测试,默认值为None
    system_prompt: Optional[str] = field(
        default=None,
        metadata={"help": "The optional system prompt to use for benchmarking."},
    )
    # 定义Hub模型的修订版本,默认值为"main"
    hub_model_revision: Optional[str] = field(
        default="main",
        metadata={"help": "The Hub model branch to push the model to."},
    )
    # 定义是否覆盖Hub修订版本,默认值为False
    overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
    # 定义是否推送到Hub修订版本/分支,默认值为False
    push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
    # 定义WandB的实体,用于存储运行结果,默认值为None
    wandb_entity: Optional[str] = field(
        default=None,
        metadata={"help": ("The entity to store runs under.")},
    )
    # 定义WandB的项目,用于存储运行结果,默认值为None
    wandb_project: Optional[str] = field(
        default=None,
        metadata={"help": ("The project to store runs under.")},
    )

sft.py

这个文件是一个监督式微调(Supervised Fine-Tuning,SFT)的脚本,用于对解码器语言模型进行微调。

1 导入必要的模块
import logging
import os
import sys
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import AutoTokenizer, set_seed
from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import SFTConfig
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.logging import init_wandb_training
from trl import (
    ModelConfig,
    ScriptArguments,
    SFTTrainer,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)
2 配置日志记录
logger = logging.getLogger(__name__)

def main(script_args, training_args, model_args):
    # 设置随机种子以确保可重复性
    set_seed(training_args.seed)

    # 配置日志记录
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # 记录每个进程的简要信息
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Model parameters {model_args}")
    logger.info(f"Script parameters {script_args}")
    logger.info(f"Training parameters {training_args}")
3 检查最后一个检查点
    # 检查最后一个检查点
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
        logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

    # 如果使用WandB进行报告,则初始化WandB训练
    if "wandb" in training_args.report_to:
        init_wandb_training(training_args)
4 加载数据集和分词器
    # 加载数据集
    dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

    # 加载分词器
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
    )
    tokenizer.pad_token = tokenizer.eos_token
5 初始化模型参数
    # 初始化模型参数
    logger.info("*** Initializing model kwargs ***")
    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )
    quantization_config = get_quantization_config(model_args)
    model_kwargs = dict(
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    training_args.model_init_kwargs = model_kwargs
6 初始化SFT Trainer
    # 初始化SFT Trainer
    trainer = SFTTrainer(
        model=model_args.model_name_or_path,
        args=training_args,
        train_dataset=dataset[script_args.dataset_train_split],
        eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
        processing_class=tokenizer,
        peft_config=get_peft_config(model_args),
        callbacks=get_callbacks(training_args, model_args),
    )
7 训练循环
    # 训练循环
    logger.info("*** Train ***")
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
    metrics = train_result.metrics
    metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()
8 保存模型和创建模型卡片
    # 保存模型和创建模型卡片
    logger.info("*** Save model ***")
    trainer.save_model(training_args.output_dir)
    logger.info(f"Model saved to {training_args.output_dir}")

    # 在主进程上保存其他信息
    kwargs = {
        "dataset_name": script_args.dataset_name,
        "tags": ["open-r1"],
    }
    if trainer.accelerator.is_main_process:
        trainer.create_model_card(**kwargs)
        # 恢复k,v缓存以进行快速推理
        trainer.model.config.use_cache = True
        trainer.model.config.save_pretrained(training_args.output_dir)
9 评估模型
    # 评估模型
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate()
        metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)
10 推送到Hub
    # 推送到Hub
    if training_args.push_to_hub:
        logger.info("Pushing to hub...")
        trainer.push_to_hub(**kwargs)


if __name__ == "__main__":
    parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    main(script_args, training_args, model_args)

evaluate.py`

用于定义自定义的 LightEval 评估任务,以下是对该文件的深度解析:

1. 模块导入和常量定义

import random

from lighteval.metrics.dynamic_metrics import (
    ExprExtractionConfig,
    IndicesExtractionConfig,
    LatexExtractionConfig,
    multilingual_extractive_match_metric,
)
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
  • 导入了 random 模块,用于生成随机数。
  • lighteval 相关模块中导入了一些配置类和函数,用于定义评估任务和指标。
latex_gold_metric = multilingual_extractive_match_metric(
    language=Language.ENGLISH,
    fallback_mode="first_match",
    precision=5,
    gold_extraction_target=(LatexExtractionConfig(),),
    # Match boxed first before trying other regexes
    pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
    aggregation_function=max,
)

expr_gold_metric = multilingual_extractive_match_metric(
    language=Language.ENGLISH,
    fallback_mode="first_match",
    precision=5,
    gold_extraction_target=(ExprExtractionConfig(),),
    # Match boxed first before trying other regexes
    pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
    aggregation_function=max,
)

gpqa_metric = multilingual_extractive_match_metric(
    language=Language.ENGLISH,
    gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
    pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
    precision=5,
)
  • 定义了三个评估指标:latex_gold_metricexpr_gold_metricgpqa_metric,用于后续评估任务的配置。

2. 提示函数定义

def prompt_fn(line, task_name: str = None):
    """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
    return Doc(
        task_name=task_name,
        query=line["problem"],
        choices=[line["solution"]],
        gold_index=0,
    )


def aime_prompt_fn(line, task_name: str = None):
    return Doc(
        task_name=task_name,
        query=line["problem"],
        choices=[line["answer"]],
        gold_index=0,
    )


def gpqa_prompt_fn(line, task_name: str = None):
    """Prompt template adapted from simple-evals: https://github.com/openai/simple-evals/blob/83ed7640a7d9cd26849bcb3340125002ef14abbe/common.py#L14"""
    gold_index = random.randint(0, 3)
    choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]]
    choices.insert(gold_index, line["Correct Answer"])
    query_template = "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}"
    query = query_template.format(A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=line["Question"])

    return Doc(
        task_name=task_name,
        query=query,
        choices=["A", "B", "C", "D"],
        gold_index=gold_index,
        instruction=query,
    )
  • 定义了三个提示函数:prompt_fnaime_prompt_fngpqa_prompt_fn,用于生成不同任务的提示信息。

3. 评估任务配置

aime24 = LightevalTaskConfig(
    name="aime24",
    suite=["custom"],
    prompt_function=aime_prompt_fn,
    hf_repo="HuggingFaceH4/aime_2024",
    hf_subset="default",
    hf_avail_splits=["train"],
    evaluation_splits=["train"],
    few_shots_split=None,
    few_shots_select=None,
    generation_size=32768,
    metric=[expr_gold_metric],
    version=1,
)
# Part I from AIME 2025 exam: https://artofproblemsolving.com/wiki/index.php/2025_AIME_I?srsltid=AfmBOoof5gaaqlt3-l6LH7Tt6qmJZtl_2PQEDYlLFlMqhq9dLL8FMCRR
aime25_part1 = LightevalTaskConfig(
    name="aime25:part1",
    suite=["custom"],
    prompt_function=aime_prompt_fn,
    hf_repo="open-r1/aime_2025_1",
    hf_subset="default",
    hf_avail_splits=["train"],
    evaluation_splits=["train"],
    few_shots_split=None,
    few_shots_select=None,
    generation_size=32768,
    metric=[expr_gold_metric],
    version=1,
)
math_500 = LightevalTaskConfig(
    name="math_500",
    suite=["custom"],
    prompt_function=prompt_fn,
    hf_repo="HuggingFaceH4/MATH-500",
    hf_subset="default",
    hf_avail_splits=["test"],
    evaluation_splits=["test"],
    few_shots_split=None,
    few_shots_select=None,
    generation_size=32768,
    metric=[latex_gold_metric],
    version=1,
)
gpqa_diamond = LightevalTaskConfig(
    name="gpqa:diamond",
    suite=["custom"],
    prompt_function=gpqa_prompt_fn,
    hf_repo="Idavidrein/gpqa",
    hf_subset="gpqa_diamond",
    hf_avail_splits=["train"],
    evaluation_splits=["train"],
    few_shots_split=None,
    few_shots_select=None,
    generation_size=32768,  # needed for reasoning models like R1
    metric=[gpqa_metric],
    stop_sequence=[],  # no stop sequence, will use eos token
    trust_dataset=True,
    version=1,
)
  • 定义了四个评估任务:aime24aime25_part1math_500gpqa_diamond,每个任务都配置了名称、套件、提示函数、数据集信息、评估指标等。

4. 任务表定义和模块逻辑

# Add tasks to the table
TASKS_TABLE = []
TASKS_TABLE.append(aime24)
TASKS_TABLE.append(aime25_part1)
TASKS_TABLE.append(math_500)
TASKS_TABLE.append(gpqa_diamond)

# MODULE LOGIC
if __name__ == "__main__":
    print([t["name"] for t in TASKS_TABLE])
    print(len(TASKS_TABLE))
  • 将定义的四个评估任务添加到 TASKS_TABLE 列表中。
  • 如果该文件作为脚本直接运行,则打印任务表中每个任务的名称和任务数量。

generate.py

是一个用于构建并运行 Distilabel 管道以生成响应的 Python 脚本。下面我们将继续深入解析该文件的各个部分:

1. 导入模块

# Copyright 2025 The HuggingFace Team. All rights reserved.
# 省略版权和许可信息...

from typing import Optional

from distilabel.llms import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import StepResources
from distilabel.steps.tasks import TextGeneration
  • Optional: 从 typing 模块导入,用于表示一个类型可以是指定类型或者 None
  • OpenAILLM: 从 distilabel.llms 导入,用于创建 OpenAI 风格的语言模型实例。
  • Pipeline: 从 distilabel.pipeline 导入,用于构建处理数据的管道。
  • StepResources: 从 distilabel.steps 导入,用于指定步骤的资源配置。
  • TextGeneration: 从 distilabel.steps.tasks 导入,用于执行文本生成任务。

2. build_distilabel_pipeline 函数

def build_distilabel_pipeline(
    model: str,
    base_url: str = "http://localhost:8000/v1",
    prompt_column: Optional[str] = None,
    prompt_template: str = "{{ instruction }}",
    temperature: Optional[float] = None,
    top_p: Optional[float] = None,
    max_new_tokens: int = 8192,
    num_generations: int = 1,
    input_batch_size: int = 64,
    client_replicas: int = 1,
    timeout: int = 900,
    retries: int = 0,
) -> Pipeline:
    generation_kwargs = {"max_new_tokens": max_new_tokens}

    if temperature is not None:
        generation_kwargs["temperature"] = temperature

    if top_p is not None:
        generation_kwargs["top_p"] = top_p

    with Pipeline().ray() as pipeline:
        TextGeneration(
            llm=OpenAILLM(
                base_url=base_url,
                api_key="something",
                model=model,
                timeout=timeout,
                max_retries=retries,
                generation_kwargs=generation_kwargs,
            ),
            template=prompt_template,
            input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
            input_batch_size=input_batch_size,
            num_generations=num_generations,
            group_generations=True,
            resources=StepResources(replicas=client_replicas),
        )

    return pipeline
  • 参数

    • model: 用于生成的模型名称,为必需参数。
    • base_url: 语言模型服务的基础 URL,默认为 http://localhost:8000/v1
    • prompt_column: 数据集中用于提示的列名,可选参数。
    • prompt_template: 提示的模板字符串,默认为 {{ instruction }}
    • temperature: 生成时的温度参数,可选参数。
    • top_p: 生成时的 top-p 参数,可选参数。
    • max_new_tokens: 生成的最大新令牌数,默认为 8192。
    • num_generations: 每个问题的生成次数,默认为 1。
    • input_batch_size: 输入处理的批量大小,默认为 64。
    • client_replicas: 并行处理的客户端副本数,默认为 1。
    • timeout: 请求超时时间(秒),默认为 900。
    • retries: 失败请求的重试次数,默认为 0。
  • 功能

    • 构建一个 Distilabel 管道,其中包含一个 TextGeneration 步骤。
    • 根据输入的参数配置 OpenAILLM 实例。
    • 最终返回构建好的管道。

3. main 部分

if __name__ == "__main__":
    import argparse

    from datasets import load_dataset

    parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
    # 省略一系列参数定义...

    args = parser.parse_args()

    print("\nRunning with arguments:")
    for arg, value in vars(args).items():
        print(f"  {arg}: {value}")
    print()

    print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
    dataset = load_dataset(args.hf_dataset, args.hf_dataset_config, split=args.hf_dataset_split)
    print("Dataset loaded!")

    pipeline = build_distilabel_pipeline(
        model=args.model,
        base_url=args.vllm_server_url,
        prompt_template=args.prompt_template,
        prompt_column=args.prompt_column,
        temperature=args.temperature,
        top_p=args.top_p,
        max_new_tokens=args.max_new_tokens,
        num_generations=args.num_generations,
        input_batch_size=args.input_batch_size,
        client_replicas=args.client_replicas,
        timeout=args.timeout,
        retries=args.retries,
    )

    print("Running generation pipeline...")
    distiset = pipeline.run(
        dataset=dataset,
        dataset_batch_size=args.input_batch_size * 1000,
        use_cache=False,
    )
    print("Generation pipeline finished!")

    if args.hf_output_dataset:
        print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
        distiset.push_to_hub(args.hf_output_dataset, private=args.private)
        print("Dataset pushed!")
  • 功能
    • 解析命令行参数,包括数据集相关参数、模型参数、生成参数等。
    • 打印运行参数。
    • 加载指定的 Hugging Face 数据集。
    • 调用 build_distilabel_pipeline 函数构建管道。
    • 运行管道进行文本生成。
    • 如果指定了输出数据集的名称,则将生成的数据集推送到 Hugging Face Hub。

grpo.py

是一个用于 GRPO训练脚本的 Python 文件,下面对其进行深度解析:

1. 导入模块

import logging
import os
import sys
from dataclasses import dataclass, field

import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint

from open_r1.configs import GRPOConfig
from open_r1.rewards import (
    accuracy_reward,
    format_reward,
    get_cosine_scaled_reward,
    get_repetition_penalty_reward,
    reasoning_steps_reward,
)
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.logging import init_wandb_training
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config

这里导入了许多必要的模块,包括日志记录、操作系统相关、数据处理(datasets)、深度学习框架(torchtransformers),以及一些自定义模块。

2. 日志记录和常量定义

logger = logging.getLogger(__name__)

SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)
  • logger 用于记录日志信息。
  • SYSTEM_PROMPT 是一个系统提示,用于构建对话格式。

3. 定义脚本参数类 GRPOScriptArguments

@dataclass
class GRPOScriptArguments(ScriptArguments):
    ...

这是一个数据类,用于存储脚本的参数,包括奖励函数列表、余弦缩放的相关参数、重复惩罚的相关参数等,并且对每个参数都有详细的说明和默认值。

4. 主函数 main

def main(script_args, training_args, model_args):
    ...

主函数是整个脚本的核心,执行以下主要步骤:

4.1 设置随机种子
set_seed(training_args.seed)

确保实验的可重复性。

4.2 配置日志记录
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

配置日志的格式、日期格式和处理程序,并设置日志级别。

4.3 记录基本信息
logger.warning(
    f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
    + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Training parameters {training_args}")

记录进程的基本信息、模型参数、脚本参数和训练参数。

4.4 检查检查点
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
    last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
    logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

检查是否有之前的检查点,如果有且没有指定从其他检查点恢复,则从该检查点继续训练。

4.5 初始化 WandB 训练(如果需要)
if "wandb" in training_args.report_to:
    init_wandb_training(training_args)

如果指定使用 WandB 进行报告,则初始化 WandB 训练。

4.6 加载数据集
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

使用 load_dataset 函数加载指定的数据集。

4.7 获取奖励函数
REWARD_FUNCS_REGISTRY = {
    "accuracy": accuracy_reward,
    "format": format_reward,
    "reasoning_steps": reasoning_steps_reward,
    "cosine": get_cosine_scaled_reward(
        min_value_wrong=script_args.cosine_min_value_wrong,
        max_value_wrong=script_args.cosine_max_value_wrong,
        min_value_correct=script_args.cosine_min_value_correct,
        max_value_correct=script_args.cosine_max_value_correct,
        max_len=script_args.cosine_max_len,
    ),
    "repetition_penalty": get_repetition_penalty_reward(
        ngram_size=script_args.repetition_n_grams,
        max_penalty=script_args.repetition_max_penalty,
    ),
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

定义奖励函数的注册表,并根据脚本参数选择相应的奖励函数。

4.8 格式化数据集为对话格式
def make_conversation(example):
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": example["problem"]},
        ],
    }

dataset = dataset.map(make_conversation)
for split in dataset:
    if "messages" in dataset[split].column_names:
        dataset[split] = dataset[split].remove_columns("messages")

将数据集的每个示例格式化为对话格式,并移除不必要的列。

4.9 初始化模型参数
torch_dtype = (
    model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
    revision=model_args.model_revision,
    trust_remote_code=model_args.trust_remote_code,
    attn_implementation=model_args.attn_implementation,
    torch_dtype=torch_dtype,
    use_cache=False if training_args.gradient_checkpointing else True,
)
training_args.model_init_kwargs = model_kwargs

根据模型参数和训练参数初始化模型的关键字参数。

4.10 初始化 GRPO 训练器
trainer = GRPOTrainer(
    model=model_args.model_name_or_path,
    reward_funcs=reward_funcs,
    args=training_args,
    train_dataset=dataset[script_args.dataset_train_split],
    eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
    peft_config=get_peft_config(model_args),
    callbacks=get_callbacks(training_args, model_args),
)

使用 GRPOTrainer 类初始化训练器,传入模型、奖励函数、训练参数、数据集等。

4.11 训练模型
checkpoint = None
if training_args.resume_from_checkpoint is not None:
    checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
    checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

根据检查点情况开始训练模型,并记录训练结果和保存模型状态。

4.12 保存模型和创建模型卡片
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")

kwargs = {
    "dataset_name": script_args.dataset_name,
    "tags": ["open-r1"],
}
if trainer.accelerator.is_main_process:
    trainer.create_model_card(**kwargs)
    trainer.model.config.use_cache = True
    trainer.model.config.save_pretrained(training_args.output_dir)

保存训练好的模型,并在主进程中创建模型卡片。

4.13 评估模型(如果需要)
if training_args.do_eval:
    logger.info("*** Evaluate ***")
    metrics = trainer.evaluate()
    metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

如果指定进行评估,则使用训练好的模型进行评估,并记录评估结果。

4.14 推送模型到 Hugging Face Hub(如果需要)
if training_args.push_to_hub:
    logger.info("Pushing to hub...")
    trainer.push_to_hub(**kwargs)

如果指定将模型推送到 Hugging Face Hub,则执行推送操作。

5. 主程序入口

if __name__ == "__main__":
    parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    main(script_args, training_args, model_args)

解析命令行参数和配置文件,并调用 main 函数开始训练。

rewards.py

文件主要包含了一系列用于GRPO训练的奖励函数。下面对该文件的内容进行深度解析:

1. 模块导入

import math
import re

from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
  • math 模块用于数学计算,例如 math.cos 用于计算余弦值。
  • re 模块用于正则表达式操作,在多个奖励函数中用于文本匹配。
  • latex2sympy2_extendedmath_verify 导入的模块和函数用于处理LaTeX表达式的解析和验证。

2. accuracy_reward 函数

def accuracy_reward(completions, solution, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sol in zip(contents, solution):
        gold_parsed = parse(
            sol,
            extraction_mode="first_match",
            extraction_config=[LatexExtractionConfig()],
        )
        if len(gold_parsed) != 0:
            # We require the answer to be provided in correct latex (no malformed operators)
            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed="all",
                            units=True,
                        ),
                        # Ensures that boxed is tried first
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )
            # Reward 1 if the content is the same as the ground truth, 0 otherwise
            reward = float(verify(answer_parsed, gold_parsed))
        else:
            # If the gold solution is not parseable, we reward 1 to skip this example
            reward = 1.0
            print("Failed to parse gold solution: ", sol)
        rewards.append(reward)

    return rewards
  • 功能:检查模型生成的完成内容是否与真实答案相同,若相同则奖励为1,不同则为0。如果真实答案无法解析,则跳过该示例并给予奖励1。
  • 参数
    • completions:模型生成的完成内容列表,每个元素是一个字典列表,取其第一个元素的 "content" 作为实际完成内容。
    • solution:真实答案列表。
  • 实现步骤
    1. 提取 completions 中的内容。
    2. 遍历每个完成内容和对应的真实答案。
    3. 解析真实答案,如果解析成功则继续解析完成内容,并验证两者是否相同。
    4. 根据验证结果给予奖励。

3. format_reward 函数

def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]
  • 功能:检查模型生成的完成内容是否符合特定格式,即是否以 <think> 开头,中间有内容,然后是 <answer> 结尾。
  • 参数
    • completions:模型生成的完成内容列表。
  • 实现步骤
    1. 提取 completions 中的内容。
    2. 使用正则表达式匹配每个完成内容。
    3. 根据匹配结果给予奖励,匹配成功为1,失败为0。

4. reasoning_steps_reward 函数

def reasoning_steps_reward(completions, **kwargs):
    r"""Reward function that checks for clear step-by-step reasoning.
    Regex pattern:
        Step \d+: - matches "Step 1:", "Step 2:", etc.
        ^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
        \n- - matches bullet points with hyphens
        \n\* - matches bullet points with asterisks
        First,|Second,|Next,|Finally, - matches transition words
    """
    pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [len(re.findall(pattern, content)) for content in completion_contents]

    # Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward
    return [min(1.0, count / 3) for count in matches]
  • 功能:检查模型生成的完成内容中是否有清晰的逐步推理步骤,通过正则表达式匹配特定的标记。
  • 参数
    • completions:模型生成的完成内容列表。
  • 实现步骤
    1. 提取 completions 中的内容。
    2. 使用正则表达式查找每个完成内容中匹配的标记数量。
    3. 根据匹配数量给予奖励,匹配数量达到3个及以上则奖励为1,否则给予部分奖励。

5. get_cosine_scaled_reward 函数

def get_cosine_scaled_reward(
    min_value_wrong: float = -1.0,
    max_value_wrong: float = -0.5,
    min_value_correct: float = 0.5,
    max_value_correct: float = 1.0,
    max_len: int = 1000,
):
    def cosine_scaled_reward(completions, solution, **kwargs):
        """Reward function that scales based on completion length using a cosine schedule.

        Shorter correct solutions are rewarded more than longer ones.
        Longer incorrect solutions are penalized less than shorter ones.

        Args:
            completions: List of model completions
            solution: List of ground truth solutions

        This function is parameterized by the following arguments:
            min_value_wrong: Minimum reward for wrong answers
            max_value_wrong: Maximum reward for wrong answers
            min_value_correct: Minimum reward for correct answers
            max_value_correct: Maximum reward for correct answers
            max_len: Maximum length for scaling
        """
        contents = [completion[0]["content"] for completion in completions]
        rewards = []

        for content, sol in zip(contents, solution):
            gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
            if len(gold_parsed) == 0:
                rewards.append(1.0)  # Skip unparseable examples
                print("Failed to parse gold solution: ", sol)
                continue

            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed=True,
                            units=True,
                        ),
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )

            is_correct = verify(answer_parsed, gold_parsed)
            gen_len = len(content)

            # Apply cosine scaling based on length
            progress = gen_len / max_len
            cosine = math.cos(progress * math.pi)

            if is_correct:
                min_value = min_value_correct
                max_value = max_value_correct
            else:
                # Swap min/max for incorrect answers
                min_value = max_value_wrong
                max_value = min_value_wrong

            reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
            rewards.append(float(reward))

        return rewards

    return cosine_scaled_reward
  • 功能:根据完成内容的长度使用余弦调度来计算奖励,较短的正确答案会得到更高的奖励,较长的错误答案会受到较少的惩罚。
  • 参数
    • min_value_wrong:错误答案的最小奖励,默认为 -1.0。
    • max_value_wrong:错误答案的最大奖励,默认为 -0.5。
    • min_value_correct:正确答案的最小奖励,默认为 0.5。
    • max_value_correct:正确答案的最大奖励,默认为 1.0。
    • max_len:用于缩放的最大长度,默认为 1000。
  • 实现步骤
    1. 提取 completions 中的内容。
    2. 遍历每个完成内容和对应的真实答案。
    3. 解析真实答案和完成内容,并验证两者是否相同。
    4. 根据完成内容的长度计算进度,并使用余弦函数进行缩放。
    5. 根据验证结果和缩放值计算奖励。

6. get_repetition_penalty_reward 函数

def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
    """
    Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373.
    Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py

    Args:
    ngram_size: size of the n-grams
    max_penalty: Maximum (negative) penalty for wrong answers
    """
    if max_penalty > 0:
        raise ValueError(f"max_penalty {max_penalty} should not be positive")

    def zipngram(text: str, ngram_size: int):
        words = text.lower().split()
        return zip(*[words[i:] for i in range(ngram_size)])

    def repetition_penalty_reward(completions, **kwargs) -> float:
        """
        reward function the penalizes repetitions
        ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py

        Args:
            completions: List of model completions
        """

        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        for completion in contents:
            if completion == "":
                rewards.append(0.0)
                continue
            if len(completion.split()) < ngram_size:
                rewards.append(0.0)
                continue

            ngrams = set()
            total = 0
            for ng in zipngram(completion, ngram_size):
                ngrams.add(ng)
                total += 1

            scaling = 1 - len(ngrams) / total
            reward = scaling * max_penalty
            rewards.append(reward)
        return rewards

    return repetition_penalty_reward
  • 功能:计算N-gram重复惩罚,用于惩罚生成内容中的重复内容。
  • 参数
    • ngram_size:N-gram的大小。
    • max_penalty:最大(负)惩罚值。
  • 实现步骤
    1. 检查 max_penalty 是否为正数,若是则抛出异常。
    2. 定义 zipngram 函数用于生成N-gram。
    3. 定义 repetition_penalty_reward 函数,遍历每个完成内容,计算N-gram的重复率,并根据重复率和 max_penalty 计算奖励。
Logo

更多推荐