Skip to content

Commit 526f8a4

Browse files
authored
Decouple vLLM engine and GRPOTrainer (#3911)
* init * argument * externel engine warning * fix response model * update * test * adjust device warning * async engine arg * pass use_async_engine arg' * wip test * fix collective_rpc * extension cls * init vllm engine before broadcast * pass client * fix * tolist and dict * wip * fix * fix * fix * fix * fix * img process * rm pydantic available * doc * 32b full script * update * update
1 parent 225c483 commit 526f8a4

File tree

19 files changed

+586
-40
lines changed

19 files changed

+586
-40
lines changed

docs/source/Instruction/GRPO.md

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,34 @@ pip install -U trl
2020

2121
![](../../resources/grpo.png)
2222

23-
SWIFT的GRPO训练中,训练模型尽量使用可见显卡的前部分,而rollout尽量使用可见显卡的后部分。这意味着
23+
GRPO 训练框架支持集成高性能推理引擎(如 vLLM)来加速采样过程,提供以下两种部署模式:
2424

25-
- 如果命令中NPROC_PER_NODE和num_infer_workers相等,都是可见显卡数量,训练和推理放在了相同显卡上,这时需要配置sleep_level
26-
- 如果命令中NPROC_PER_NODE加上num_infer_workers等于可见显卡数量,则训练使用前部分卡,rollout使用后部分卡,这时可以配置async_generate
25+
### 1. 内部集成模式 (Internal)
26+
27+
- 在Trainer内部直接启动推理服务
28+
- 提供两种资源分配策略:
29+
- **协同模式 (Colocate)**: 训练与推理共享GPU资源
30+
- **异步模式 (Async)**: 训练与推理使用独立GPU资源
31+
32+
### GRPO训练资源配置方案
33+
| 配置场景 | NPROC_PER_NODE | num_infer_workers | 资源分配说明 |
34+
|--------------------------|----------------|------------------|------------------------|
35+
| **Colocate** | =总GPU数 | =总GPU数 | 训练和推理共享全部GPU资源 |
36+
| **Async** | =训练卡数 | =推理卡数 | 必须满足:训练卡数 + 推理卡数 = 总GPU数 |
37+
38+
**注:**
39+
1. 在Colocate模式下推荐设置`sleep_level=1`, 在模型训练时释放vLLM占用显存
40+
2. 总GPU数指可见的GPU设备总数
41+
42+
### 2. 外部服务模式 (External)
43+
连接外部的 vLLM 推理服务器
44+
使用时,使用以下参数配置外部 vLLM 服务器
45+
```bash
46+
--vllm_server_host <服务器IP> \
47+
--vllm_server_port <服务端口> \
48+
--vllm_server_timeout <超时时间> \
49+
```
2750

28-
> async_generate实际上使用了step-1的policy model,因此`clip`操作实际上不生效。如果训练中不稳定或者无法收敛,可以尝试关掉此参数。
29-
> 在我们的实际实验中,即使开启async_generate后不稳定或不收敛的情况较少出现。
3051

3152
## 奖励函数
3253
### 自定义奖励函数
@@ -126,8 +147,15 @@ A conversation between User and Assistant. The user asks a question, and the Ass
126147
- 提示:若没有设置`--report_to wandb`,则会在checkpoint中创建`completions.jsonl`来存储生成内容
127148
- use_vllm: 是否使用vLLM作为采样的生成后端,默认为False,建议使用加快训练速度
128149
- vllm_device: 设置vLLM部署的设备,默认为`auto`, 即未被使用的第一张显卡,使用`cuda:x`来设置特定的卡。
129-
- vllm_gpu_memory_utilization: vLLM透传参数
130-
- vllm_max_model_len: vLLM透传参数
150+
- vllm_gpu_memory_utilization: vllm透传参数,默认为0.9
151+
- vllm_max_model_len: vllm透传参数,默认为None
152+
- vllm_max_num_seqs: vllm透传参数,默认为256
153+
- vllm_enforce_eager: vllm透传参数,默认为False
154+
- vllm_limit_mm_per_prompt: vllm透传参数,默认为None
155+
- vllm_enable_prefix_caching: vllm透传参数,默认为True
156+
- vllm_server_host:vLLM server host地址,默认为None,使用外部vLLM server时使用
157+
- vllm_server_port vLLM server 服务端口,默认为8000
158+
- vllm_server_timeout 连接vLLM server的超时时间,默认为120s
131159
- reward_model: 同model, 使用奖励模型作为奖励函数,与reward_funcs至少需要指定一个
132160
- num_iterations: 每个批次代更新次数,默认为1.
133161
- epsilon: clip 系数,默认为0.2.
@@ -144,6 +172,10 @@ A conversation between User and Assistant. The user asks a question, and the Ass
144172
- dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。
145173
- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。
146174
- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。
175+
- vllm_server_host:vLLM server host地址,默认为None,使用外部vLLM server时使用 \
176+
- vllm_server_port vLLM server 服务端口,默认为8000 \
177+
- vllm_server_timeout 连接vLLM server的超时时间,默认为120s \
178+
147179

148180
奖励函数参数,见[内置奖励函数](#内置奖励函数)
149181

@@ -303,6 +335,12 @@ swift rlhf \
303335
--system 'examples/train/grpo/prompt.txt' \
304336
--log_completions true
305337
```
338+
多机训练参考[这里](../../../examples/train/grpo/multi_node/)
339+
340+
注:内部集成模式下,需要不同节点的GPU配置以及训练参数相同
341+
342+
343+
306344

307345
## DAPO
308346
[Decoupled Clip and Dynamic sAmpling Policy Optimization (DAPO)](https://arxiv.org/abs/2503.14476)在GRPO的基础上设置了几种trick,分别是

docs/source/Instruction/命令行参数.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,9 @@ reward模型参数将在PPO、GRPO中使用。
401401
- vllm_enforce_eager: vllm透传参数,默认为False
402402
- vllm_limit_mm_per_prompt: vllm透传参数,默认为None
403403
- vllm_enable_prefix_caching: vllm透传参数,默认为True
404+
- vllm_server_host:vLLM server host地址,默认为None,使用外部vLLM server时使用
405+
- vllm_server_port vLLM server 服务端口,默认为8000
406+
- vllm_server_timeout 连接vLLM server的超时时间,默认为120s
404407
- top_k: 默认为50
405408
- top_p: 默认为0.9
406409
- repetition_penalty: 重复惩罚项。默认为1.
@@ -468,7 +471,7 @@ soft overlong 奖励参数
468471
- 注意:在`swift app`或者`swift eval`时,默认为False
469472
- log_interval: tokens/s统计值打印间隔,默认20秒。设置为-1则不打印
470473
- max_logprobs: 最多返回客户端的logprobs数量,默认为20
471-
474+
- use_async_engine: vLLM backend下是否使用async engine,默认为True
472475

473476
### Web-UI参数
474477
- server_name: web-ui的host,默认为'0.0.0.0'

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,9 @@ The meanings of the following parameters can be referenced [here](https://huggin
412412
- vllm_enforce_eager: vLLM passthrough parameter, default is False.
413413
- vllm_limit_mm_per_prompt: vLLM passthrough parameter, default is None.
414414
- vllm_enable_prefix_caching: vLLM passthrough parameter, default is True.
415+
- vllm_server_host: The host address of the vLLM server. Default is None. This is used when connecting to an external vLLM server.
416+
- vllm_server_port: The service port of the vLLM server. Default is 8000.
417+
- vllm_server_timeout: The connection timeout for the vLLM server. Default is 120 seconds.
415418
- top_k: Default is 50.
416419
- top_p: Default is 0.9.
417420
- repetition_penalty: Repetition penalty term. Default is 1.
@@ -482,6 +485,8 @@ Deployment Arguments inherit from the [inference arguments](#inference-arguments
482485
- Note: In `swift app` or `swift eval`, the default is False.
483486
- log_interval: Interval for printing tokens/s statistics, default is 20 seconds. If set to -1, it will not be printed.
484487
- max_logprobs: Maximum number of logprobs returned to the client, with a default value of 20.
488+
- use_async_engine: Whether to use the async engine under the vLLM backend. Default is True.
489+
485490

486491
### Web-UI Arguments
487492
- server_name: Host for the web UI, default is '0.0.0.0'.

docs/source_en/Instruction/GRPO.md

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,36 @@ pip install -U trl
2222

2323
![](../../resources/grpo.png)
2424

25-
In SWIFT's GRPO training, the training model preferentially uses the front portion of the available GPUs, while the rollout process utilizes the rear portion of the available GPUs. This means:
25+
The GRPO training framework supports the integration of high-performance inference engines (such as vLLM) to accelerate the sampling process, offering the following two deployment modes:
2626

27-
- **If both `NPROC_PER_NODE` and `num_infer_workers` in the command are equal to the number of available GPUs**, training and inference are assigned to the same GPUs. In this case, you need to configure `sleep_level`.
28-
- **If the sum of `NPROC_PER_NODE` and `num_infer_workers` equals the total number of available GPUs**, training will use the front GPUs and rollout will use the rear GPUs. In this scenario, you can configure `async_generate`.
27+
### 1. Internal Integration Mode
2928

30-
> Note: async_generate uses the policy model and responses of current_step-1, so in fact the `clip` method will be ignored
31-
> If you encountered unstable in training, turn off this argument.
32-
> In our experiments, unstable cases is not frequently occurring when async_generate is true.
29+
- Launch the inference service directly within the Trainer.
30+
- Provides two resource allocation strategies:
31+
- **Colocate Mode**: Training and inference share GPU resources.
32+
- **Async Mode**: Training and inference use separate GPU resources.
33+
34+
### GRPO Training Resource Allocation Scheme
35+
36+
| Configuration Scenario | NPROC_PER_NODE | num_infer_workers | Resource Allocation Description |
37+
|-------------------------|----------------|-------------------|---------------------------------------|
38+
| **Colocate** | = Total GPUs | = Total GPUs | Training and inference share all GPU resources. |
39+
| **Async** | = Training GPUs| = Inference GPUs | Must satisfy: Training GPUs + Inference GPUs = Total GPUs. |
40+
41+
**Note:**
42+
1. In Colocate mode, it is recommended to set `sleep_level=1` to release the GPU memory occupied by vLLM during model training.
43+
2. Total GPUs refers to the total number of visible GPU devices.
44+
45+
### 2. External Service Mode
46+
47+
Connect to an external vLLM inference server.
48+
When using this mode, configure the external vLLM server with the following parameters:
49+
50+
```bash
51+
--vllm_server_host <Server IP> \
52+
--vllm_server_port <Server Port> \
53+
--vllm_server_timeout <Timeout> \
54+
```
3355

3456
## Reward Functions
3557
### Custom Reward Functions
@@ -130,8 +152,15 @@ Arguments
130152
- Note: If `--report_to wandb` is not set, a `completions.jsonl` will be created in the checkpoint to store the generated content.
131153
- use_vllm: Whether to use vLLM as the back-end for sampling generation; default is False, using it is recommended to speed up training.
132154
- vllm_device: Device for deploying vLLM, default is auto, meaning the first unused GPU. Use cuda:x to specify a particular card.
133-
- vllm_gpu_memory_utilization: vLLM pass-through parameter.
134-
- vllm_max_model_len: vLLM pass-through parameter.
155+
- vllm_gpu_memory_utilization: vLLM passthrough parameter, default is 0.9.
156+
- vllm_max_model_len: vLLM passthrough parameter, default is None.
157+
- vllm_max_num_seqs: vLLM passthrough parameter, default is 256.
158+
- vllm_enforce_eager: vLLM passthrough parameter, default is False.
159+
- vllm_limit_mm_per_prompt: vLLM passthrough parameter, default is None.
160+
- vllm_enable_prefix_caching: vLLM passthrough parameter, default is True.
161+
- vllm_server_host: The host address of the vLLM server. Default is None. This is used when connecting to an external vLLM server.
162+
- vllm_server_port: The service port of the vLLM server. Default is 8000.
163+
- vllm_server_timeout: The connection timeout for the vLLM server. Default is 120 seconds.
135164
- reward_model: Same as the model, using a reward model as a reward function. At least one of reward_funcs and reward_model needs to be specified.
136165
- num_iterations: number of iterations per batch. Default is 1.
137166
- epsilon: epsilon value for clipping. Default is 0.2.
@@ -307,6 +336,10 @@ swift rlhf \
307336
--log_completions true
308337
```
309338

339+
For multi-node training, refer to [here](../../../examples/train/grpo/multi_node/) .
340+
341+
Note : In the internal integration mode, the GPU configurations and training parameters must be identical across different nodes.
342+
310343
## DAPO
311344
Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO) introduces several tricks based on GRPO, which are:
312345
- Clip Higher
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# External vLLM
2+
3+
# Assume we have two nodes, one with 8 GPUs of 80GB each (880G) and another with 2 GPUs of 80GB each (2 80G).
4+
# NODE1. The node with 2*80G will be used to deploy the vLLM server.
5+
# NODE2. The node with 8*80G will be used for full-parameter fine-tuning of the 32B model.
6+
7+
# Note : Use beta=0 to disable the reference model; otherwise, it may lead to Out-of-Memory (OOM) errors.
8+
9+
# NODE1 for vLLM Server
10+
CUDA_VISIBLE_DEVICES=0,1 \
11+
swift deploy \
12+
--model Qwen/Qwen2.5-32B-Instruct \
13+
--infer_backend vllm \
14+
--use_async_engine false \
15+
--tensor_parallel_size 2
16+
17+
# NODE2 for Training
18+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
19+
NPROC_PER_NODE=8 \
20+
swift rlhf \
21+
--rlhf_type grpo \
22+
--model Qwen/Qwen2.5-32B-Instruct \
23+
--reward_funcs accuracy \
24+
--use_vllm true \
25+
--vllm_server_host xxx \
26+
--vllm_server_port 8000 \
27+
--train_type full \
28+
--torch_dtype bfloat16 \
29+
--dataset AI-MO/NuminaMath-TIR#1000 \
30+
--max_completion_length 2048 \
31+
--num_train_epochs 3 \
32+
--per_device_train_batch_size 1 \
33+
--per_device_eval_batch_size 1 \
34+
--learning_rate 1e-6 \
35+
--gradient_accumulation_steps 1 \
36+
--save_total_limit 2 \
37+
--logging_steps 1 \
38+
--warmup_ratio 0.05 \
39+
--dataloader_num_workers 4 \
40+
--dataset_num_proc 4 \
41+
--num_generations 8 \
42+
--temperature 1.0 \
43+
--top_p 0.9 \
44+
--top_k 50 \
45+
--deepspeed zero3 \
46+
--log_completions true \
47+
--num_iterations 1 \
48+
--num_infer_workers 1 \
49+
--report_to tensorboard wandb \
50+
--beta 0.0

examples/train/grpo/multi_node/multi_node1.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
# Internal vLLM
2+
13
# pip install math_verify # reward function
24
# pip install -U trl
5+
# note: Note: The parameters of each node need to be consistent.
36
export CUDA_VISIBLE_DEVICES=0,1,2,3
47
export NNODES=2
58
export NODE_RANK=0

examples/train/grpo/multi_node/train_dlc.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# This script is used in DLC (Deep Learning Containers)
2+
# For more information, visit:https://www.aliyun.com/activity/bigdata/pai-dlc
13
NNODES=$WORLD_SIZE \
24
NODE_RANK=$RANK \
35
PYTHONPATH=. \

swift/llm/argument/infer_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ class InferArguments(MergeArguments, VllmArguments, LmdeployArguments, BaseArgum
127127
# only for inference
128128
val_dataset_sample: Optional[int] = None
129129

130+
use_async_engine: bool = True
131+
130132
def _get_result_path(self, folder_name: str) -> str:
131133
result_dir = self.ckpt_dir or f'result/{self.model_suffix}'
132134
os.makedirs(result_dir, exist_ok=True)

swift/llm/argument/rlhf_args.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from swift.llm import MODEL_MAPPING
77
from swift.trainers.arguments import GRPOArgumentsMixin
8-
from swift.utils import get_logger, set_default_ddp_config
8+
from swift.utils import get_logger, is_master, set_default_ddp_config
99
from .train_args import TrainArguments
1010

1111
logger = get_logger()
@@ -111,9 +111,11 @@ def __post_init__(self):
111111
self._init_simpo()
112112
self._init_ppo()
113113
self._set_default()
114+
self._init_external_vllm()
114115
super().__post_init__()
115116
self._check_rlhf()
116117
self._check_grpo()
118+
self._external_vllm_warning()
117119

118120
if self.loss_scale is None:
119121
if self.rlhf_type == 'orpo' and not self.model_meta.is_multimodal:
@@ -187,6 +189,14 @@ def _init_rm(self):
187189
self.task_type = 'seq_cls'
188190
self.num_labels = 1
189191

192+
def _init_external_vllm(self):
193+
if self.rlhf_type != 'grpo' or self.vllm_server_host is None:
194+
return
195+
from swift.trainers.rlhf_trainer.vllm_client import VLLMClient
196+
if is_master():
197+
self.vllm_client = VLLMClient(
198+
self.vllm_server_host, self.vllm_server_port, connection_timeout=self.vllm_server_timeout)
199+
190200
def _set_default(self):
191201
if self.beta is None:
192202
self.beta = 0.1
@@ -208,7 +218,7 @@ def _check_grpo(self):
208218
_, _, _, local_world_size = get_dist_setting()
209219
num_infer_workers = self.num_infer_workers
210220
fast_infer = self.use_vllm or self.use_lmdeploy
211-
if fast_infer:
221+
if fast_infer and self.vllm_server_host is None:
212222
is_colocate_mode = (device_count == num_infer_workers)
213223

214224
if is_colocate_mode:
@@ -243,3 +253,23 @@ def _check_grpo(self):
243253
if self.mini_batch_size:
244254
assert self.per_device_train_batch_size % self.mini_batch_size == 0,\
245255
'per_device_train_batch_size needs be divisible by mini_batch_size'
256+
257+
def _external_vllm_warning(self):
258+
if self.rlhf_type != 'grpo' or not self.vllm_server_host:
259+
return
260+
261+
if self.vllm_device != 'auto':
262+
logger.warning("Configuration conflict: External vLLM engine detected, but 'vllm_device' is set to '%s'. ",
263+
self.vllm_device)
264+
265+
if self.num_infer_workers != 1:
266+
logger.warning(
267+
"Auto-adjustment: Changing 'num_infer_workers' from %s to 1 because external vLLM engine is detected",
268+
self.num_infer_workers)
269+
self.num_infer_workers = 1
270+
271+
if self.vllm_max_model_len is not None:
272+
logger.warning(
273+
"Configuration conflict: 'vllm_max_model_len=%s' is ignored for external vLLM. "
274+
'Please specify it when launching the inference service: '
275+
'`swift deploy --max_model_len <value>`', self.vllm_max_model_len)

0 commit comments

Comments
 (0)