-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Add FlexAttention to V1 #16078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FlexAttention to V1 #16078
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
d173560
to
0420216
Compare
Amazingly exciting! |
8123983
to
0460b1b
Compare
2f13f54
to
73b646e
Compare
0942d47
to
1a3c2f5
Compare
I will flesh out the PR summary, one thing I have found is that we need a newer version of PT since I have fixed a number of dynamic shape issues since 2.6.0. On nightly PT ERROR 04-15 15:43:02 [core.py:387] File "/home/drisspg/.conda/envs/vllm_main/lib/python3.12/site-packages/torch/cuda/__init__.py", line 364, in _lazy_init
ERROR 04-15 15:43:02 [core.py:387] raise RuntimeError(
ERROR 04-15 15:43:02 [core.py:387] RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
ERROR 04-15 15:43:02 [core.py:387]
CRITICAL 04-15 15:43:02 [core_client.py:359] Got fatal signal from worker processes, shutting down. See stack trace above for root cause issue.
zsh: killed python vllm/flex.py |
@WoosukKwon So we fixed a few more dynamic shape bugs, I have been testing this against latest pytorch nightly which might be why we are seeing different results, I am getting Results:
--------------------------------------------------------------------------------
Prompt: Hello, my name is
Generated: Kristi and I am so excited to be here today. I have been a homeschool mom for 15 years and am married to my best friend. We have 4 kids and live in Arkansas. I currently work for the Arkansas Dept. of Education as a teacher of special education.
--------------------------------------------------------------------------------
Prompt: The president of the United States is
Generated: responsible for the government of the United States. The president serves a four-year term and is elected by popular vote. The United States has a parliamentary government, which means that the president and the members of the legislative branch work together to create a law.
The United States was established as a republic in 1776. It is an example of a democracy, which means that the citizens of the United States choose their government. The president is the head of state. He or she is elected for a term of four years. The vice president is the second highest-ranking official in the United States. He or she is elected for a term of four
--------------------------------------------------------------------------------
Prompt: The capital of France is
Generated: Paris, the country’s fashion capital, home to some of the world’s most famous fashion houses. Travel to the south of France and visit the poppy fields of Provence and the French Riviera. Savor the culinary delights of Paris, watch the world’s greatest artists at the Louvre and enjoy a luxurious cruise along the Rhone River.
Welcome to Paris, the capital of France and a fashion capital of the world. Take some time to explore the city’s many museums, including the Louvre, the Musée d’Orsay and the Orangerie. You’ll also enjoy a cruise along the Seine River. Head
--------------------------------------------------------------------------------
Prompt: The future of AI is
Generated: in the hands of students
The future of artificial intelligence is in the hands of college students, according to IBM’s latest report on the state of artificial intelligence.
In the report, the IBM Watson Group released a survey of 1,500 college students from around the world, including students in the U.S. and the United Kingdom.
Students were asked to rate their level of interest in AI on a scale of one to five, with five being the highest.
Students were also asked how much they believed AI was a threat to their jobs, and how much they believed AI was a tool to help them find a new job.
The report found
-------------------------------------------------------------------------------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me. I think we should ship this and continue iterating on it (it's already gotten pretty large, and we have a sense of the cases where it doesn't work for yet).
@drisspg Is torch nightly required? I'm now seeing reasonable outputs with torch v2.7.0. |
Its possible to run into a few dynamic shapes issues, w/ 2.7.0 around max-autotune that we resolved, specifically this line: https://github.com/vllm-project/vllm/pull/16078/files#diff-0310608cf47330020e617d94f28ce469e6e802e291f33ce4bce90e22e11cc7e5R35 ideally we would compile w/ max-autotune. But as implemented it should be good w/ 2.7 |
@WoosukKwon I will update, someone suggested this name - does seem a little redundant 😂 |
Baseline❯ lm_eval --model vllm \
--model_args pretrained=Qwen/Qwen3-8B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,num_gpu_blocks_override=512 \
--tasks lambada_openai \
--batch_size auto
...
INFO:lm_eval.loggers.evaluation_tracker:Output path not provided, skipping saving results aggregated
vllm (pretrained=Qwen/Qwen3-8B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,num_gpu_blocks_override=512), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
|--------------|------:|------|-----:|----------|---|-----:|---|-----:|
|lambada_openai| 1|none | 0|acc |↑ |0.6503|± |0.0066|
| | |none | 0|perplexity|↓ |4.5999|± |0.1390| Flex VLLM_ATTENTION_BACKEND=FLEX_ATTENTION lm_eval --model vllm \
--model_args pretrained=Qwen/Qwen3-8B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,num_gpu_blocks_override=512 \
--tasks lambada_openai \
--batch_size auto
...
INFO:lm_eval.loggers.evaluation_tracker:Output path not provided, skipping saving results aggregated
vllm (pretrained=Qwen/Qwen3-8B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,num_gpu_blocks_override=512), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
|--------------|------:|------|-----:|----------|---|-----:|---|-----:|
|lambada_openai| 1|none | 0|acc |↑ |0.6522|± |0.0066|
| | |none | 0|perplexity|↓ |4.6014|± |0.1390|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While not blocking, I placed several comments above 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@drisspg I'm seeing 10x slowdown of e2e performance on sharegpt throughput benchmark (I'm using torch 2.7.0). Is this expected? Is this because re-compilation happens at run time?
@WoosukKwon Yup, as implemented this is expected and touched upon in the summary but it is not due to recompilation, the problem is that flex-attention and create_block mask consist of alot of small cpu bound operations in order to prep the metadata. Since Flex is currently in a custom up we can't use compile to hide this overhead. In traces I am seeing a 2x slowdown for actually attention calculation which aligns with what I expected comparing against FAv3 In a follow PR figuring out how to enable direct-call w/ proper metadata should greatly help this problem |
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea looks great to me, leave it to @WoosukKwon for final sign-off
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGMT overall. Sorry for the delayed review 🙏
Looking forward to extending this to more diverse attention algorithms!
Also, hope we can get the fix for the recompilation issue.
@drisspg Could you please check the failed CI tests and rebase the PR? Will merge once the CI gets green. :) |
b5b2f14
to
5d2556d
Compare
Signed-off-by: drisspg <[email protected]>
Summary
This PR adds FlexAttention as a new unified_attention backend for the V1 engine.
This requires torch > 2.7 since we fixed a number of dynamic shapes issues that show up by default here.
Design
FlexAttention is broken up into two distinct phases, block mask creation and the call to forward. For most Transformers they N attention layers share a common attention pattern and thus we can amortize the cost of block mask creation over the N attention layers. This lends itself pretty nicley to the Metadata Builder for the UA OP.
Majority of the work here is to build the correct BlockMask.
The current BlockTable is of the form:

This block table is a map from logical KV pages to Physical KV pages in the paged KV cache. It has a size of
MAX_REQS x (MAX_SEQ_LEN//PAGE_SIZE).
FlexAttention has no notion of a PageTable and we have to build out inverse mapping from Physical Pages (the full paged KV Cache is input to kernel) to Logical Indices. We then use these logical indices to determine if we should compute attention for a query x KV pair.
Setting up a Generic Physical to Logical re-writer
Once we have this Physical to Logical Map we can abstract this away from different logical mask_mods. We do this w/ this function: https://github.com/vllm-project/vllm/pull/16078/files#diff-0310608cf47330020e617d94f28ce469e6e802e291f33ce4bce90e22e11cc7e5R196
By default this type of attention can be seen as a document-packed or var-seq len transformation so that all separate queries are splatted into 1 super sequence. We use have a small lookup from physical q_idx to req_id.
We use this req_id to get the inverse page_table. And with this inverse page table we isolate attention to valid sequences (valid blocks and < current seq_len)
We then adjust the q_idx by the offset - which is 0 during prefill.
Once all thats done, we can have a pure "logical mask mod" which by default - and for most models will be
Adding New Variants
The nice thing about the above setup is that it makes adding new variants simpler since we have a generic paged+packed rewriter. And users will only need register a simple logical mod here: https://github.com/drisspg/vllm/blob/891345dd545ad86ca57163f34d4ea7696610dea3/vllm/v1/attention/backends/flex_attention.py#L177
For
score_mods
I didnt put much effort and mostly disabled for now since that should be an easy follow up. For instance if you wanted to support tanh softcapping we would just need to pass in a tanh_score_mod here: https://github.com/drisspg/vllm/blob/891345dd545ad86ca57163f34d4ea7696610dea3/vllm/v1/attention/backends/flex_attention.py#L433Performance Gaps (for now)
Trace for a baseline enforce_eager, but flex_components compiled for Qwen 1.5b w/ full KVCache e.g. 2.84 Million Tokens single GPU
Trace: https://fburl.com/1uo5h5uy
TLP: https://fburl.com/z5u609g8
In this case we see 2 frames and 1 recompile each for these. And in this case the recompiles are coming from difference in Query length size between prefill and decode. This is exactly what mark_dynamic should be able to solve but since we have a dynamic integer "Q_LEN" we need to use
TORCH_COMPILE_DYNAMIC_SOURCES
but I couldn't get this to work so punting for now and having the two recompiles.In this case it is taking around 3.7 seconds before we settle into a steady state (for the given requests)

Lets zoom in on the steady state:

Example