Skip to content

Conversation

SageMoore
Copy link
Contributor

@SageMoore SageMoore commented Aug 26, 2025

Purpose

This PR adds support for Dual-Batch Overlap in VLLM. In it's current state it will only be abled when a user provides the --enable-microbatching flag. Furthermore, it will only be used when all DP groups are running full-decode batches. This PR supports running DBO with full cudagraphs, which is essential for minimizing the CPU overhead and getting performance from this feature.

To implement Dual-Batch Overlap (DBO), at a high level, we split the batch into two microbatches. Then using two threads and two cuda streams, one for communication and one for computation, to overlap the dispatch and combine all-to-all kernels of one microbatch with the compute kernels of the other microbatch.

When microbatching is enabled and supported, the GPUModelRunner will split the batch into two token_slices. These token_slices are then passed into the attention meta data builders during _prepare_inputs to generate one attention metadata object per-microbatch. When actually running the model, the model runner will spawn off two microbatching threads that will each communicate with each other using a UBatchContext. Each of these threads will then run self.model with the appropriate attention meta data.

Without any additional modifications to the code, this will just result in one microbatch running to completion before the other microbatch starts. In order to get overlaps, we've added a "yield" call that can be inserted into the all-to-all kernels to interleave the two microbatches. The yield_and_switch_from_compute_to_comm function yield the CPU from this thread (thread A) to the other microbatching thread (thread B). Once thread A has resumed execution, either because thread B yielded the CPU or finished it's execution, it will swap over to the communication stream and start dispatching kernels there. yield_and_switch_from_comm_to_compute behaves similarly but in the opposite direction. It swaps from the communication stream to the compute stream.

There are both GPU and CPU events to synchronize all of this. That being said, it is absolutely critical that only one microbatching thread is running at a time, meaning the other one is waiting on an event. It is also absolutely critical that both microbatches are running the exact same number of yields.

Test Plan

In general my test plan was to run lm_eval with deepseek-ai/DeepSeek-V2-Lite. We've also run numerous times with R1 in a multi node setup and verified that lm_eval produces reasonable output.

Non-DBO Runs

Eager

Command

VLLM_ALL2ALL_BACKEND=deepep_low_latency vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --data-parallel-size 2 --enable-expert-parallel --enforce-eager

Result
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3567|±  |0.0277|
|     |       |strict-match    |     5|exact_match|↑  |0.3533|±  |0.0276|

Default

Command

VLLM_ALL2ALL_BACKEND=deepep_low_latency g2 vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --data-parallel-size 2 --enable-expert-parallel

Result
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3700|±  |0.0279|
|     |       |strict-match    |     5|exact_match|↑  |0.3667|±  |0.0279|

DBO Runs

Eager

Command

VLLM_ALL2ALL_BACKEND=deepep_low_latency g2 vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --data-parallel-size 2 --enable-expert-parallel --enforce-eager --enable-microbatching --microbatching-token-threshold 4

Result
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3800|±  |0.0281|
|     |       |strict-match    |     5|exact_match|↑  |0.3767|±  |0.0280|

Full cudagraphs

Command

VLLM_ALL2ALL_BACKEND=deepep_low_latency g2 vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --data-parallel-size 2 --enable-expert-parallel --compilation_config '{"cudagraph_mode": "full_decode_only"}' --enable-microbatching --microbatching-token-threshold 4

Result
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3733|±  |0.0280|
|     |       |strict-match    |     5|exact_match|↑  |0.3700|±  |0.0279|

LucasWilkinson and others added 30 commits May 22, 2025 20:51
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Sage Moore <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
@mergify mergify bot removed the needs-rebase label Sep 15, 2025
Copy link

mergify bot commented Sep 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @SageMoore.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 16, 2025
@tlrmchlsmth
Copy link
Member

tlrmchlsmth commented Sep 16, 2025

I thought the kernels-moe-test failures were due to VLLM_USE_PRECOMPILED=1 not picking up the changes from #24054, but that was from 3 days ago

Could it be a real problem? @elvircrn, @dougbtv

AttributeError:` '_OpNamespace' '_C' object has no attribute 'silu_mul_fp8_quant_deep_gemm_cuda

Edit: Confirmed it's picking up old binaries.

@mergify mergify bot removed the needs-rebase label Sep 16, 2025
@tlrmchlsmth tlrmchlsmth merged commit 5679399 into vllm-project:main Sep 16, 2025
54 checks passed
@NihalPotdar
Copy link

Hey! Quick question - do you have any performance numbers for this change?

Mainly wondering about the efficiency of the communication-computation overlap strategy in the PR.

Comment on lines +274 to +286
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
compute_stream=compute_stream,
dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE)

return self._capture_ubatches(ubatch_metadata, self.model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we only have self.runnable in this wrapper, no self.model, right? Or am I missing something? Both L286 and L303

Yikun pushed a commit to vllm-project/vllm-ascend that referenced this pull request Sep 20, 2025
…2907)

### What this PR does / why we need it?
1. This pr bump vllm commit to
vllm-project/vllm@6d8246a
2. fix upstream changes vllm-project/vllm#24548
abort multi-modal kwargs, make vllm main and `v0.10.2` both adaptable
3. fix metadata_builder changes introduced by
vllm-project/vllm#23693
4. fix `structured_outputs_config` changes introduced by
vllm-project/vllm#22772
5. fix `moe_config` changes introduced by
vllm-project/vllm#22537

Co-authored-by:  MengqingCao <[email protected]>
Co-authored-by:  Yikun Jiang <[email protected]>


- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@c60e613

---------

Signed-off-by: wangli <[email protected]>
Signed-off-by: MengqingCao <[email protected]>
Co-authored-by: MengqingCao <[email protected]>
weijinqian0 pushed a commit to weijinqian0/vllm-ascend that referenced this pull request Sep 22, 2025
…llm-project#2907)

### What this PR does / why we need it?
1. This pr bump vllm commit to
vllm-project/vllm@6d8246a
2. fix upstream changes vllm-project/vllm#24548
abort multi-modal kwargs, make vllm main and `v0.10.2` both adaptable
3. fix metadata_builder changes introduced by
vllm-project/vllm#23693
4. fix `structured_outputs_config` changes introduced by
vllm-project/vllm#22772
5. fix `moe_config` changes introduced by
vllm-project/vllm#22537

Co-authored-by:  MengqingCao <[email protected]>
Co-authored-by:  Yikun Jiang <[email protected]>

- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@c60e613

---------

Signed-off-by: wangli <[email protected]>
Signed-off-by: MengqingCao <[email protected]>
Co-authored-by: MengqingCao <[email protected]>
Mercykid-bash pushed a commit to Mercykid-bash/vllm-ascend that referenced this pull request Sep 22, 2025
…llm-project#2907)

### What this PR does / why we need it?
1. This pr bump vllm commit to
vllm-project/vllm@6d8246a
2. fix upstream changes vllm-project/vllm#24548
abort multi-modal kwargs, make vllm main and `v0.10.2` both adaptable
3. fix metadata_builder changes introduced by
vllm-project/vllm#23693
4. fix `structured_outputs_config` changes introduced by
vllm-project/vllm#22772
5. fix `moe_config` changes introduced by
vllm-project/vllm#22537

Co-authored-by:  MengqingCao <[email protected]>
Co-authored-by:  Yikun Jiang <[email protected]>

- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@c60e613

---------

Signed-off-by: wangli <[email protected]>
Signed-off-by: MengqingCao <[email protected]>
Co-authored-by: MengqingCao <[email protected]>
Signed-off-by: Che Ruan <[email protected]>
Mercykid-bash pushed a commit to Mercykid-bash/vllm-ascend that referenced this pull request Sep 22, 2025
…llm-project#2907)

### What this PR does / why we need it?
1. This pr bump vllm commit to
vllm-project/vllm@6d8246a
2. fix upstream changes vllm-project/vllm#24548
abort multi-modal kwargs, make vllm main and `v0.10.2` both adaptable
3. fix metadata_builder changes introduced by
vllm-project/vllm#23693
4. fix `structured_outputs_config` changes introduced by
vllm-project/vllm#22772
5. fix `moe_config` changes introduced by
vllm-project/vllm#22537

Co-authored-by:  MengqingCao <[email protected]>
Co-authored-by:  Yikun Jiang <[email protected]>

- vLLM version: v0.10.2
- vLLM main:
vllm-project/vllm@c60e613

---------

Signed-off-by: wangli <[email protected]>
Signed-off-by: MengqingCao <[email protected]>
Co-authored-by: MengqingCao <[email protected]>
Signed-off-by: Che Ruan <[email protected]>
@lhtin
Copy link
Contributor

lhtin commented Sep 23, 2025

@SageMoore @LucasWilkinson Could you provide some performance improvement data? I tested DeepSeek V2 Lite locally and observed a negative performance gain, with the per-step latency increasing from 38ms to 49ms. The process of launching vLLM and the test results are shown below.

According to the Nsys profile data, after enabling DBO, the execution time of both kernel batched_triton_kerneland vllm::act_and_mul_kernelhas increased significantly.

config.yaml:

model: /path/to/DeepSeek-V2-Lite
tensor-parallel-size: 1
data-parallel-size: 2
enable-expert-parallel: true
served-model-name: vllm_infer_1
enable-dbo: true
dbo-decode-token-threshold: 4

launch vllm:

export VLLM_ALL2ALL_BACKEND=deepep_low_latency
vllm serve --config config.yaml

launch bench:

vllm bench serve \
    --model /path/to/DeepSeek-V2-Lite/ \
    --served-model-name vllm_infer_1 \
    --random-input-len 1 \
    --random-output-len 1024 \
    --num-prompts 1000 \
    --max-concurrency 100 \
    --ignore-eos

timeline with dbo:
image
image

timeline without dbo:
image
image

@LucasWilkinson
Copy link
Collaborator

@SageMoore @LucasWilkinson Could you provide some performance improvement data? I tested DeepSeek V2 Lite locally and observed a negative performance gain, with the per-step latency increasing from 38ms to 49ms. The process of launching vLLM and the test results are shown below.

According to the Nsys profile data, after enabling DBO, the execution time of both kernel batched_triton_kerneland vllm::act_and_mul_kernelhas increased significantly.

Yes this is expected; DBO will increase the GEMM time when running a memory bound workload since the full model weights will have to be loaded twice (once for each microbatch). So DBO is only really beneficial when the communication time is >1x GEMM time; so it's really only intended to be used in multi-node EP setup where the communications costs are much higher. Its not expected to provide speed-up in a single node environment.

@lhtin
Copy link
Contributor

lhtin commented Sep 24, 2025

Yes this is expected; DBO will increase the GEMM time when running a memory bound workload since the full model weights will have to be loaded twice (once for each microbatch). So DBO is only really beneficial when the communication time is >1x GEMM time; so it's really only intended to be used in multi-node EP setup where the communications costs are much higher. Its not expected to provide speed-up in a single node environment.

Thank you for the explanation. The proportion of communication time I tested on the H20 is indeed very small, less than 10%.

FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…t#23693)

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Sage Moore <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: yewentao256 <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
Cwndmiao added a commit to Cwndmiao/vllm that referenced this pull request Sep 26, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…t#23693)

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Sage Moore <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: yewentao256 <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
…t#23693)

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Sage Moore <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: yewentao256 <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.