Skip to content

updates GRPOTrainer compatible with trl 0.17 #3969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
hjh0119 committed Apr 27, 2025
commit bec8015d68f3ddc7741c21f67b326511eea7c419
4 changes: 2 additions & 2 deletions docs/source/Instruction/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

环境安装
```bash
pip install math_verify # reward function
pip install math_verify==0.5.2 # reward function
pip install -U trl
```

Expand Down Expand Up @@ -137,7 +137,7 @@ A conversation between User and Assistant. The user asks a question, and the Ass

## 参数与运行脚本
参数
- num_generations: 每个prompt采样的数量,论文中的G值,需要被 per_device_batch_size * nproc_per_node 整除
- num_generations: 每个prompt采样的数量,论文中的G值,需要被 per_device_batch_size * gradient_accumulation_steps * nproc_per_node 整除,默认为8
- max_completion_length: 采样生成的最大长度,默认为512
- ds3_gather_for_generation: 该参数适用于DeepSpeed ZeRO-3。如果启用,策略模型权重将被收集用于生成,从而提高生成速度。然而,禁用此选项允许训练超出单个GPU VRAM的模型,尽管生成速度会变慢。禁用此选项与vLLM生成不兼容。默认为True
- reward_funcs: 奖励函数,根据模型生成结果进行打分,内置accuracy、format、cosine和repetition四个rule-based函数,详细见 swift/plugin/orm.py 文件
Expand Down
4 changes: 3 additions & 1 deletion docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,9 @@ reward模型参数将在PPO、GRPO中使用。


#### GRPO参数
- num_generations: GRPO算法中的G值,默认为8
- per_device_train_batch_size: 每个设备训练批量大小,在GRPO中,指 completion 的批次大小。
- per_device_eval_batch_size: 每个设备评估批量大小,在GRPO中,指 completion 的批次大小。
- num_generations: 每个prompt采样的数量,论文中的G值,需要被 per_device_batch_size * gradient_accumulation_steps * nproc_per_node 整除,默认为8
- max_completion_length: GRPO算法中的最大生成长度,默认为512
- ds3_gather_for_generation: 该参数适用于DeepSpeed ZeRO-3。如果启用,策略模型权重将被收集用于生成,从而提高生成速度。然而,禁用此选项允许训练超出单个GPU VRAM的模型,尽管生成速度会变慢。禁用此选项与vLLM生成不兼容。默认为True
- reward_funcs: GRPO算法奖励函数,可选项为`accuracy`、`format`、`cosine` 和 `repetition`,见swift/plugin/orm.py。你也可以在plugin中自定义自己的奖励函数。默认为`[]`
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ The meanings of the following parameters can be referenced [here](https://huggin


#### GRPO Arguments
- num_generations: The G value in the GRPO algorithm, default is 8.
- num_generations: The number of samples for each prompt, referred to as the G value in the paper, needs to be divisible by per_device_batch_size * - gradient_accumulation_steps * nproc_per_node, default is 8.
- max_completion_length: The maximum generation length in the GRPO algorithm, default is 512.
- ds3_gather_for_generation: This parameter applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, improving generation speed. However, disabling this option allows training models that exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible with vLLM generation. The default is True.
- reward_funcs: Reward functions in the GRPO algorithm; options include `accuracy`,`format`,`cosine` and `repetition`, as seen in `swift/plugin/orm.py`. You can also customize your own reward functions in the plugin. Default is `[]`.
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Instruction/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ In addition to rule-based reward functions, this framework also supports using r
## Arguments and Execution Script
Arguments

- num_generations: The number of samples for each prompt, referred to as the G value in the paper, needs to be divisible by per_device_batch_size * - nproc_per_node.
- num_generations: The number of samples for each prompt, referred to as the G value in the paper, needs to be divisible by per_device_batch_size * - gradient_accumulation_steps * nproc_per_node, default is 8.
- max_completion_length: The maximum length for sampling generation, default is 512.
- ds3_gather_for_generation: This parameter applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, improving generation speed. However, disabling this option allows training models that exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible with vLLM generation. The default is True.
- reward_funcs: Reward functions to score the results generated by the model. Includes built-in accuracy, format , cosine and repetition rule-based functions, detailed in the swift/plugin/orm.py file.
Expand Down
9 changes: 5 additions & 4 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,14 +949,15 @@ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions):

return inputs, rewards, rewards_per_func, completions

def _encode_and_prepare_inputs(self, batch):
def _encode_and_prepare_inputs(self, batch, logits_to_keep=None):
template = self.template
with self._template_context(template):
batch_encoded_inputs = [template.encode(infer_request) for infer_request in batch]
batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.model.device)

labels = batch_encoded_inputs.pop('labels')
logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item()
if logits_to_keep is None:
logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item()
batch_encoded_inputs['logits_to_keep'] = logits_to_keep
batch_encoded_inputs['completion_mask'] = labels[:, -logits_to_keep:] != -100
return batch_encoded_inputs
Expand Down Expand Up @@ -1148,10 +1149,10 @@ def _get_per_token_logps(self, model, encoded_inputs, raw_inputs, batch_size=Non
batch_size = batch_size or input_ids.size(0)
effective_batch_size = input_ids.size(0)
all_logps = []
logits_to_keep = encoded_inputs['logits_to_keep']
for i in range(0, effective_batch_size, batch_size):
raw_inputs_batch = raw_inputs[i:i + batch_size]
encoded_inputs_batch = self._encode_and_prepare_inputs(raw_inputs_batch)
logits_to_keep = encoded_inputs_batch['logits_to_keep']
encoded_inputs_batch = self._encode_and_prepare_inputs(raw_inputs_batch, logits_to_keep=logits_to_keep)
input_ids = encoded_inputs_batch['input_ids']
inputs = {
k: v
Expand Down