Skip to content

Fix grpo eval when gas > 1 #4057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 1, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 1 addition & 32 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,37 +959,6 @@ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions):

return inputs, rewards, rewards_per_func, completions

def _encode_and_prepare_inputs(self, batch):
"""Process input batch into model-ready format with gradient accumulation support.

Args:
batch: Input data batch with shape [gas*bs, ...], where gas is gradient
accumulation steps and bs is batch size

Returns:
List of encoded inputs with shape [gas, bs, ...] ready for model forward pass
"""
template = self.template
ga_batch_encoded_inputs = []

with self._template_context(template):
gas = self.args.gradient_accumulation_steps
mode = 'train' if self.model.training else 'eval'
bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size
for i in range(gas):
start_idx = i * bs
end_idx = (i + 1) * bs
batch_encoded_inputs = [template.encode(infer_request) for infer_request in batch[start_idx:end_idx]]

batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.model.device)
labels = batch_encoded_inputs.pop('labels')
last_non_padding = torch.ne(labels, -100).int().argmax(-1)
logits_to_keep = (labels.shape[-1] - last_non_padding).max().item()
batch_encoded_inputs.update({'completion_mask': labels[:, -logits_to_keep:] != -100})

ga_batch_encoded_inputs.append(batch_encoded_inputs)
return ga_batch_encoded_inputs

def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> List[InputsType]:
"""
Prepare the final batch inputs with advantages, ref/old_policy logps and other fields for RL training.
Expand Down Expand Up @@ -1021,7 +990,7 @@ def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> Li

mode = 'train' if self.model.training else 'eval'
bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size
gas = self.args.gradient_accumulation_steps
gas = self.args.gradient_accumulation_steps if mode == 'train' else 1

assert len(inputs) == bs * gas, f'Expected {bs * gas} inputs, got {len(inputs)}'
gas_chunks = [inputs[i * bs:(i + 1) * bs] for i in range(gas)]
Expand Down