Skip to content

Conversation

vllmellm
Copy link
Contributor

@vllmellm vllmellm commented May 1, 2025

AITER MLA Support for V1 Engine

This PR implements AITER MLA attention backend support for the V1 engine. The implementation mirrors the V0 engine's established approach from PR #15893.

This PR also introduces a new environment variable, VLLM_ROCM_EXECUTE_MODEL_TIMEOUT, which specifies the model execution timeout in seconds. This allows for flexible adjustment of execution time, which is helpful since a timeout error was encountered during graph building when enabling AITER MLA ops on the v1 engine.

Accuracy Validation

using the command below:
VLLM_ATTENTION_BACKEND=ROCM_AITER_MLA VLLM_USE_V1=1 lm_eval \ --model vllm \ --model_args pretrained=deepseek-ai/DeepSeek-V3,tensor_parallel_size=8,trust_remote_code=True,max_model_len=32768,block_size=1,enforce_eager=False \ --tasks gsm8k --num_fewshot 5 --batch_size auto

Results:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9492 ± 0.0060
strict-match 5 exact_match 0.9477 ± 0.0061

Performance:

The results of benchmarks/benchmark_serving.py

using the commands below:
v0 engine = VLLM_ATTENTION_BACKEND=ROCM_AITER_MLA VLLM_USE_V1=0 python benchmarks/benchmark_serving.py --model deepseek-ai/DeepSeek-V3 --trust-remote-code --dataset-name random
v1 engine = VLLM_ATTENTION_BACKEND=ROCM_AITER_MLA VLLM_USE_V1=1 python benchmarks/benchmark_serving.py --model deepseek-ai/DeepSeek-V3 --trust-remote-code --dataset-name random

Metric ROCm AITER MLA V1 ROCm AITER MLA V0
Successful requests 1000 1000
Benchmark duration (s) 179.78 180.92
Total input tokens 1024000 1024000
Total generated tokens 39667 38739
Request throughput (req/s) 5.56 5.53
Output token throughput (tok/s) 220.64 214.12
Total Token throughput (tok/s) 5916.41 5873.93
---------------------------------------------- ------------------------------------- -------------------
Mean TTFT (ms) 85618.95 87984.79
Median TTFT (ms) 85528.56 91664.85
P99 TTFT (ms) 166979.65 96858.61
---------------------------------------------- ------------------------------------- -------------------
Mean TPOT (ms) 1048.37 4727.25
Median TPOT (ms) 1242.70 1189.05
P99 TPOT (ms) 1593.63 38742.17
---------------------------------------------- ------------------------------------- -------------------
Mean ITL (ms) 698.66 857.94
Median ITL (ms) 1226.67 110.7
P99 ITL (ms) 1664.92 10027.72

vllmellm and others added 29 commits March 28, 2025 08:19
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
@vllmellm vllmellm requested a review from tlrmchlsmth as a code owner May 1, 2025 08:33
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000

EXECUTE_MODEL_TIMEOUT_S = 40
EXECUTE_MODEL_TIMEOUT_S = (envs.VLLM_ROCM_EXECUTE_MODEL_TIMEOUT
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering why rocm needs a much larger timeout here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on first time run when the graph is being created it might take between 100-250 seconds depending on how many AITER kernels are enabled. Thus we kept the default timeout to 250s.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not crazy about requiring another environment variable when running AITER. Can you just set the timeout to 250 here instead of asking the user to increase the timeout? Feel free to give a "safe" timeout.

Copy link
Member

@tlrmchlsmth tlrmchlsmth May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will these very long runs only happen during the profile and graph capture runs? Or can they happen while processing real requests?

@houseroad
Copy link
Collaborator

Could you check the failed tests?

Copy link
Collaborator

@hongxiayang hongxiayang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve with comment.

To make Deepseek V1 performant, it needs additional work.

Based on my test, it can improve TTFT if additional AITER environment variables are used. Otherwise, the TTFT is not as good comparing to V0. Throughput is not good yet comparing to V0.

compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd,
from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_forward,
Copy link
Collaborator

@hongxiayang hongxiayang May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this change from fwd -> forward seems not necessary, in order to minimize the number of files changed in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it has been addressed in the latest commit.



def aiter_mla_decode_fwd(
def aiter_mla_decode_forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can keep the name as _fwd (see below line 37 decode_fwd)

@hongxiayang
Copy link
Collaborator

@houseroad I found using below environment variables can fix the huge TTFT issue described in the PR.

VLLM_ATTENTION_BACKEND=ROCM_AITER_MLA VLLM_USE_V1=1 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MOE=1  VLLM_ROCM_USE_AITER_MLA=1 

My command is below:

VLLM_ATTENTION_BACKEND=ROCM_AITER_MLA VLLM_USE_V1=1 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MOE=1  VLLM_ROCM_USE_AITER_MLA=1  vllm serve /amdhome/models/DeepSeek-R1 --trust-remote-code -tp 8 --max-model-len 32768 --block-size 1 --no-enable-prefix-caching --max-num-batched-tokens 32768 --max-num-seqs 1024

For input-len/output-len/concurrency/prompts 1000/1000/1/2, the TTFT is changed from 57419.91 to 154.8.

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable. Just a few nits and questions

from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func

def _flash_attn_varlen_diff_headdims(self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore you may want to check coomon.py the method _flash_attn_varlen_diff_headdims is defined there and overridden in this class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now. I must have mistyped the string when I searched for it :).

assert max_model_len == 32768,\
"AITER MLA requires max_model_len=32768"
assert self.runner.block_size == 1, "AITER MLA" \
"requires only block size 1."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "only supports block size 1."

POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000

EXECUTE_MODEL_TIMEOUT_S = 40
EXECUTE_MODEL_TIMEOUT_S = (envs.VLLM_ROCM_EXECUTE_MODEL_TIMEOUT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not crazy about requiring another environment variable when running AITER. Can you just set the timeout to 250 here instead of asking the user to increase the timeout? Feel free to give a "safe" timeout.

vllm/envs.py Outdated
lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")),

# Time in seconds for the model execution in ROCm platforms.
"VLLM_ROCM_EXECUTE_MODEL_TIMEOUT":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this. See below comment.

Copy link
Contributor Author

@vllmellm vllmellm May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore At this moment we can't find "safe" timeout because depending on number of AITER kernels are enable knowing the "safe" timeout is difficult as tracing the AITER jit files might be time consuming during execution time and might change as AITER ops might change based on different versions would be used in future. Thus, rather than having a hardcoded timeout that might trouble the user where to spot it in the code they are able to control this value with environment variable.

Copy link
Contributor Author

@vllmellm vllmellm May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore The environment variable has been removed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore : the env change is removed as we discussed. Please merge this asap if there are no other blockers.

# `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk,
self.page_size)
if self.aot_schedule:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this a bit? Why was this change necessary?

Copy link
Contributor Author

@vllmellm vllmellm May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore the self.page_size if only defined in __init__ with the condition self.aot_schedule while on ROCm this condition is not true and it encounters the error self.page_size is not defined.

self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3)
# Dont try to access the runner on AMD
if self.aot_schedule:
self.page_size = self.runner.block_size

You may want to ask the author about this as these line changes were added in this PR.

anyways if self.page_size is defined without this self.aot_schedule condition it does not have any effect on ROCm at least for AITER MLA which is the only MLA backend in V1 currently.

Copy link

mergify bot commented May 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vllmellm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 7, 2025
@mergify mergify bot removed the needs-rebase label May 8, 2025
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable. Thanks for taking out the timeout changes!

@DarkLight1337 DarkLight1337 merged commit 3c9396a into vllm-project:main May 9, 2025
57 of 58 checks passed
@chaunceyjiang
Copy link
Collaborator

It seems that this PR introduced a static check error.

https://github.com/vllm-project/vllm/actions/runs/14923384538/job/41922811885?pr=17845

Error: vllm/v1/attention/backends/mla/rocm_aiter_mla.py:98: error: Signature of "_build_decode" incompatible with supertype "MLACommonMetadataBuilder"  [override]
vllm/v1/attention/backends/mla/rocm_aiter_mla.py:98: note:      Superclass:
vllm/v1/attention/backends/mla/rocm_aiter_mla.py:98: note:          def _build_decode(self, block_table: Any, seq_lens: Any) -> Any
vllm/v1/attention/backends/mla/rocm_aiter_mla.py:98: note:      Subclass:
vllm/v1/attention/backends/mla/rocm_aiter_mla.py:98: note:          def _build_decode(self, input_positions: Any, block_table: Any, seq_lens: Any) -> AiterMLADecodeMetadata
Error: vllm/v1/attention/backends/mla/rocm_aiter_mla.py:108: error: Unexpected keyword argument "input_positions" for "AiterMLADecodeMetadata"  [call-arg]
Found 2 errors in 1 file (checked 83 source files)

princepride pushed a commit to princepride/vllm that referenced this pull request May 10, 2025
Signed-off-by: vllmellm <[email protected]>
Co-authored-by: qli88 <[email protected]>
Co-authored-by: Hongxia Yang <[email protected]>
Signed-off-by: 汪志鹏 <[email protected]>
@tjtanaa
Copy link
Contributor

tjtanaa commented May 10, 2025

EXECUTE_MODEL_TIMEOUT_S

fixed by #17880

RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: vllmellm <[email protected]>
Co-authored-by: qli88 <[email protected]>
Co-authored-by: Hongxia Yang <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
Signed-off-by: vllmellm <[email protected]>
Co-authored-by: qli88 <[email protected]>
Co-authored-by: Hongxia Yang <[email protected]>
tjtanaavllm pushed a commit to ROCm/vllm that referenced this pull request May 16, 2025
Signed-off-by: vllmellm <[email protected]>
Co-authored-by: qli88 <[email protected]>
Co-authored-by: Hongxia Yang <[email protected]>
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Signed-off-by: vllmellm <[email protected]>
Co-authored-by: qli88 <[email protected]>
Co-authored-by: Hongxia Yang <[email protected]>
Signed-off-by: Yuqi Zhang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants