Group Relative Policy Optimization (GRPO) for the Countdown task, using basic transformers-library functionality and PyTorch's native DDP. Why? Let's start with some blah blah.
If you've stumbled upon this repository, you might wonder: Aren't there enough other repositories that try to do the exact same thing: Applying Group Relative Policy Optimization to the Countdown task? And you're definitely right, since the popular TinyZero project, plenty of other implementations / blog posts have come up. However, many of them build on other bigger "RL for LLM"-frameworks such as veRL, use the convenience GRPO Trainer from Hugging Face (e.g., here), or use vllm for generation (e.g., simple_GRPO), wrapping deepspeed around.
I wanted to make an implementation that only relies on "standard" (I don't even know why I'm using this word) Hugging Face functionalities (like tokenizer, models and sequence generation, but not the convenience GRPO Trainer obviously), and PyTorch's standard DistributedDataParallel (DDP) module to allow training the model on multiple GPUs.
The repository that comes closest to this is GRPO-Zero, and I like it a lot. It's a super clean implementation with minimal dependencies. And my code is partly loosely based on it, with explicit one-to-one borrows (like the reward functions), so thank you GRPO-Zero! My change here is that it supports multi-GPU training with DDP with (hopefully) minimal added complexity and we allow ourselves to load and use the models with the transformers library (in contrast to GRPO, where the network is set up manually).
Finally, I want to give here a detailed walk through the code. So consider this README a writeup that will be helpful maybe to anyone getting into the topic and wants some guidance accompanying reading through the code.
We will start with the general (file) structure, the dataset, and what the code is doing on a high level, and how to run it. Then, a tour of the steps in GRPO (with some modifications coming from Dr. GRPO). Then, tour of the code in main.py to help understanding. Finally, some results (spoiler: we solve the problem in around 90% of all cases with Qwen-2.5-3B).
The overall goal of the code is to take a pretrained LLM (as other repositories, we tried with Qwen2.5 models, but feel free to use others), and let it learn with GRPO how to solve the Countdown task.
An instance of the Countdown task is that you are given 3 or 4 integers and a target integer, and the task is to build an arithmetic expression (using +, -, *, / and parantheses only) that evaluates to the target value. For example, for numbers = [48, 74, 90, 67]; target = 64 a solution would be 67 - 48 / (90-74). To get a list of these examples to train on, we use the dataset by Jiayi-Pan on Hugging Face, that consists of 490k examples of the form {'target': int, 'nums': list[int]}. Everything related to the countdown task is found in the file countdown.py, which is based on GRPO-Zero's implementation:
- Since we don't want to feed the raw numbers and targets to the LLM, we embed items in the dataset within a prompt template. This is done in
generate_prompt. - The
reward_functiontakes in a completion (an opening"<think>"-tag is added before the completion) and gives a reward of the form0.1 * format_reward_function + answer_reward_function. - Here,
format_reward_functionrewards that the completion is of the form
<think>Model reasoning</think>
<answer>Arithmetic expression</answer>
answer_reward_functionrewards if all numbers given are used exactly once (and nothing more) and that the target value is reached.
We use DDP for multiple GPUs on a single node, i.e., multiple processes run the same code but use different devices and occasionally synchronize. We refer to each process as "rank" with a consecutive number (if you run on a machine with 4 GPUs, then you will have 4 ranks: rank 0, rank 1, rank 2, rank 3). This is just for future terminology, if you're unfamiliar with DDP, check out PyTorch's tutorial.
This means that we start with loading the model, the tokenizer, and the dataset. We then map every item in the dataset to the prompt template. The dataset is distributed between the ranks for training and validation (such that not every process uses the same data). The rest of the script is just a loop over the training steps, where at each step each rank takes a few items from the dataset, samples completions for them and uses them to update the model. We'll get into details of that later. During evaluation, we do "best-of-n"-sampling, that is, for each item in the validation dataset, we sample
Sidenote: Since we only have one holdout dataset to evaluate the training performance every x steps, and because I'm lazy, both here and in the code I'm using the words validation/evaluation interchangeably.
Apart from countdown.py, which is described above, there is
main.py: All training and generation happens in here, for easier understanding without needing to hop between files.logger.py: Tiny logger, that does two things: First, it's a wrapper aroundprint, because we only want to print from rank 0. Second, it's a wrapper around logging to MLFlow (if you want). If you want to log to MLFlow – which is completely optional – create a.env-file that contains the environment variables for MLFlow configuration (see MLFlow docs). Also change thelog = Logger(rank=rank, do_mlflow=True)config.yaml: Config file, surprise. Every parameter is commented, I hope the explanations are clear. If not, please let me know.
To setup the project and assuming you're using uv, run
$ uv syncBut there's also a plain requirements.txt file, if preferred.
We run main.py with torchrun to kick off DDP. Using uv:
$ OMP_NUM_THREADS=$NUM_THREADS uv run torchrun --nproc-per-node=$NUM_PROCS main.pyHere:
-
$NUM_PROCSis the number of ranks you want to spawn. Note that each rank (starting from zero) will use the devicecuda:{rank}. If I have two GPUs (cuda:0andcuda:1), I'll set--nproc-per-node=2. -
$NUM_THREADSis a fitting value for your hardware. I usually simply set it to the number of available CPU threads divided by the number of processes (--nproc-per-node). So if I have 64 threads and 2 GPUs, I use a value of 32.
And the training begins! (main.py does start initially with a round of validation, though, to get a baseline).
We use the GRPO algorithm with a few modifications from Dr. GRPO ("GRPO done right"). These are the changes that we use:
- No KL-divergence term. In particular, this means that we don't need to keep a reference policy, which lifts some computational burden.
- Advantages are not scaled with the inverse standard deviation.
- The loss across a batch of sequences is divided by the size of the batch times the maximum completion length.
- Since our implementation is "sample completions for batch -> training step using all completions -> repeat for new batch", this means that we don't do multiple updates for a batch of generations. Thus, we don't need importance sampling weights / clipping, and we end up with an advantage weighted REINFORCE update.
The above points will be clear in the loss formula below.
For each training step, each rank does the following:
-
Take
$N$ questions/prompts$q_1, \dots, q_N$ from the training dataset. -
For each question
$q_i$ , we sample$G$ completions/answers$a_{i,1}, \dots, a_{i,G} $ , Here,$G$ is the group size, and since we will train on all generated completions, note that the full batch size$B$ for the training step will be$B = G \cdot N$ .Sidenote: In
config.yaml,$N$ is set viatraining.num_samples_per_stepand$G$ is set by passing the tuple $ [G,G]$ totraining.k_best_of_n(we generalize GRPO to$k$ best of$n$ sampling, more on that later). -
For each completion
$a_{i,j}$ , we compute its real-valued reward$r_{i,j}$ , and its advantage$A_{i,j}$ based on the reward for each group:$$\displaystyle A_{i,j} := r_{i,j} - \frac{1}{G}\sum_{k=1}^{G}r_{i,k}.$$ -
Let
$a_{i,j}[t]$ be the$t$ -th token of answer$a_{i,j}$ with length$|a_{i,j}|$ . Let$L$ be the maximum completion length (given inconfig.yamlbytraining.max_new_gen_tokens). Then the training loss for the batch is given by:$$\displaystyle \mathcal{L}(\theta) = - \frac{1}{B \cdot L}\sum_{i=1}^{N}{\sum_{j=1}^{G}\sum_{t=1}^{|a_{i,j}|}{A_{i,j}\cdot \log\ \pi_\theta\left( a_{i,j}[t]\ |\ q_i, a_{i,j}[<t] \right)}},$$ where
$a_{i,j}[<t]$ denotes the completion up to (excluding) the$t$ -th token.Sidenote: Giving credit where credit's due, the above loss is actually a form of self-critical REINFORCE that was introduced in a different optimization context at NeurIPS 2020 in the POMO paper.
During evaluation, we sample for each question a set of answers (in the config, given by validation.best_of_n) and take as final answer the one with the highest reward.
Another possibility to let model self-improve over time is to fine-tune it on only the best answers from the sampled ones. That is, instead of taking all
This training can be enabled in config.yaml by setting training.learning_type: "sil" instead of "grpo" ("sil" stands for "self-improvement learning", as this is sometimes referred to). Also, set training.k_best_of_n to a tuple [k,n], meaning that we finetune on the best
(This is also the reason why we set training.k_best_of_n: [G,G] for GRPO, since in GRPO we consider all
Sometimes when I read through code, I'm really grateful when there's a rough overview of what is happening where and why. In this section, I want to do the same thing. Luckily, I'll go only over main.py, since we talked about the other files (config, logger and countdown) before, and everything else is happening in this one main.py-file.
- The
main-method builds up everything and callsrun_trainingafter building up everything. Note that we wrap the model in aDDPcall, such that gradients are synchronized during training. In particular, inmain, the dataset is sliced depending on the rank, and mapped to the prompt template. We often calldist.barrier()to synchronize the processes. - We have a dataclass
Trajectorywhich we use to represent a prompt/question together with a completion/answer. - In
run_training, we iterate over the training steps, and start off with a round of evaluation (to get a baseline performance for the model). I'll talk about theevaluate-method that's called in a second. If we get a model with a better accuracy on the validation set, the model checkpoint is saved. We don't do anything with the saved checkpoint actually, we just save it, so there's no continue-training functionality. - Let's talk about the
evaluate-method. Here, we call the heartsample_trajectoriesto get best-of-$n$ answers and their rewards for the questions. Since every rank has its own slice of the validation dataset, we collect all results at rank 0 only, which callsdist.recv_object_liston all other ranks. Meanwhile, the other ranks send their results viadist.send_object_list. - The method
sample_trajectoriesdoes the main work, I'd say, by sampling completions for a set of questions. One thing to note: For sampling, we call the.generateconvenience method that is provided by the transformers library. Since most models auto-loaded from Hugging Face come with a predefined generation config, we make sure that we have a fresh generation config with really only the things we need in the methodget_generation_config. For example, it took stupid me three hours to realize that the default sampling temperature in the Qwen models is set to 0.7, and I was wondering why the model stops learning so quickly. - Back to the main training loop. While the wranging of the tensors and their shapes should be commented in the file and hopefully understandable, one important thing: while each rank has its own part of the batch, within a rank we also accumulate gradients over multiple micro-batches. This is why we call
model.no_sync()for all but the last micro-batch to keep DDP from all reducing the gradients across ranks. Also note that each rank only needs to take care of averaging across their own micro-batches, since DDP averages the gradients across ranks.
I made some runs on a machine with 4 H100s (so 4 ranks), but you don't need a similar setup.
The Qwen2.5-3B base model starts at around 50% of accuracy with best-of-n sampling with
With the same setup, but k-best-of-n finetuning / self-improvement ("sil") with
Below are two accuracy plots of these runs, with evaluation every ten training steps:
The better the model gets, the more the reasoning traces become pretty trial-and-error. But I mean, it's only doing this countdown thing, and it's getting good at what it's doing, so who am I to judge. Here are two sample outputs from the 3B model:
Numbers = [72, 27, 36], target = 29
<think>We need to find an expression using 72, 27, and 36 that equals 29.
Let's try different combinations of the operations:
1. (72 - 36) - 27 = 36 - 27 = 9 (Not 29)
2. (72 - 36) + 27 = 36 + 27 = 63 (Not 29)
3. (72 + 36) - 27 = 108 - 27 = 81 (Not 29)
4. 72 / (36 - 27) = 72 / 9 = 8 (Not 29)
5. (36 / 27) + 72 = (4/3) + 72 = 1.33 + 72 = 73.33 (Not 29)
6. 72 - (36 / 27) = 72 - 4/3 = 72 - 1.333... = 70.67 (Not 29)
7. (27 + 36) - 72 = 63 - 72 = -9 (Not 29)
8. 72 / 36 + 27 = 2 + 27 = 29
</think>
<answer> 72 / 36 + 27 </answer><|im_end|>Numbers = [48, 74, 90, 67], target = 64
<think>We need to find an expression using 48, 74, 90, and 67 that equals 64.
Let's try different combinations of the operations:
(90 - 74) + 48 + 67 = 16 + 48 + 67 = 64 + 67 = 131 (Not 64)
(90 - 74) - (67 - 48) = 16 - 19 = -3 (Not 64)
(67 - 48) + (90 - 74) = 19 + 16 = 35 (Not 64)
(90 + 74) - (67 + 48) = 164 - 115 = 49 (Not 64)
(48 * 74) / (90 - 67) = 3672 / 23 = 159 (Not 64)
(90 - 67) * (74 - 48) = 23 * 26 = 599 (Not 64)
(74 - 67) * (90 - 48) = 7 * 42 = 294 (Not 64)
(90 + 48) - (74 + 67) = 138 - 141 = -3 (Not 64)
48 + (90 / (74 - 67)) = 48 + (90 / 7) = 48 + 12.857... (Not 64)
(90 / (74 - 48)) + 67 = (90 / 26) + 67 = 3.461... + 67 (Not 64)
(48 + 74) - (90 / 67) = 122 - 1.342... (Not 64)
48 * (90 / (74 + 67)) = 48 * (90 / 141) = 48 * (30 / 47) = 360 / 47 (Not 64)
Let's try: (74 - 90) + (48 - 67) = -16 + (-19) = -35 (Not 64)
(90 - 74) + (67 - 48) = 16 + 19 = 35 (Not 64)
(90 + 67) - (74 - 48) = 157 - 26 = 131 (Not 64)
48 + 74 - (90 / 67) = 122 - 1.342... (Not 64)
(74 - 67) + (90 / 48) = 7 + 1.875 = 8.875 (Not 64)
(67 - 48) + (90 / 74) = 19 + 1.205... (Not 64)
(90 - 67) - (74 - 48) = 23 - 26 = -3 (Not 64)
48 / (74 - 90) + 67 = 48 / (-16) + 67 = -3 + 67 = 64
</think>
<answer> 48 / (74 - 90) + 67 </answer><|im_end|>