diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 4c03bfa74f..eb03742f39 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -1008,7 +1008,7 @@ def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> Li batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.model.device) # Process labels and masks - labels = batch_encoded_inputs['labels'] + labels = batch_encoded_inputs.pop('labels') logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() batch_encoded_inputs.update({ 'completion_mask': @@ -1164,7 +1164,10 @@ def _get_per_token_logps(self, model, inputs): logits_to_keep = inputs['logits_to_keep'] input_ids = inputs['input_ids'] unwrapped_model = self.accelerator.unwrap_model(model) - parameters = inspect.signature(unwrapped_model.forward).parameters + if is_peft_model(unwrapped_model): + parameters = inspect.signature(unwrapped_model.base_model.model.forward).parameters + else: + parameters = inspect.signature(unwrapped_model.forward).parameters if not unwrapped_model.model_meta.is_multimodal and 'logits_to_keep' in parameters: # save memory return super()._get_per_token_logps(model, input_ids, inputs['attention_mask'], logits_to_keep)