Skip to content

updates GRPOTrainer compatible with trl 0.17 #3969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
hjh0119 committed Apr 29, 2025
commit 6656359022c6ede4ebaaf29d255d91a6aad854e9
2 changes: 1 addition & 1 deletion examples/train/grpo/multi_node/Qwen2_5_32B_full.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# NODE1 for vLLM Server
CUDA_VISIBLE_DEVICES=0,1 \
swift deploy \
swift rollout \
--model Qwen/Qwen2.5-32B-Instruct \
--infer_backend vllm \
--use_async_engine false \
Expand Down
13 changes: 6 additions & 7 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,9 @@ def _fast_infer(self, inputs: InputsType) -> Tuple[InputsType, OutputsType]:
# for example, 2 workers, 6 inputs, 0/2/4 dispatch to the first worker
# 1/3/5 dispatch to the second worker
# trying to shuffle and average the length
distributed_idx = round_robin(len(all_inputs), get_node_setting()[1] * self.args.num_infer_workers)
nnodes = get_node_setting()[1]
num_workers = 1 if self.is_external_vllm else nnodes
distributed_idx = round_robin(len(all_inputs), num_workers * self.args.num_infer_workers)
if self.infer_rank >= 0:
_input_slice = np.array(all_inputs)[distributed_idx[self.infer_rank]]
if self.args.async_generate:
Expand Down Expand Up @@ -865,7 +867,7 @@ def _generate_and_score_completions(self, inputs: InputsType) -> InputsType:
# Log metrics
messages = [inputs[i]['messages'][:-1] for i in range(len(inputs))]

self._log_metrics(batch_encoded_inputs, messages, completions, total_rewards, total_rewards_per_func) # TODO
self._log_metrics(batch_encoded_inputs, messages, completions, total_rewards, total_rewards_per_func)

return batch_encoded_inputs

Expand Down Expand Up @@ -1047,13 +1049,10 @@ def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> Li
batch_advantages
})

# Compute log probabilities
with torch.no_grad():
# Old policy logps
batch_encoded_inputs['old_per_token_logps'] = (
self._get_per_token_logps(self.model, batch_encoded_inputs) if self.old_policy else None)

# Reference policy logps
if self.beta == 0.0:
ref_per_token_logps = None
elif self.ref_model is not None:
Expand Down Expand Up @@ -1413,9 +1412,9 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non

logs = {**logs, **metrics}
if version.parse(transformers.__version__) >= version.parse('4.47.0.dev0'):
Trainer.log(self, logs, start_time)
super().log(logs, start_time)
else: # transformers<=4.46
Trainer.log(self, logs)
super().log(logs)
self._metrics[mode].clear()

if self.accelerator.is_main_process and self.log_completions:
Expand Down