Skip to content

[grpo] support gen rm #4151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
May 11, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update reward funcs prepare
  • Loading branch information
hjh0119 committed May 9, 2025
commit 567a88b16271318c960bc0aff20955ab7d26fb7b
9 changes: 6 additions & 3 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from transformers.trainer_utils import seed_worker
from trl import GRPOTrainer as HFGRPOTrainer
from trl.extras.profiling import profiling_decorator
from trl.models import prepare_deepspeed
from trl.trainer.grpo_trainer import nanmax, nanmin

from swift.llm import InferRequest, MultiModelKeys, RequestConfig, RowPreprocessor, get_model_arch, to_device
Expand Down Expand Up @@ -325,9 +326,11 @@ def __init__(self,

self.model_accepts_loss_kwargs = False
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel) and is_deepspeed_zero3_enabled():
from trl.models.utils import prepare_deepspeed
prepare_deepspeed(reward_func, self.accelerator) # Does not wrap DeepSpeedEngine
if self.is_deepspeed_enabled:
self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
else:
self.reward_funcs[i] = self.accelerator.prepare_model(
reward_func, evaluation_mode=True, device_placement=True)

# Multi-step
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
Expand Down
Loading