Skip to content

johnnycrab/easy-GRPO

Repository files navigation

easy-GRPO

Group Relative Policy Optimization (GRPO) for the Countdown task, using basic transformers-library functionality and PyTorch's native DDP. Why? Let's start with some blah blah.

Yet another GRPO for Countdown? Motivation for this repo

If you've stumbled upon this repository, you might wonder: Aren't there enough other repositories that try to do the exact same thing: Applying Group Relative Policy Optimization to the Countdown task? And you're definitely right, since the popular TinyZero project, plenty of other implementations / blog posts have come up. However, many of them build on other bigger "RL for LLM"-frameworks such as veRL, use the convenience GRPO Trainer from Hugging Face (e.g., here), or use vllm for generation (e.g., simple_GRPO), wrapping deepspeed around.

I wanted to make an implementation that only relies on "standard" (I don't even know why I'm using this word) Hugging Face functionalities (like tokenizer, models and sequence generation, but not the convenience GRPO Trainer obviously), and PyTorch's standard DistributedDataParallel (DDP) module to allow training the model on multiple GPUs.

The repository that comes closest to this is GRPO-Zero, and I like it a lot. It's a super clean implementation with minimal dependencies. And my code is partly loosely based on it, with explicit one-to-one borrows (like the reward functions), so thank you GRPO-Zero! My change here is that it supports multi-GPU training with DDP with (hopefully) minimal added complexity and we allow ourselves to load and use the models with the transformers library (in contrast to GRPO, where the network is set up manually).

Finally, I want to give here a detailed walk through the code. So consider this README a writeup that will be helpful maybe to anyone getting into the topic and wants some guidance accompanying reading through the code.

Structure of this README

We will start with the general (file) structure, the dataset, and what the code is doing on a high level, and how to run it. Then, a tour of the steps in GRPO (with some modifications coming from Dr. GRPO). Then, tour of the code in main.py to help understanding. Finally, some results (spoiler: we solve the problem in around 90% of all cases with Qwen-2.5-3B).

General structure and flow

The overall goal of the code is to take a pretrained LLM (as other repositories, we tried with Qwen2.5 models, but feel free to use others), and let it learn with GRPO how to solve the Countdown task.

Countdown

An instance of the Countdown task is that you are given 3 or 4 integers and a target integer, and the task is to build an arithmetic expression (using +, -, *, / and parantheses only) that evaluates to the target value. For example, for numbers = [48, 74, 90, 67]; target = 64 a solution would be 67 - 48 / (90-74). To get a list of these examples to train on, we use the dataset by Jiayi-Pan on Hugging Face, that consists of 490k examples of the form {'target': int, 'nums': list[int]}. Everything related to the countdown task is found in the file countdown.py, which is based on GRPO-Zero's implementation:

  • Since we don't want to feed the raw numbers and targets to the LLM, we embed items in the dataset within a prompt template. This is done in generate_prompt.
  • The reward_function takes in a completion (an opening "<think>"-tag is added before the completion) and gives a reward of the form 0.1 * format_reward_function + answer_reward_function.
  • Here, format_reward_function rewards that the completion is of the form
<think>Model reasoning</think>
<answer>Arithmetic expression</answer>
  • answer_reward_function rewards if all numbers given are used exactly once (and nothing more) and that the target value is reached.

High-level flow

We use DDP for multiple GPUs on a single node, i.e., multiple processes run the same code but use different devices and occasionally synchronize. We refer to each process as "rank" with a consecutive number (if you run on a machine with 4 GPUs, then you will have 4 ranks: rank 0, rank 1, rank 2, rank 3). This is just for future terminology, if you're unfamiliar with DDP, check out PyTorch's tutorial.

This means that we start with loading the model, the tokenizer, and the dataset. We then map every item in the dataset to the prompt template. The dataset is distributed between the ranks for training and validation (such that not every process uses the same data). The rest of the script is just a loop over the training steps, where at each step each rank takes a few items from the dataset, samples completions for them and uses them to update the model. We'll get into details of that later. During evaluation, we do "best-of-n"-sampling, that is, for each item in the validation dataset, we sample $n$ completions and take the one with the highest reward. After a round of evaluation, we print the first three completions.

Sidenote: Since we only have one holdout dataset to evaluate the training performance every x steps, and because I'm lazy, both here and in the code I'm using the words validation/evaluation interchangeably.

Files

Apart from countdown.py, which is described above, there is

  • main.py: All training and generation happens in here, for easier understanding without needing to hop between files.
  • logger.py: Tiny logger, that does two things: First, it's a wrapper around print, because we only want to print from rank 0. Second, it's a wrapper around logging to MLFlow (if you want). If you want to log to MLFlow – which is completely optional – create a .env-file that contains the environment variables for MLFlow configuration (see MLFlow docs). Also change the log = Logger(rank=rank, do_mlflow=True)
  • config.yaml: Config file, surprise. Every parameter is commented, I hope the explanations are clear. If not, please let me know.

How to run

To setup the project and assuming you're using uv, run

$ uv sync

But there's also a plain requirements.txt file, if preferred.

We run main.py with torchrun to kick off DDP. Using uv:

$ OMP_NUM_THREADS=$NUM_THREADS uv run torchrun --nproc-per-node=$NUM_PROCS main.py

Here:

  • $NUM_PROCS is the number of ranks you want to spawn. Note that each rank (starting from zero) will use the device cuda:{rank}. If I have two GPUs (cuda:0 and cuda:1), I'll set --nproc-per-node=2.
  • $NUM_THREADS is a fitting value for your hardware. I usually simply set it to the number of available CPU threads divided by the number of processes (--nproc-per-node). So if I have 64 threads and 2 GPUs, I use a value of 32.

And the training begins! (main.py does start initially with a round of validation, though, to get a baseline).

Algorithm

We use the GRPO algorithm with a few modifications from Dr. GRPO ("GRPO done right"). These are the changes that we use:

  • No KL-divergence term. In particular, this means that we don't need to keep a reference policy, which lifts some computational burden.
  • Advantages are not scaled with the inverse standard deviation.
  • The loss across a batch of sequences is divided by the size of the batch times the maximum completion length.
  • Since our implementation is "sample completions for batch -> training step using all completions -> repeat for new batch", this means that we don't do multiple updates for a batch of generations. Thus, we don't need importance sampling weights / clipping, and we end up with an advantage weighted REINFORCE update.

The above points will be clear in the loss formula below.

GRPO data generation and objective

For each training step, each rank does the following:

  1. Take $N$ questions/prompts $q_1, \dots, q_N$ from the training dataset.

  2. For each question $q_i$, we sample $G$ completions/answers $a_{i,1}, \dots, a_{i,G} $, Here, $G$ is the group size, and since we will train on all generated completions, note that the full batch size $B$ for the training step will be $B = G \cdot N$.

    Sidenote: In config.yaml, $N$ is set via training.num_samples_per_step and $G$ is set by passing the tuple $ [G,G]$ to training.k_best_of_n (we generalize GRPO to $k$ best of $n$ sampling, more on that later).

  3. For each completion $a_{i,j}$, we compute its real-valued reward $r_{i,j}$, and its advantage $A_{i,j}$ based on the reward for each group:

    $$\displaystyle A_{i,j} := r_{i,j} - \frac{1}{G}\sum_{k=1}^{G}r_{i,k}.$$

  4. Let $a_{i,j}[t]$ be the $t$-th token of answer $a_{i,j}$ with length $|a_{i,j}|$. Let $L$ be the maximum completion length (given in config.yaml by training.max_new_gen_tokens). Then the training loss for the batch is given by:

    $$\displaystyle \mathcal{L}(\theta) = - \frac{1}{B \cdot L}\sum_{i=1}^{N}{\sum_{j=1}^{G}\sum_{t=1}^{|a_{i,j}|}{A_{i,j}\cdot \log\ \pi_\theta\left( a_{i,j}[t]\ |\ q_i, a_{i,j}[&lt;t] \right)}},$$

    where $a_{i,j}[&lt;t]$ denotes the completion up to (excluding) the $t$-th token.

    Sidenote: Giving credit where credit's due, the above loss is actually a form of self-critical REINFORCE that was introduced in a different optimization context at NeurIPS 2020 in the POMO paper.

During evaluation, we sample for each question a set of answers (in the config, given by validation.best_of_n) and take as final answer the one with the highest reward.

Alternative training: k-best-of-n imitation

Another possibility to let model self-improve over time is to fine-tune it on only the best answers from the sampled ones. That is, instead of taking all $G$ answers per question for training (weighted by their advantage), we only take the top $k$ answers ranked by their reward (e.g. $k=1$ oer $2$) and finetune the model on these high-ranking answers. In the loss above, this simply means that we only sum over the top $k$ answers (not all $G$), and simply set the advantage for each answer to $A_{i,j} := 1$. Because then, the loss above essentially becomes a cross-entropy loss on the highest ranking samples.

This training can be enabled in config.yaml by setting training.learning_type: "sil" instead of "grpo" ("sil" stands for "self-improvement learning", as this is sometimes referred to). Also, set training.k_best_of_n to a tuple [k,n], meaning that we finetune on the best $k$ answers over $n$ samples per question.

(This is also the reason why we set training.k_best_of_n: [G,G] for GRPO, since in GRPO we consider all $G$ of $G$ answers.)

Code walkthrough

Sometimes when I read through code, I'm really grateful when there's a rough overview of what is happening where and why. In this section, I want to do the same thing. Luckily, I'll go only over main.py, since we talked about the other files (config, logger and countdown) before, and everything else is happening in this one main.py-file.

  1. The main-method builds up everything and calls run_training after building up everything. Note that we wrap the model in a DDP call, such that gradients are synchronized during training. In particular, in main, the dataset is sliced depending on the rank, and mapped to the prompt template. We often call dist.barrier() to synchronize the processes.
  2. We have a dataclass Trajectory which we use to represent a prompt/question together with a completion/answer.
  3. In run_training, we iterate over the training steps, and start off with a round of evaluation (to get a baseline performance for the model). I'll talk about the evaluate-method that's called in a second. If we get a model with a better accuracy on the validation set, the model checkpoint is saved. We don't do anything with the saved checkpoint actually, we just save it, so there's no continue-training functionality.
  4. Let's talk about the evaluate-method. Here, we call the heart sample_trajectories to get best-of-$n$ answers and their rewards for the questions. Since every rank has its own slice of the validation dataset, we collect all results at rank 0 only, which calls dist.recv_object_list on all other ranks. Meanwhile, the other ranks send their results via dist.send_object_list.
  5. The method sample_trajectories does the main work, I'd say, by sampling completions for a set of questions. One thing to note: For sampling, we call the .generate convenience method that is provided by the transformers library. Since most models auto-loaded from Hugging Face come with a predefined generation config, we make sure that we have a fresh generation config with really only the things we need in the method get_generation_config. For example, it took stupid me three hours to realize that the default sampling temperature in the Qwen models is set to 0.7, and I was wondering why the model stops learning so quickly.
  6. Back to the main training loop. While the wranging of the tensors and their shapes should be commented in the file and hopefully understandable, one important thing: while each rank has its own part of the batch, within a rank we also accumulate gradients over multiple micro-batches. This is why we call model.no_sync() for all but the last micro-batch to keep DDP from all reducing the gradients across ranks. Also note that each rank only needs to take care of averaging across their own micro-batches, since DDP averages the gradients across ranks.

Results

I made some runs on a machine with 4 H100s (so 4 ranks), but you don't need a similar setup.

The Qwen2.5-3B base model starts at around 50% of accuracy with best-of-n sampling with $n$=16. we reach around 92% of accuracy after ~200 training steps of GRPO, where at each step, one rank considers 32 questions and samples 16 answers for that (so for all 4 ranks combined, this makes a batch size of $4 \cdot 32 \cdot 16 = 2048$ ). We evaluate every 10 steps, using best-of-n sampling with $n=16$, and the total size of the validation set is 128. Maximum completion length is 1024 tokens. The training time for 200 steps was ~11 hours with this setup.

With the same setup, but k-best-of-n finetuning / self-improvement ("sil") with $k = 2$, $n=32$ and 64 questions per rank and step (so batch size is $4 \cdot 2 \cdot 64 = 512$ , we reach around the same performance (minimally better) in around 300 steps, but with higher sampling burden.

Below are two accuracy plots of these runs, with evaluation every ten training steps:

Reasoning traces

The better the model gets, the more the reasoning traces become pretty trial-and-error. But I mean, it's only doing this countdown thing, and it's getting good at what it's doing, so who am I to judge. Here are two sample outputs from the 3B model:

Numbers = [72, 27, 36], target = 29

<think>We need to find an expression using 72, 27, and 36 that equals 29.
Let's try different combinations of the operations:
1. (72 - 36) - 27 = 36 - 27 = 9 (Not 29)
2. (72 - 36) + 27 = 36 + 27 = 63 (Not 29)
3. (72 + 36) - 27 = 108 - 27 = 81 (Not 29)
4. 72 / (36 - 27) = 72 / 9 = 8 (Not 29)
5. (36 / 27) + 72 = (4/3) + 72 = 1.33 + 72 = 73.33 (Not 29)
6. 72 - (36 / 27) = 72 - 4/3 = 72 - 1.333... = 70.67 (Not 29)
7. (27 + 36) - 72 = 63 - 72 = -9 (Not 29)
8. 72 / 36 + 27 = 2 + 27 = 29
</think>
<answer> 72 / 36 + 27 </answer><|im_end|>

Numbers = [48, 74, 90, 67], target = 64

<think>We need to find an expression using 48, 74, 90, and 67 that equals 64.
Let's try different combinations of the operations:
(90 - 74) + 48 + 67 = 16 + 48 + 67 = 64 + 67 = 131 (Not 64)
(90 - 74) - (67 - 48) = 16 - 19 = -3 (Not 64)
(67 - 48) + (90 - 74) = 19 + 16 = 35 (Not 64)
(90 + 74) - (67 + 48) = 164 - 115 = 49 (Not 64)
(48 * 74) / (90 - 67) = 3672 / 23 = 159 (Not 64)
(90 - 67) * (74 - 48) = 23 * 26 = 599 (Not 64)
(74 - 67) * (90 - 48) = 7 * 42 = 294 (Not 64)
(90 + 48) - (74 + 67) = 138 - 141 = -3 (Not 64)
48 + (90 / (74 - 67)) = 48 + (90 / 7) = 48 + 12.857... (Not 64)
(90 / (74 - 48)) + 67 = (90 / 26) + 67 = 3.461... + 67 (Not 64)
(48 + 74) - (90 / 67) = 122 - 1.342... (Not 64)
48 * (90 / (74 + 67)) = 48 * (90 / 141) = 48 * (30 / 47) = 360 / 47 (Not 64)
Let's try: (74 - 90) + (48 - 67) = -16 + (-19) = -35 (Not 64)
(90 - 74) + (67 - 48) = 16 + 19 = 35 (Not 64)
(90 + 67) - (74 - 48) = 157 - 26 = 131 (Not 64)
48 + 74 - (90 / 67) = 122 - 1.342... (Not 64)
(74 - 67) + (90 / 48) = 7 + 1.875 = 8.875 (Not 64)
(67 - 48) + (90 / 74) = 19 + 1.205... (Not 64)
(90 - 67) - (74 - 48) = 23 - 26 = -3 (Not 64)
48 / (74 - 90) + 67 = 48 / (-16) + 67 = -3 + 67 = 64
</think>
<answer> 48 / (74 - 90) + 67 </answer><|im_end|>

About

A simple and explained implementation of (Dr.) GRPO in PyTorch.

Topics

Resources

Stars

Watchers

Forks

Languages