Skip to content

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Apr 4, 2025

Summary

This PR adds FlexAttention as a new unified_attention backend for the V1 engine.

This requires torch > 2.7 since we fixed a number of dynamic shapes issues that show up by default here.

Design

FlexAttention is broken up into two distinct phases, block mask creation and the call to forward. For most Transformers they N attention layers share a common attention pattern and thus we can amortize the cost of block mask creation over the N attention layers. This lends itself pretty nicley to the Metadata Builder for the UA OP.

Majority of the work here is to build the correct BlockMask.

The current BlockTable is of the form:
image

This block table is a map from logical KV pages to Physical KV pages in the paged KV cache. It has a size of
MAX_REQS x (MAX_SEQ_LEN//PAGE_SIZE).

FlexAttention has no notion of a PageTable and we have to build out inverse mapping from Physical Pages (the full paged KV Cache is input to kernel) to Logical Indices. We then use these logical indices to determine if we should compute attention for a query x KV pair.

    Logical to Physical (Original block_table):
    ┌───────────────────────────────────────────┐
    │ Request 0:                                │
    │                                           │
    │ Logical Blocks:  0  1  2  3  4  5  6  7   │
    │                  │  │  │  │  │  │  │  │   │
    │                  v  v  v  v  v  v  v  v   │
    │ Physical Blocks: 3  5  1  7  4  2  0  6   │
    └───────────────────────────────────────────┘

    This function creates the inverse mapping:

    Physical to Logical (Inverse mapping):
    ┌───────────────────────────────────────────┐
    │ Request 0:                                │
    │                                           │
    │ Physical Blocks: 0  1  2  3  4  5  6  7   │
    │                  │  │  │  │  │  │  │  │   │
    │                  v  v  v  v  v  v  v  v   │
    │ Logical Blocks:  6  2  5  0  4  1  7  3   │
    └───────────────────────────────────────────┘
  • Uses more memory than the page table
  • Required memory: MAX_REQS × NUM_PAGES
  • For smaller models:
    • Number of pages can be up to 178,375
    • Calculation: (total tokens) / default_page_size = 2,854,000/16
    • Typical max_seq_len of 2048 = 128 Pages

Setting up a Generic Physical to Logical re-writer

Once we have this Physical to Logical Map we can abstract this away from different logical mask_mods. We do this w/ this function: https://github.com/vllm-project/vllm/pull/16078/files#diff-0310608cf47330020e617d94f28ce469e6e802e291f33ce4bce90e22e11cc7e5R196

By default this type of attention can be seen as a document-packed or var-seq len transformation so that all separate queries are splatted into 1 super sequence. We use have a small lookup from physical q_idx to req_id.

We use this req_id to get the inverse page_table. And with this inverse page table we isolate attention to valid sequences (valid blocks and < current seq_len)

We then adjust the q_idx by the offset - which is 0 during prefill.

Once all thats done, we can have a pure "logical mask mod" which by default - and for most models will be

def causal(b, h, q_idx, kv_idx):
   return q_idx >= kv_idx

Adding New Variants

The nice thing about the above setup is that it makes adding new variants simpler since we have a generic paged+packed rewriter. And users will only need register a simple logical mod here: https://github.com/drisspg/vllm/blob/891345dd545ad86ca57163f34d4ea7696610dea3/vllm/v1/attention/backends/flex_attention.py#L177

For score_mods I didnt put much effort and mostly disabled for now since that should be an easy follow up. For instance if you wanted to support tanh softcapping we would just need to pass in a tanh_score_mod here: https://github.com/drisspg/vllm/blob/891345dd545ad86ca57163f34d4ea7696610dea3/vllm/v1/attention/backends/flex_attention.py#L433

Performance Gaps (for now)

Trace for a baseline enforce_eager, but flex_components compiled for Qwen 1.5b w/ full KVCache e.g. 2.84 Million Tokens single GPU

Trace: https://fburl.com/1uo5h5uy
TLP: https://fburl.com/z5u609g8

In this case we see 2 frames and 1 recompile each for these. And in this case the recompiles are coming from difference in Query length size between prefill and decode. This is exactly what mark_dynamic should be able to solve but since we have a dynamic integer "Q_LEN" we need to use TORCH_COMPILE_DYNAMIC_SOURCES but I couldn't get this to work so punting for now and having the two recompiles.

In this case it is taking around 3.7 seconds before we settle into a steady state (for the given requests)
Screenshot 2025-04-22 at 7 23 22 PM

Lets zoom in on the steady state:
Screenshot 2025-04-23 at 5 11 08 PM

Example

# SPDX-License-Identifier: Apache-2.0

import os
import argparse
import random

import numpy as np
import torch

from vllm import LLM, SamplingParams


def set_seed(seed):
    """Set seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def main():
    parser = argparse.ArgumentParser(description="Run vLLM with FlexAttention backend")
    parser.add_argument("--model", type=str, default="meta-llama/Llama-3.2-1B",
                        help="Model name or path")
    parser.add_argument("--eager", action="store_true",
                        help="Run in eager mode (no compilation)")
    parser.add_argument("--num-blocks", type=int, default=128,
                        help="Number of GPU blocks")
    parser.add_argument("--tp-size", type=int, default=1,
                        help="Tensor parallel size")
    parser.add_argument("--max-tokens", type=int, default=128,
                        help="Maximum tokens to generate")

    args = parser.parse_args()

    # Set environment variable for FlexAttention
    os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"

    # Set seed
    set_seed(42)

    # Sample prompts
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    # Create sampling params
    sampling_params = SamplingParams(
        temperature=0.8,
        top_p=0.95,
        seed=42,
        max_tokens=args.max_tokens
    )

    # Create LLM
    print(f"Loading model: {args.model}")
    print(f"Mode: {'eager' if args.eager else 'compile'}")

    llm = LLM(
        model=args.model,
        tensor_parallel_size=args.tp_size,
        enforce_eager=args.eager,
        num_gpu_blocks_override=args.num_blocks,
    )

    # Generate
    print("\nGenerating responses...")
    outputs = llm.generate(prompts, sampling_params)

    # Print outputs
    print("\nResults:")
    print("-" * 80)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt}")
        print(f"Generated: {generated_text}")
        print("-" * 80)


if __name__ == "__main__":
    main()

@drisspg drisspg changed the title it errors [WIP] Add Flex to V1 Apr 4, 2025
Copy link

github-actions bot commented Apr 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation v1 labels Apr 4, 2025
@drisspg drisspg force-pushed the flex-attention branch 2 times, most recently from d173560 to 0420216 Compare April 4, 2025 23:27
@robertgshaw2-redhat
Copy link
Collaborator

Amazingly exciting!

@drisspg drisspg force-pushed the flex-attention branch 11 times, most recently from 8123983 to 0460b1b Compare April 10, 2025 22:47
@drisspg
Copy link
Contributor Author

drisspg commented Apr 15, 2025

I will flesh out the PR summary, one thing I have found is that we need a newer version of PT since I have fixed a number of dynamic shape issues since 2.6.0. On nightly PT VLLM_ENABLE_V1_MULTIPROCESSING=1 seems to be failing for both flex and default backend

ERROR 04-15 15:43:02 [core.py:387]   File "/home/drisspg/.conda/envs/vllm_main/lib/python3.12/site-packages/torch/cuda/__init__.py", line 364, in _lazy_init
ERROR 04-15 15:43:02 [core.py:387]     raise RuntimeError(
ERROR 04-15 15:43:02 [core.py:387] RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
ERROR 04-15 15:43:02 [core.py:387] 
CRITICAL 04-15 15:43:02 [core_client.py:359] Got fatal signal from worker processes, shutting down. See stack trace above for root cause issue.
zsh: killed     python vllm/flex.py

@drisspg
Copy link
Contributor Author

drisspg commented May 29, 2025

@WoosukKwon So we fixed a few more dynamic shape bugs, I have been testing this against latest pytorch nightly which might be why we are seeing different results, I am getting

Results:
--------------------------------------------------------------------------------
Prompt: Hello, my name is
Generated:  Kristi and I am so excited to be here today. I have been a homeschool mom for 15 years and am married to my best friend. We have 4 kids and live in Arkansas. I currently work for the Arkansas Dept. of Education as a teacher of special education.
--------------------------------------------------------------------------------
Prompt: The president of the United States is
Generated:  responsible for the government of the United States. The president serves a four-year term and is elected by popular vote. The United States has a parliamentary government, which means that the president and the members of the legislative branch work together to create a law.
The United States was established as a republic in 1776. It is an example of a democracy, which means that the citizens of the United States choose their government. The president is the head of state. He or she is elected for a term of four years. The vice president is the second highest-ranking official in the United States. He or she is elected for a term of four
--------------------------------------------------------------------------------
Prompt: The capital of France is
Generated:  Paris, the country’s fashion capital, home to some of the world’s most famous fashion houses. Travel to the south of France and visit the poppy fields of Provence and the French Riviera. Savor the culinary delights of Paris, watch the world’s greatest artists at the Louvre and enjoy a luxurious cruise along the Rhone River.
Welcome to Paris, the capital of France and a fashion capital of the world. Take some time to explore the city’s many museums, including the Louvre, the Musée d’Orsay and the Orangerie. You’ll also enjoy a cruise along the Seine River. Head
--------------------------------------------------------------------------------
Prompt: The future of AI is
Generated:  in the hands of students
The future of artificial intelligence is in the hands of college students, according to IBM’s latest report on the state of artificial intelligence.
In the report, the IBM Watson Group released a survey of 1,500 college students from around the world, including students in the U.S. and the United Kingdom.
Students were asked to rate their level of interest in AI on a scale of one to five, with five being the highest.
Students were also asked how much they believed AI was a threat to their jobs, and how much they believed AI was a tool to help them find a new job.
The report found
--------------------------------------------------------------------------------

Copy link
Collaborator

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. I think we should ship this and continue iterating on it (it's already gotten pretty large, and we have a sense of the cases where it doesn't work for yet).

@WoosukKwon
Copy link
Collaborator

I have been testing this against latest pytorch nightly

@drisspg Is torch nightly required? I'm now seeing reasonable outputs with torch v2.7.0.

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label May 30, 2025
@drisspg
Copy link
Contributor Author

drisspg commented May 30, 2025

Its possible to run into a few dynamic shapes issues, w/ 2.7.0 around max-autotune that we resolved, specifically this line: https://github.com/vllm-project/vllm/pull/16078/files#diff-0310608cf47330020e617d94f28ce469e6e802e291f33ce4bce90e22e11cc7e5R35 ideally we would compile w/ max-autotune. But as implemented it should be good w/ 2.7

@drisspg
Copy link
Contributor Author

drisspg commented Jun 1, 2025

@WoosukKwon I will update, someone suggested this name - does seem a little redundant 😂

@drisspg
Copy link
Contributor Author

drisspg commented Jun 2, 2025

Baseline

❯ lm_eval --model vllm \
--model_args pretrained=Qwen/Qwen3-8B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,num_gpu_blocks_override=512 \
--tasks lambada_openai \
--batch_size auto

...

INFO:lm_eval.loggers.evaluation_tracker:Output path not provided, skipping saving results aggregated
vllm (pretrained=Qwen/Qwen3-8B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,num_gpu_blocks_override=512), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto
|    Tasks     |Version|Filter|n-shot|  Metric  |   |Value |   |Stderr|
|--------------|------:|------|-----:|----------|---|-----:|---|-----:|
|lambada_openai|      1|none  |     0|acc       ||0.6503|±  |0.0066|
|              |       |none  |     0|perplexity||4.5999|±  |0.1390|

Flex

 VLLM_ATTENTION_BACKEND=FLEX_ATTENTION lm_eval --model vllm \
--model_args pretrained=Qwen/Qwen3-8B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,num_gpu_blocks_override=512 \
--tasks lambada_openai \
--batch_size auto

...
INFO:lm_eval.loggers.evaluation_tracker:Output path not provided, skipping saving results aggregated
vllm (pretrained=Qwen/Qwen3-8B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,num_gpu_blocks_override=512), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto
|    Tasks     |Version|Filter|n-shot|  Metric  |   |Value |   |Stderr|
|--------------|------:|------|-----:|----------|---|-----:|---|-----:|
|lambada_openai|      1|none  |     0|acc       ||0.6522|±  |0.0066|
|              |       |none  |     0|perplexity||4.6014|±  |0.1390|

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While not blocking, I placed several comments above 😄

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg I'm seeing 10x slowdown of e2e performance on sharegpt throughput benchmark (I'm using torch 2.7.0). Is this expected? Is this because re-compilation happens at run time?

@drisspg
Copy link
Contributor Author

drisspg commented Jun 3, 2025

@WoosukKwon Yup, as implemented this is expected and touched upon in the summary but it is not due to recompilation, the problem is that flex-attention and create_block mask consist of alot of small cpu bound operations in order to prep the metadata. Since Flex is currently in a custom up we can't use compile to hide this overhead. In traces I am seeing a 2x slowdown for actually attention calculation which aligns with what I expected comparing against FAv3

In a follow PR figuring out how to enable direct-call w/ proper metadata should greatly help this problem

Copy link

mergify bot commented Jun 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @drisspg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 4, 2025
Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea looks great to me, leave it to @WoosukKwon for final sign-off

@mergify mergify bot removed the needs-rebase label Jun 4, 2025
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGMT overall. Sorry for the delayed review 🙏
Looking forward to extending this to more diverse attention algorithms!
Also, hope we can get the fix for the recompilation issue.

@WoosukKwon
Copy link
Collaborator

@drisspg Could you please check the failed CI tests and rebase the PR? Will merge once the CI gets green. :)

@drisspg drisspg force-pushed the flex-attention branch 2 times, most recently from b5b2f14 to 5d2556d Compare June 7, 2025 00:39
@WoosukKwon WoosukKwon merged commit cf02f9b into vllm-project:main Jun 7, 2025
68 of 70 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants