Skip to content

Conversation

fsx950223
Copy link
Contributor

@fsx950223 fsx950223 commented Aug 12, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Test Plan

Test Result

(Optional) Documentation Update

Signed-off-by: fsx950223 <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Aug 12, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the CUDA graph support flag for the AiterFlashAttention backend. While the change to use the AttentionCGSupport enum is a good step towards standardization, I've identified a critical issue in the underlying implementation. The backend is marked as always supporting CUDA graphs, but the forward pass for prefill scenarios appears to be buggy, calling two attention kernels sequentially and overwriting the output. This makes the prefill path incorrect. I've recommended adjusting the support level to PURE_DECODE_ONLY to accurately reflect the backend's current state and prevent capturing faulty logic in CUDA graphs until the implementation is fixed.

Comment on lines 233 to 234
attn_cudagraph_support: ClassVar[
AttentionCGSupport] = AttentionCGSupport.ALWAYS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Setting attn_cudagraph_support to ALWAYS seems incorrect given the current implementation of AiterFlashAttentionImpl.forward.

In the forward method, for prefill scenarios (where max_seqlen_q > 1), both torch.ops.vllm.flash_attn_varlen_func and torch.ops.aiter.paged_attention_v1 are executed sequentially. The second call will overwrite the results of the first, which appears to be a bug that makes the prefill path produce incorrect results.

Given that the decode path (max_seqlen_q <= 1) only calls paged_attention_v1 and might be correct, attn_cudagraph_support should be set to AttentionCGSupport.PURE_DECODE_ONLY. This would accurately reflect that only decode operations can be safely captured in a CUDA graph. Setting it to ALWAYS is misleading and could lead to capturing incorrect logic in CUDA graphs for prefill/mixed batches.

I suggest changing this to PURE_DECODE_ONLY until the prefill logic is fixed.

Suggested change
attn_cudagraph_support: ClassVar[
AttentionCGSupport] = AttentionCGSupport.ALWAYS
attn_cudagraph_support: ClassVar[
AttentionCGSupport] = AttentionCGSupport.PURE_DECODE_ONLY

Copy link
Contributor

@tjtanaa tjtanaa Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fsx950223
This attn_cudagraph_support has been renamed to cudagraph_support in this PR #20059 (merged 4 days ago into main)

The full_cuda_graph is deprecated.

It is now controlled through --compilation-config '{"cudagraph_mode": "FULL"}'

There are multiple modes introduced in this PR #20059

cudagraph_mode: Optional[CUDAGraphMode] = None
"""
The mode of the cudagraph.
- NONE, no cudagraph capture.
- PIECEWISE. (v1 default)
- FULL.
- FULL_DECODE_ONLY.
- FULL_AND_PIECEWISE.
PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph
incompatiable ops (i.e. some attention ops) outside the cudagraph
for general flexibility.
This is the default mode.
FULL mode: Capture full cudagraph for all batches. Can be good for small
models or workloads with small prompts; not supported by many backends.
Generally for performance FULL_AND_PIECEWISE is better.
FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only.
Mixed prefill-decode batches are run without cudagraphs. Can be good for
decode instances in a P/D setup where prefill is not as important so we
can save some memory.
FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
piecewise cudagraph for prefill and mixed prefill-decode batches.
This is like the most performant mode for most models.
Currently, the cudagraph mode is only used for the v1 engine.
Note that the cudagraph logic is generally orthogonal to the
compilation logic. While piecewise cudagraphs require piecewise
compilation (level=PIECEWISE and non-empty splitting_ops), full
cudagraphs are supported with and without compilation.
Warning: This flag is new and subject to change in addition
more modes may be added.
"""

@tjtanaa
Copy link
Contributor

tjtanaa commented Aug 20, 2025

Additional Tests:

Server command:

#!/bin/bash

rm -rf /root/.cache/vllm

MODEL=Qwen/Qwen3-8B

VLLM_USE_V1=1 VLLM_ROCM_USE_AITER=1 \
vllm serve $MODEL \
--tensor-parallel-size 8 \
--max-num-seqs 1024 \
--kv-cache-dtype fp8 \
--max-num-batched-tokens 32768 \
--disable-log-requests \
--compilation-config '{"cudagraph_mode": "FULL"}' \
--trust-remote-code

server log:

INFO 08-20 03:32:27 [__init__.py:241] Automatically detected platform rocm.
WARNING 08-20 03:32:29 [__init__.py:1726] argument '--disable-log-requests' is deprecated and replaced with '--enable-log-requests'. This will be removed in v0.12.0.
�[1;36m(APIServer pid=148339)�[0;0m INFO 08-20 03:32:29 [api_server.py:1805] vLLM API server version 0.9.2rc2.dev1221+ga4eaba3e8
�[1;36m(APIServer pid=148339)�[0;0m INFO 08-20 03:32:29 [utils.py:326] non-default args: {'model_tag': 'Qwen/Qwen3-8B', 'model': 'Qwen/Qwen3-8B', 'trust_remote_code': True, 'tensor_parallel_size': 8, 'kv_cache_dtype': 'fp8', 'max_num_batched_tokens': 32768, 'max_num_seqs': 1024, 'compilation_config': {"level":null,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":null,"use_inductor":true,"compile_sizes":null,"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"cudagraph_mode":2,"use_cudagraph":true,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":null,"cudagraph_copy_inputs":false,"full_cuda_graph":false,"pass_config":{},"max_capture_size":null,"local_cache_dir":null}}
�[1;36m(APIServer pid=148339)�[0;0m The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
�[1;36m(APIServer pid=148339)�[0;0m INFO 08-20 03:32:50 [__init__.py:711] Resolved architecture: Qwen3ForCausalLM
�[1;36m(APIServer pid=148339)�[0;0m INFO 08-20 03:32:50 [__init__.py:1750] Using max model len 40960
�[1;36m(APIServer pid=148339)�[0;0m INFO 08-20 03:32:51 [cache.py:176] Using fp8 data type to store kv cache. It reduces the GPU memory footprint and boosts the performance. Meanwhile, it may cause accuracy drop without a proper scaling factor.
�[1;36m(APIServer pid=148339)�[0;0m INFO 08-20 03:32:51 [scheduler.py:222] Chunked prefill is enabled with max_num_batched_tokens=32768.
�[1;36m(APIServer pid=148339)�[0;0m INFO 08-20 03:32:51 [__init__.py:3632] CUDAGraphMode.FULL is not supported with cascade attention currently. Disabling cascadeattention.

...

Capturing CUDA graphs (mixed prefill-decode, FULL):   0%|          | 0/67 [00:00<?, ?it/s]
Capturing CUDA graphs (mixed prefill-decode, FULL):   1%|▏         | 1/67 [00:00<00:23,  2.81it/s]
Capturing CUDA graphs (mixed prefill-decode, FULL):   3%|▎         | 2/67 [00:00<00:21,  3.08it/s]
Capturing CUDA graphs (mixed prefill-decode, FULL):   4%|▍         | 3/67 [00:00<00:19,  3.29it/s]
Capturing CUDA graphs (mixed prefill-decode, FULL):   6%|▌         | 4/67 [00:01<00:19,  3.29it/s]
Capturing CUDA graphs (mixed prefill-decode, FULL):   7%|▋         | 5/67 [00:01<00:18,  3.29it/s]
Capturing CUDA graphs (mixed prefill-decode, FULL):   9%|▉         | 6/67 [00:01<00:18,  3.30it/s]

Client command:

#!/bin/bash

lm_eval \
--model local-completions \
--tasks gsm8k \
--model_args model=Qwen/Qwen3-8B,base_url=http://127.0.0.1:8000/v1/completions \
--batch_size 100 \
> lmeval_server-Qwen_Qwen3-8B-aiter-v1-mha-fullgraphmode_FULL_PR.log 2>&1

lm eval score of Qwen/Qwen3-8B

local-completions (model=Qwen/Qwen3-8B,base_url=http://127.0.0.1:8000/v1/completions), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 100
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8719|±  |0.0092|
|     |       |strict-match    |     5|exact_match|↑  |0.8635|±  |0.0095|

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) August 20, 2025 03:45
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 20, 2025
@DarkLight1337 DarkLight1337 merged commit d983769 into vllm-project:main Aug 20, 2025
47 checks passed
divakar-amd pushed a commit to divakar-amd/vllm_upstream that referenced this pull request Aug 20, 2025
cyang49 pushed a commit to cyang49/vllm that referenced this pull request Aug 20, 2025
djmmoss pushed a commit to djmmoss/vllm that referenced this pull request Aug 21, 2025
Signed-off-by: fsx950223 <[email protected]>
Signed-off-by: Duncan Moss <[email protected]>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: fsx950223 <[email protected]>
Signed-off-by: Xiao Yu <[email protected]>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
mengxingkongzhouhan pushed a commit to mengxingkongzhouhan/vllm that referenced this pull request Aug 30, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants