Skip to content

[grpo] code refactor #4097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 51 commits into from
May 13, 2025
Merged
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
edc1fd1
init
hjh0119 May 6, 2025
07a1040
fix default
hjh0119 May 6, 2025
0303461
fix
hjh0119 May 7, 2025
854f357
fix seed
hjh0119 May 7, 2025
7df2b5d
fix
hjh0119 May 7, 2025
fda82ee
wip
hjh0119 May 7, 2025
5d8d4a2
wip multi turn
hjh0119 May 7, 2025
ac52340
multi turn
hjh0119 May 7, 2025
578a365
fix comment
hjh0119 May 7, 2025
9a49fb5
fix peft model inspect and labels
hjh0119 May 7, 2025
5579c3e
fix multi turn
hjh0119 May 7, 2025
7de8aab
update multi turn
hjh0119 May 7, 2025
438f1f7
multi turn not remove response
hjh0119 May 8, 2025
d69a9ae
fix
hjh0119 May 8, 2025
451fd02
fix multi turn concate response
hjh0119 May 8, 2025
c3a1aa9
fix multi turn message check
hjh0119 May 8, 2025
300610e
fix infer
hjh0119 May 8, 2025
fd08ccd
external async generate
hjh0119 May 8, 2025
9da6242
clean argument check
hjh0119 May 8, 2025
8a22c9b
fix async generate
hjh0119 May 8, 2025
8ba0330
fix server infer to list
hjh0119 May 8, 2025
0926a3c
fix server infer
hjh0119 May 8, 2025
0c3827a
catch async generate error
hjh0119 May 8, 2025
fbc2b54
fix infer inputs
hjh0119 May 8, 2025
57445b4
fix async generate
hjh0119 May 8, 2025
e2330f9
fix size
hjh0119 May 8, 2025
37a06f9
remove vllm context
hjh0119 May 9, 2025
66ad138
reward model prepare ds
hjh0119 May 9, 2025
a1f1636
merge main
hjh0119 May 12, 2025
f4a05d3
lint
hjh0119 May 12, 2025
2b5198e
fix multi turn + TP
hjh0119 May 12, 2025
a479465
external path image
hjh0119 May 12, 2025
1fb25db
fix async generate and doc
hjh0119 May 12, 2025
7394dc9
update doc
hjh0119 May 12, 2025
4160ad3
remove async mode script
hjh0119 May 12, 2025
47bb902
doc wip and deprecate patch
hjh0119 May 12, 2025
37c68d2
lint
hjh0119 May 12, 2025
f7700fa
doc and scipt wip
hjh0119 May 13, 2025
6a572fa
doc update
hjh0119 May 13, 2025
4afbdc3
doc
hjh0119 May 13, 2025
df2ce3d
doc update
hjh0119 May 13, 2025
b101e4b
doc update
hjh0119 May 13, 2025
1939873
update doc and readme
hjh0119 May 13, 2025
dae81c1
update grpo doc
hjh0119 May 13, 2025
05054d0
update scripts
hjh0119 May 13, 2025
11307be
rm script
hjh0119 May 13, 2025
7bbed3f
update completion_length_limit_scope argument
hjh0119 May 13, 2025
f2b4aac
update stable doc reference
hjh0119 May 13, 2025
cb7ff52
remove lmdeploy
hjh0119 May 13, 2025
5e9e3b5
set different seed bewteen processes
hjh0119 May 13, 2025
25ac346
fix seed
hjh0119 May 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix seed
  • Loading branch information
hjh0119 committed May 7, 2025
commit 854f3571f93e8f6847a9815e58f348619ae4877b
10 changes: 6 additions & 4 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
from swift.llm.template.template_inputs import StdTemplateInputs
from swift.plugin import orms
from swift.plugin.multi_turn import multi_turns
from swift.utils import (JsonlWriter, gc_collect, get_logger, get_node_setting, is_lmdeploy_available, is_vllm_available, is_wandb_available)
from swift.utils import (JsonlWriter, gc_collect, get_device, get_logger, get_node_setting, is_lmdeploy_available,
is_vllm_available, is_wandb_available)
from ..mixin import SwiftMixin
from .rlhf_mixin import RLHFTrainerMixin
from .utils import patch_lora_merge, patch_lora_unmerge, round_robin, unwrap_model_for_generation
Expand Down Expand Up @@ -410,7 +411,7 @@ def prepare_vllm(self, model):
max_num_seqs = (
self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size
* self.args.gradient_accumulation_steps)

current_device = get_device()
with Swift.grpo_context(model, self.template.processor):
engine = cls(
model.model_dir,
Expand All @@ -424,6 +425,7 @@ def prepare_vllm(self, model):
limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt,
enable_sleep_mode=self.args.sleep_level > 0,
use_async_engine=False,
device=current_device,
max_model_len=self.args.vllm_max_model_len,
engine_kwargs=engine_kwargs,
**vllm_kwargs)
Expand Down Expand Up @@ -550,7 +552,8 @@ def _get_first_turn_results(self, inputs: InputsType, request_config: RequestCon
gathered_inputs = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_inputs, inputs, group=self.tp_group)
inputs = [p for sublist in gathered_inputs for p in sublist]

# confirm that the seed is same in tp group
request_config.seed = self.accelerator.process_index // self.vllm_tensor_parallel_size
results: List[ChatCompletionResponse] = self._engine_infer(
infer_requests=inputs, request_config=request_config, use_tqdm=False)

Expand Down Expand Up @@ -649,7 +652,6 @@ def _infer_single_or_multi_turn(self, inputs: InputsType,
_choices.append((_input['messages'], choice.finish_reason))
outputs.append(_choices)
assert len(outputs) == len(inputs)
assert all([len(o) == self.vllm_tensor_parallel_size for o in outputs])

return outputs

Expand Down
Loading