Skip to content

grpo liger loss #3781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 80 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
c3f859d
liger grpo loss
hjh0119 Apr 7, 2025
5224a4a
merge main
hjh0119 Apr 14, 2025
bbce4b2
update
hjh0119 Apr 14, 2025
63fdcea
fix
hjh0119 Apr 14, 2025
5915901
move args
hjh0119 Apr 14, 2025
d0c290c
fix
hjh0119 Apr 14, 2025
0a3794f
fix
hjh0119 Apr 14, 2025
3b9ee6d
fix
hjh0119 Apr 14, 2025
d643ab9
fix
hjh0119 Apr 14, 2025
93fdb71
require
hjh0119 Apr 15, 2025
f87b042
compatible with zero3
hjh0119 Apr 15, 2025
b82cbf4
fix
hjh0119 Apr 15, 2025
9c20051
merge main
hjh0119 May 1, 2025
fc7fabe
wip
hjh0119 May 1, 2025
8f67b13
update liger loss
hjh0119 May 1, 2025
8b4e346
liger&peft
hjh0119 May 1, 2025
edc1fd1
init
hjh0119 May 6, 2025
07a1040
fix default
hjh0119 May 6, 2025
0303461
fix
hjh0119 May 7, 2025
854f357
fix seed
hjh0119 May 7, 2025
7df2b5d
fix
hjh0119 May 7, 2025
fda82ee
wip
hjh0119 May 7, 2025
5d8d4a2
wip multi turn
hjh0119 May 7, 2025
ac52340
multi turn
hjh0119 May 7, 2025
578a365
fix comment
hjh0119 May 7, 2025
9a49fb5
fix peft model inspect and labels
hjh0119 May 7, 2025
5579c3e
fix multi turn
hjh0119 May 7, 2025
7de8aab
update multi turn
hjh0119 May 7, 2025
438f1f7
multi turn not remove response
hjh0119 May 8, 2025
d69a9ae
fix
hjh0119 May 8, 2025
451fd02
fix multi turn concate response
hjh0119 May 8, 2025
c3a1aa9
fix multi turn message check
hjh0119 May 8, 2025
300610e
fix infer
hjh0119 May 8, 2025
fd08ccd
external async generate
hjh0119 May 8, 2025
9da6242
clean argument check
hjh0119 May 8, 2025
8a22c9b
fix async generate
hjh0119 May 8, 2025
8ba0330
fix server infer to list
hjh0119 May 8, 2025
0926a3c
fix server infer
hjh0119 May 8, 2025
0c3827a
catch async generate error
hjh0119 May 8, 2025
fbc2b54
fix infer inputs
hjh0119 May 8, 2025
57445b4
fix async generate
hjh0119 May 8, 2025
e2330f9
fix size
hjh0119 May 8, 2025
37a06f9
remove vllm context
hjh0119 May 9, 2025
66ad138
reward model prepare ds
hjh0119 May 9, 2025
a1f1636
merge main
hjh0119 May 12, 2025
f4a05d3
lint
hjh0119 May 12, 2025
2b5198e
fix multi turn + TP
hjh0119 May 12, 2025
a479465
external path image
hjh0119 May 12, 2025
1fb25db
fix async generate and doc
hjh0119 May 12, 2025
7394dc9
update doc
hjh0119 May 12, 2025
4160ad3
remove async mode script
hjh0119 May 12, 2025
47bb902
doc wip and deprecate patch
hjh0119 May 12, 2025
37c68d2
lint
hjh0119 May 12, 2025
f7700fa
doc and scipt wip
hjh0119 May 13, 2025
6a572fa
doc update
hjh0119 May 13, 2025
4afbdc3
doc
hjh0119 May 13, 2025
df2ce3d
doc update
hjh0119 May 13, 2025
b101e4b
doc update
hjh0119 May 13, 2025
1939873
update doc and readme
hjh0119 May 13, 2025
dae81c1
update grpo doc
hjh0119 May 13, 2025
05054d0
update scripts
hjh0119 May 13, 2025
11307be
rm script
hjh0119 May 13, 2025
7bbed3f
update completion_length_limit_scope argument
hjh0119 May 13, 2025
53a08d0
merge refactor
hjh0119 May 13, 2025
829a7ea
fix epsilon
hjh0119 May 13, 2025
f2b4aac
update stable doc reference
hjh0119 May 13, 2025
cb7ff52
remove lmdeploy
hjh0119 May 13, 2025
5e9e3b5
set different seed bewteen processes
hjh0119 May 13, 2025
25ac346
fix seed
hjh0119 May 13, 2025
427a32f
merge refactor
hjh0119 May 13, 2025
c4dc72e
merge main
hjh0119 May 13, 2025
346396f
remove liger check
hjh0119 May 13, 2025
3045802
fix epsilon
hjh0119 May 13, 2025
4bf7996
remvoe unused import
hjh0119 May 14, 2025
f7080f5
Merge remote-tracking branch 'origin' into liger
hjh0119 May 22, 2025
83b3845
use_liger_kernel
hjh0119 May 22, 2025
8a10681
update
hjh0119 May 22, 2025
79834e6
Merge remote-tracking branch 'origin' into liger
hjh0119 May 22, 2025
169882f
remove require
hjh0119 May 22, 2025
3c7e763
lint
hjh0119 May 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix async generate and doc
  • Loading branch information
hjh0119 committed May 12, 2025
commit 1fb25db0235c9c9a790e1d0ac94420930fb018bf
64 changes: 44 additions & 20 deletions docs/source/Instruction/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pip install -U trl
```

**更新日志**

- **2025-05-13** — Internal部分重构,支持vLLM>=0.8
- **2025-05-11** — 支持生成式奖励模型,通过 reward_model_plugin 自定义奖励模型逻辑。有关更多详细信息,请参阅[自定义奖励模型](#自定义奖励模型)部分。
- **2025-04-30** — external vllm server 的启动命令改为 `swift rollout`

Expand All @@ -27,38 +27,62 @@ pip install -U trl

GRPO 训练框架支持集成高性能推理引擎(如 vLLM)来加速采样过程,提供以下两种部署模式:

### 1. 内部集成模式 (Internal)
### 1. Colocate Mode

- 训练与推理共享GPU资源,在 Trainer 内部启动推理服务,

- 在Trainer内部直接启动推理服务
- 提供两种资源分配策略:
- **协同模式 (Colocate)**: 训练与推理共享GPU资源
- **异步模式 (Async)**: 训练与推理使用独立GPU资源
启动参数
```bash
--vllm_mode colocate
```

### GRPO训练资源配置方案
| 配置场景 | NPROC_PER_NODE | num_infer_workers | 资源分配说明 |
|--------------------------|----------------|------------------|------------------------|
| **Colocate** | =总GPU数 | =总GPU数 | 训练和推理共享全部GPU资源 |
| **Async** | =训练卡数 | =推理卡数 | 必须满足:训练卡数 + 推理卡数 = 总GPU数 |
#### Colocate 模式下的显存优化方案
在 Colocate 模式下运行时,容易出现显存不足(OOM)的情况。以下是几种有效的显存优化方法和参数配置:

**注:**
1. 在Colocate模式下推荐设置`sleep_level=1`, 在模型训练时释放vLLM占用显存
2. 总GPU数指可见的GPU设备总数
1. 在训练阶段,释放 vLLM 占用的显存:

### 2. 外部服务模式 (External)
连接外部的 vLLM 推理服务器
使用时,使用以下参数配置外部 vLLM 服务器
```bash
--vllm_server_host <服务器IP> \
--vllm_server_port <服务端口> \
--vllm_server_timeout <超时时间> \
--sleep_level 1
```

2. 在vLLM 推理阶段,释放训练模型和优化器占用的显存:

```bash
--offload_optimizer true \
--offload_model true \
--gc_collect_after_offload true \
```

3. 在vLLM中使用 Tensor Parallel 技术:

```bash
--tensor_parallel_size [tp_size]
```

4. 分批 Gather 模型权重(zero3下同步 vLLM 权重时):
```bash
--move_model_batches [批次数量]
```

### 2. Async Mode

- 训练与推理资源分离,在外面启动单独的推理服务器

使用`swift rollout`命令部署vLLM 服务器, 现仅支持vLLM backend
```bash
CUDA_VISIBLE_DEVICES=2 \
swift rollout \
--model Qwen/Qwen2.5-VL-7B-Instruct \
--tensor_parallel_size 2 \
```

训练使用以下参数配置外部 vLLM 服务器
```bash
--vllm_server_host <服务器IP> \
--vllm_server_port <服务端口> \
--vllm_server_timeout <超时时间> \
```

完整脚本可以参考[这里](../../../examples/train/grpo/multi_node/Qwen2_5_32B_full.sh)


Expand Down
65 changes: 45 additions & 20 deletions docs/source_en/Instruction/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,67 @@ pip install -U trl

The GRPO training framework supports the integration of high-performance inference engines (such as vLLM) to accelerate the sampling process, offering the following two deployment modes:

### 1. Internal Integration Mode
### 1. Colocate Mode

- Launch the inference service directly within the Trainer.
- Provides two resource allocation strategies:
- **Colocate Mode**: Training and inference share GPU resources.
- **Async Mode**: Training and inference use separate GPU resources.
Training and inference share GPU resources; the inference service is started internally within the Trainer.

### GRPO Training Resource Allocation Scheme
Launch Parameters
```bash
--vllm_mode colocate
```

| Configuration Scenario | NPROC_PER_NODE | num_infer_workers | Resource Allocation Description |
|-------------------------|----------------|-------------------|---------------------------------------|
| **Colocate** | = Total GPUs | = Total GPUs | Training and inference share all GPU resources. |
| **Async** | = Training GPUs| = Inference GPUs | Must satisfy: Training GPUs + Inference GPUs = Total GPUs. |
#### Memory Optimization Strategies in Colocate Mode
When running in Colocate Mode , out-of-memory (OOM) errors are common due to simultaneous training and inference workloads. Below are effective memory optimization strategies and configuration parameters:

**Note:**
1. In Colocate mode, it is recommended to set `sleep_level=1` to release the GPU memory occupied by vLLM during model training.
2. Total GPUs refers to the total number of visible GPU devices.
1. Release vLLM memory during training:

### 2. External Service Mode
```bash
--sleep_level 1
```

Connect to an external vLLM inference server.
When using this mode, configure the external vLLM server with the following parameters:
2. Offload training model and optimizer memory during vLLM inference:

```bash
--vllm_server_host <Server IP> \
--vllm_server_port <Server Port> \
--vllm_server_timeout <Timeout> \
--offload_optimizer true \
--offload_model true \
--gc_collect_after_offload true \
```

Deploy the vLLM server using the `swift rollout` command. Currently, only the vLLM backend is supported.
3. Use Tensor Parallelism in vLLM:

```bash
--tensor_parallel_size [tp_size]
```

4. Batched gathering of model weights (when synchronizing vLLM weights under ZeRO-3):

```bash
--move_model_batches [number_of_batches]
```


### 2. Async Mode

Training and inference use separate resources; a dedicated inference server is launched externally.

Deploy the vLLM server using the swift rollout command. Currently, only the vLLM backend is supported:

```bash
CUDA_VISIBLE_DEVICES=2 \
swift rollout \
--model Qwen/Qwen2.5-VL-7B-Instruct \
--tensor_parallel_size 2 \
```

Use the following parameters in training to connect to an external vLLM server:

```bash
--vllm_mode server \
--vllm_server_host <Server IP> \
--vllm_server_port <Server Port> \
--vllm_server_timeout <Timeout> \
```

The complete script can be found [here](../../../examples/train/grpo/multi_node/Qwen2_5_32B_full.sh) .

## Reward Functions
Expand Down
4 changes: 2 additions & 2 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def _infer(self, inputs: InputsType, request_config: RequestConfig, is_global_in
# keys from InferRequest
per_device_size = len(inputs)
if is_global_inputs:
per_device_size /= self.accelerator.num_processes
per_device_size //= self.accelerator.num_processes
infer_inputs = [{
k: v
for k, v in inp.items() if k in ['messages', 'images', 'audios', 'videos', 'tools', 'objects']
Expand Down Expand Up @@ -722,7 +722,7 @@ def infer_task():
def done(future):
try:
result = future.result()
current_queue.put(DataCache(inputs, result))
current_queue.put(DataCache(all_inputs, result))
except Exception as e:
logger.error('Error in async_infer callback: %s', str(e))

Expand Down
Loading