Skip to content

Decouple vLLM engine and GRPOTrainer. #3911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions docs/source/Instruction/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,34 @@ pip install -U trl

![](../../resources/grpo.png)

SWIFT的GRPO训练中,训练模型尽量使用可见显卡的前部分,而rollout尽量使用可见显卡的后部分。这意味着
GRPO 训练框架支持集成高性能推理引擎(如 vLLM)来加速采样过程,提供以下两种部署模式:

- 如果命令中NPROC_PER_NODE和num_infer_workers相等,都是可见显卡数量,训练和推理放在了相同显卡上,这时需要配置sleep_level
- 如果命令中NPROC_PER_NODE加上num_infer_workers等于可见显卡数量,则训练使用前部分卡,rollout使用后部分卡,这时可以配置async_generate
### 1. 内部集成模式 (Internal)

- 在Trainer内部直接启动推理服务
- 提供两种资源分配策略:
- **协同模式 (Colocate)**: 训练与推理共享GPU资源
- **异步模式 (Async)**: 训练与推理使用独立GPU资源

### GRPO训练资源配置方案
| 配置场景 | NPROC_PER_NODE | num_infer_workers | 资源分配说明 |
|--------------------------|----------------|------------------|------------------------|
| **Colocate** | =总GPU数 | =总GPU数 | 训练和推理共享全部GPU资源 |
| **Async** | =训练卡数 | =推理卡数 | 必须满足:训练卡数 + 推理卡数 = 总GPU数 |

**注:**
1. 在Colocate模式下推荐设置`sleep_level=1`, 在模型训练时释放vLLM占用显存
2. 总GPU数指可见的GPU设备总数

### 2. 外部服务模式 (External)
连接外部的 vLLM 推理服务器
使用时,使用以下参数配置外部 vLLM 服务器
```bash
--vllm_server_host <服务器IP> \
--vllm_server_port <服务端口> \
--vllm_server_timeout <超时时间> \
```

> async_generate实际上使用了step-1的policy model,因此`clip`操作实际上不生效。如果训练中不稳定或者无法收敛,可以尝试关掉此参数。
> 在我们的实际实验中,即使开启async_generate后不稳定或不收敛的情况较少出现。

## 奖励函数
### 自定义奖励函数
Expand Down Expand Up @@ -122,8 +143,15 @@ A conversation between User and Assistant. The user asks a question, and the Ass
- 提示:若没有设置`--report_to wandb`,则会在checkpoint中创建`completions.jsonl`来存储生成内容
- use_vllm: 是否使用vLLM作为采样的生成后端,默认为False,建议使用加快训练速度
- vllm_device: 设置vLLM部署的设备,默认为`auto`, 即未被使用的第一张显卡,使用`cuda:x`来设置特定的卡。
- vllm_gpu_memory_utilization: vLLM透传参数
- vllm_max_model_len: vLLM透传参数
- vllm_gpu_memory_utilization: vllm透传参数,默认为0.9
- vllm_max_model_len: vllm透传参数,默认为None
- vllm_max_num_seqs: vllm透传参数,默认为256
- vllm_enforce_eager: vllm透传参数,默认为False
- vllm_limit_mm_per_prompt: vllm透传参数,默认为None
- vllm_enable_prefix_caching: vllm透传参数,默认为True
- vllm_server_host:vLLM server host地址,默认为None,使用外部vLLM server时使用
- vllm_server_port vLLM server 服务端口,默认为8000
- vllm_server_timeout 连接vLLM server的超时时间,默认为120s
- reward_model: 同model, 使用奖励模型作为奖励函数,与reward_funcs至少需要指定一个
- num_iterations: 每个批次代更新次数,默认为1.
- epsilon: clip 系数,默认为0.2.
Expand All @@ -140,6 +168,10 @@ A conversation between User and Assistant. The user asks a question, and the Ass
- dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。
- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。
- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。
- vllm_server_host:vLLM server host地址,默认为None,使用外部vLLM server时使用 \
- vllm_server_port vLLM server 服务端口,默认为8000 \
- vllm_server_timeout 连接vLLM server的超时时间,默认为120s \


奖励函数参数,见[内置奖励函数](#内置奖励函数)

Expand Down Expand Up @@ -299,6 +331,12 @@ swift rlhf \
--system 'examples/train/grpo/prompt.txt' \
--log_completions true
```
多机训练参考[这里](../../../examples/train/grpo/multi_node/)

注:内部集成模式下,需要不同节点的GPU配置以及训练参数相同




## DAPO
[Decoupled Clip and Dynamic sAmpling Policy Optimization (DAPO)](https://arxiv.org/abs/2503.14476)在GRPO的基础上设置了几种trick,分别是
Expand Down
5 changes: 4 additions & 1 deletion docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,9 @@ reward模型参数将在PPO、GRPO中使用。
- vllm_enforce_eager: vllm透传参数,默认为False
- vllm_limit_mm_per_prompt: vllm透传参数,默认为None
- vllm_enable_prefix_caching: vllm透传参数,默认为True
- vllm_server_host:vLLM server host地址,默认为None,使用外部vLLM server时使用
- vllm_server_port vLLM server 服务端口,默认为8000
- vllm_server_timeout 连接vLLM server的超时时间,默认为120s
- top_k: 默认为50
- top_p: 默认为0.9
- repetition_penalty: 重复惩罚项。默认为1.
Expand Down Expand Up @@ -468,7 +471,7 @@ soft overlong 奖励参数
- 注意:在`swift app`或者`swift eval`时,默认为False
- log_interval: tokens/s统计值打印间隔,默认20秒。设置为-1则不打印
- max_logprobs: 最多返回客户端的logprobs数量,默认为20

- use_async_engine: vLLM backend下是否使用async engine,默认为True

### Web-UI参数
- server_name: web-ui的host,默认为'0.0.0.0'
Expand Down
5 changes: 5 additions & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,9 @@ The meanings of the following parameters can be referenced [here](https://huggin
- vllm_enforce_eager: vLLM passthrough parameter, default is False.
- vllm_limit_mm_per_prompt: vLLM passthrough parameter, default is None.
- vllm_enable_prefix_caching: vLLM passthrough parameter, default is True.
- vllm_server_host: The host address of the vLLM server. Default is None. This is used when connecting to an external vLLM server.
- vllm_server_port: The service port of the vLLM server. Default is 8000.
- vllm_server_timeout: The connection timeout for the vLLM server. Default is 120 seconds.
- top_k: Default is 50.
- top_p: Default is 0.9.
- repetition_penalty: Repetition penalty term. Default is 1.
Expand Down Expand Up @@ -482,6 +485,8 @@ Deployment Arguments inherit from the [inference arguments](#inference-arguments
- Note: In `swift app` or `swift eval`, the default is False.
- log_interval: Interval for printing tokens/s statistics, default is 20 seconds. If set to -1, it will not be printed.
- max_logprobs: Maximum number of logprobs returned to the client, with a default value of 20.
- use_async_engine: Whether to use the async engine under the vLLM backend. Default is True.


### Web-UI Arguments
- server_name: Host for the web UI, default is '0.0.0.0'.
Expand Down
49 changes: 41 additions & 8 deletions docs/source_en/Instruction/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,36 @@ pip install -U trl

![](../../resources/grpo.png)

In SWIFT's GRPO training, the training model preferentially uses the front portion of the available GPUs, while the rollout process utilizes the rear portion of the available GPUs. This means:
The GRPO training framework supports the integration of high-performance inference engines (such as vLLM) to accelerate the sampling process, offering the following two deployment modes:

- **If both `NPROC_PER_NODE` and `num_infer_workers` in the command are equal to the number of available GPUs**, training and inference are assigned to the same GPUs. In this case, you need to configure `sleep_level`.
- **If the sum of `NPROC_PER_NODE` and `num_infer_workers` equals the total number of available GPUs**, training will use the front GPUs and rollout will use the rear GPUs. In this scenario, you can configure `async_generate`.
### 1. Internal Integration Mode

> Note: async_generate uses the policy model and responses of current_step-1, so in fact the `clip` method will be ignored
> If you encountered unstable in training, turn off this argument.
> In our experiments, unstable cases is not frequently occurring when async_generate is true.
- Launch the inference service directly within the Trainer.
- Provides two resource allocation strategies:
- **Colocate Mode**: Training and inference share GPU resources.
- **Async Mode**: Training and inference use separate GPU resources.

### GRPO Training Resource Allocation Scheme

| Configuration Scenario | NPROC_PER_NODE | num_infer_workers | Resource Allocation Description |
|-------------------------|----------------|-------------------|---------------------------------------|
| **Colocate** | = Total GPUs | = Total GPUs | Training and inference share all GPU resources. |
| **Async** | = Training GPUs| = Inference GPUs | Must satisfy: Training GPUs + Inference GPUs = Total GPUs. |

**Note:**
1. In Colocate mode, it is recommended to set `sleep_level=1` to release the GPU memory occupied by vLLM during model training.
2. Total GPUs refers to the total number of visible GPU devices.

### 2. External Service Mode

Connect to an external vLLM inference server.
When using this mode, configure the external vLLM server with the following parameters:

```bash
--vllm_server_host <Server IP> \
--vllm_server_port <Server Port> \
--vllm_server_timeout <Timeout> \
```

## Reward Functions
### Custom Reward Functions
Expand Down Expand Up @@ -125,8 +147,15 @@ Arguments
- Note: If `--report_to wandb` is not set, a `completions.jsonl` will be created in the checkpoint to store the generated content.
- use_vllm: Whether to use vLLM as the back-end for sampling generation; default is False, using it is recommended to speed up training.
- vllm_device: Device for deploying vLLM, default is auto, meaning the first unused GPU. Use cuda:x to specify a particular card.
- vllm_gpu_memory_utilization: vLLM pass-through parameter.
- vllm_max_model_len: vLLM pass-through parameter.
- vllm_gpu_memory_utilization: vLLM passthrough parameter, default is 0.9.
- vllm_max_model_len: vLLM passthrough parameter, default is None.
- vllm_max_num_seqs: vLLM passthrough parameter, default is 256.
- vllm_enforce_eager: vLLM passthrough parameter, default is False.
- vllm_limit_mm_per_prompt: vLLM passthrough parameter, default is None.
- vllm_enable_prefix_caching: vLLM passthrough parameter, default is True.
- vllm_server_host: The host address of the vLLM server. Default is None. This is used when connecting to an external vLLM server.
- vllm_server_port: The service port of the vLLM server. Default is 8000.
- vllm_server_timeout: The connection timeout for the vLLM server. Default is 120 seconds.
- reward_model: Same as the model, using a reward model as a reward function. At least one of reward_funcs and reward_model needs to be specified.
- num_iterations: number of iterations per batch. Default is 1.
- epsilon: epsilon value for clipping. Default is 0.2.
Expand Down Expand Up @@ -302,6 +331,10 @@ swift rlhf \
--log_completions true
```

For multi-node training, refer to [here](../../../examples/train/grpo/multi_node/) .

Note : In the internal integration mode, the GPU configurations and training parameters must be identical across different nodes.

## DAPO
Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO) introduces several tricks based on GRPO, which are:
- Clip Higher
Expand Down
50 changes: 50 additions & 0 deletions examples/train/grpo/multi_node/Qwen2_5_32B_full.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# External vLLM

# Assume we have two nodes, one with 8 GPUs of 80GB each (880G) and another with 2 GPUs of 80GB each (2 80G).
# NODE1. The node with 2*80G will be used to deploy the vLLM server.
# NODE2. The node with 8*80G will be used for full-parameter fine-tuning of the 32B model.

# Note : Use beta=0 to disable the reference model; otherwise, it may lead to Out-of-Memory (OOM) errors.

# NODE1 for vLLM Server
CUDA_VISIBLE_DEVICES=0,1 \
swift deploy \
--model Qwen/Qwen2.5-32B-Instruct \
--infer_backend vllm \
--use_async_engine false \
--tensor_parallel_size 2

# NODE2 for Training
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=8 \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-32B-Instruct \
--reward_funcs accuracy \
--use_vllm true \
--vllm_server_host xxx \
--vllm_server_port 8000 \
--train_type full \
--torch_dtype bfloat16 \
--dataset AI-MO/NuminaMath-TIR#1000 \
--max_completion_length 2048 \
--num_train_epochs 3 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-6 \
--gradient_accumulation_steps 1 \
--save_total_limit 2 \
--logging_steps 1 \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--dataset_num_proc 4 \
--num_generations 8 \
--temperature 1.0 \
--top_p 0.9 \
--top_k 50 \
--deepspeed zero3 \
--log_completions true \
--num_iterations 1 \
--num_infer_workers 1 \
--report_to tensorboard wandb \
--beta 0.0
3 changes: 3 additions & 0 deletions examples/train/grpo/multi_node/multi_node1.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Internal vLLM

# pip install math_verify # reward function
# pip install -U trl
# note: Note: The parameters of each node need to be consistent.
export CUDA_VISIBLE_DEVICES=0,1,2,3
export NNODES=2
export NODE_RANK=0
Expand Down
2 changes: 2 additions & 0 deletions examples/train/grpo/multi_node/train_dlc.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This script is used in DLC (Deep Learning Containers)
# For more information, visit:https://www.aliyun.com/activity/bigdata/pai-dlc
NNODES=$WORLD_SIZE \
NODE_RANK=$RANK \
PYTHONPATH=. \
Expand Down
2 changes: 2 additions & 0 deletions swift/llm/argument/infer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ class InferArguments(MergeArguments, VllmArguments, LmdeployArguments, BaseArgum
# only for inference
val_dataset_sample: Optional[int] = None

use_async_engine: bool = True

def _get_result_path(self, folder_name: str) -> str:
result_dir = self.ckpt_dir or f'result/{self.model_suffix}'
os.makedirs(result_dir, exist_ok=True)
Expand Down
34 changes: 32 additions & 2 deletions swift/llm/argument/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from swift.llm import MODEL_MAPPING
from swift.trainers.arguments import GRPOArgumentsMixin
from swift.utils import get_logger, set_default_ddp_config
from swift.utils import get_logger, is_master, set_default_ddp_config
from .train_args import TrainArguments

logger = get_logger()
Expand Down Expand Up @@ -109,9 +109,11 @@ def __post_init__(self):
self._init_simpo()
self._init_ppo()
self._set_default()
self._init_external_vllm()
super().__post_init__()
self._check_rlhf()
self._check_grpo()
self._external_vllm_warning()

if self.loss_scale is None:
if self.rlhf_type == 'orpo' and not self.model_meta.is_multimodal:
Expand Down Expand Up @@ -185,6 +187,14 @@ def _init_rm(self):
self.task_type = 'seq_cls'
self.num_labels = 1

def _init_external_vllm(self):
if self.rlhf_type != 'grpo' or self.vllm_server_host is None:
return
from swift.trainers.rlhf_trainer.vllm_client import VLLMClient
if is_master():
self.vllm_client = VLLMClient(
self.vllm_server_host, self.vllm_server_port, connection_timeout=self.vllm_server_timeout)

def _set_default(self):
if self.beta is None:
self.beta = 0.1
Expand All @@ -206,7 +216,7 @@ def _check_grpo(self):
_, _, _, local_world_size = get_dist_setting()
num_infer_workers = self.num_infer_workers
fast_infer = self.use_vllm or self.use_lmdeploy
if fast_infer:
if fast_infer and self.vllm_server_host is None:
is_colocate_mode = (device_count == num_infer_workers)

if is_colocate_mode:
Expand Down Expand Up @@ -241,3 +251,23 @@ def _check_grpo(self):
if self.mini_batch_size:
assert self.per_device_train_batch_size % self.mini_batch_size == 0,\
'per_device_train_batch_size needs be divisible by mini_batch_size'

def _external_vllm_warning(self):
if self.rlhf_type != 'grpo' or not self.vllm_server_host:
return

if self.vllm_device != 'auto':
logger.warning("Configuration conflict: External vLLM engine detected, but 'vllm_device' is set to '%s'. ",
self.vllm_device)

if self.num_infer_workers != 1:
logger.warning(
"Auto-adjustment: Changing 'num_infer_workers' from %s to 1 because external vLLM engine is detected",
self.num_infer_workers)
self.num_infer_workers = 1

if self.vllm_max_model_len is not None:
logger.warning(
"Configuration conflict: 'vllm_max_model_len=%s' is ignored for external vLLM. "
'Please specify it when launching the inference service: '
'`swift deploy --max_model_len <value>`', self.vllm_max_model_len)
Loading
Loading