Skip to content

Conversation

NickLucche
Copy link
Collaborator

@NickLucche NickLucche commented Jun 27, 2025

This PR enables the use of FlashInfer in a heterogeneous TP setting when using NixlConnector, particularly important for Blackwell systems since they will default to FlashInfer.

The main difference from FA is that the cache layout goes from [2, num_blocks, HND] to [num_blocks, 2, HND] where 2 is K/V.
With homogeneous TP, this layout change has no particular implication: quite the contrary, we can actually read both K and V in a single message (of size doubled).

image

In heterogeneous TP, we need to read a portion of heads (tp_ratio, eg half the heads) and we can do that efficiently with FA leveraging the HND layout, as we can just say eg "read cache[:2, :, :]" for worker 1 and "read cache[2:, :, :]" for worker 2, indexing on H.

With FlashInfer, we have K and V which are not interleaved in memory, as the dim "2" is now right before HND.
Attempting to read eg "half" the kv cache will result in reading all Ks, rather than half the heads for both K and V.

To address that, this PR will add a virtual split so that when flashinfer is detected K/V will be alternated just like FA when creating descriptors. This allows us to use the same logic as before for getting block_ids, while at the same time retaining the memory_registration advantage, effectively registering just num_layers region in NIXL, down from 2*num_layers for FA.

TL;DR: K/V must be alternated when reading to maintain consistency with FA. The number of regions actually registered is half of that of FA, but the number of descs is the same, so at the logical level you won't notice a difference.

Test with

# Flashinfer run
VLLM_ATTENTION_BACKEND=FLASHINFER NUM_DECODE_INSTANCES=1 DECODER_TP_SIZE=2 bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
./script.sh  1429.72s user 106.44s system 580% cpu 4:24.47 total

# FlashAttention run
NUM_DECODE_INSTANCES=1 DECODER_TP_SIZE=2 bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
 ./script.sh  1450.54s user 107.23s system 584% cpu 4:26.68 total

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @NickLucche, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for FlashInfer in a heterogeneous Tensor Parallel (TP) environment when utilizing the NixlConnector. The primary goal is to reconcile the differing KV cache memory layouts between FlashInfer and FlashAttention, enabling efficient and consistent block-level data transfer for distributed inference by introducing a virtual splitting mechanism for FlashInfer's KV cache.

Highlights

  • FlashInfer Heterogeneous TP Support: Enabled FlashInfer to work efficiently in heterogeneous Tensor Parallel (TP) settings by addressing its distinct KV cache memory layout compared to FlashAttention. This ensures consistent block-level data transfer for distributed inference.
  • Virtual KV Cache Splitting: Implemented a 'virtual split' mechanism for FlashInfer's KV cache. This involves doubling the logical number of regions and interleaving K and V block addresses during descriptor registration, allowing for separate indexing of K and V components while retaining FlashInfer's memory registration advantages.
  • Dynamic Block Length Calculation: Introduced a new helper method, get_backend_aware_block_len, to dynamically determine the effective block length based on the attention backend. For FlashInfer, this function returns half the block_len to facilitate separate K and V indexing.
  • Updated Descriptor Registration Logic: Modified the register_kv_caches and add_remote_agent methods to correctly register and retrieve block descriptors, accounting for the new virtual K/V splitting and adjusted block lengths specific to FlashInfer.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR aims to add support for FlashInfer in a heterogeneous Tensor Parallelism setting within the NixlConnector. The core idea is to create a "virtual" split for K/V caches to align with FlashAttention's behavior, which is a clever approach.

The changes look mostly correct and well-reasoned. I've identified a couple of critical issues related to descriptor creation logic that could lead to incorrect data transfers, and a minor issue with a data type change.

  • Descriptor Interleaving: The logic for creating blocks_data for FlashInfer in both register_kv_caches and add_remote_agent seems to produce a non-interleaved list of K and V blocks. This contradicts the logic in _get_block_descs_ids, which expects an interleaved layout. This needs to be fixed to ensure correct KV cache transfers.
  • Data Type: There's a change from integer division to float division for remote_block_size, which could cause issues. This variable also appears to be unused.

Once these points are addressed, the PR should be in good shape. Great work on tackling this complex integration!

Comment on lines 639 to 646
if self._use_flashinfer:
# To maintain the same descs ordering, K/V must be interleaved.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
addr = base_addr + block_offset
# Register addresses for V cache.
v_addr = addr + block_len
blocks_data.append((v_addr, block_len, self.tp_rank))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation for creating blocks_data for FlashInfer seems to be incorrect. The preceding loop (outside this diff) adds all K blocks, and this new block adds all V blocks. This results in a [K0..Kn, V0..Vn] layout for each layer's descriptors.

However, the existing logic in _get_block_descs_ids for FlashInfer expects an interleaved layout of [K0, V0, K1, V1, ...]. This mismatch will likely cause incorrect data to be transferred.

To fix this, the descriptor creation for K and V should be interleaved within a single loop. You'll need to modify the loop at line 631 to handle both FlashInfer and other backends correctly.

For example (conceptual):

if self._use_flashinfer:
    # Interleave K and V block registrations
    for block_id in range(self.num_blocks):
        # ... calculate addr for K
        blocks_data.append(...) # K
        # ... calculate addr for V
        blocks_data.append(...) # V
else:
    # Original logic for other backends
    for block_id in range(self.num_blocks):
        # ...
        blocks_data.append(...)

@NickLucche
Copy link
Collaborator Author

@wseaton already tested this on B200, thanks!

@NickLucche
Copy link
Collaborator Author

@tlrmchlsmth

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took me a minute to figure out what was going on - could use a high level comment explaining that we're halving the blocks in order to separate the K and V blocks into separate descs

# of 'virtual' regions here and double the descs below.
self.num_regions *= 2

block_len = self.get_backend_aware_block_len()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rare instance where I think the helper method actually makes it a little harder to understand what's going on.

Suggested change
block_len = self.get_backend_aware_block_len()
block_len = self.block_len // 2

Also there's a landmine here that should be managed -- someone may try to use self.block_len instead of block_len and run into issues here. Does it make sense to halve self.block_len in the flashinfer case instead?

@NickLucche
Copy link
Collaborator Author

@tlrmchlsmth I've made minor modifications to try and address clarity

@NickLucche NickLucche requested a review from tlrmchlsmth July 11, 2025 14:46
@NickLucche NickLucche force-pushed the heterogenous-tp-flashinfer2 branch from 9f2bd21 to 3631794 Compare July 25, 2025 16:36
@NickLucche
Copy link
Collaborator Author

I've also tested this PR with the upcoming num_blocks <--> 2(KV) dim swap here #21607. New default layout is going to be FlashInfer's

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM; thanks! we should think about if we omit the separate V-indexing in the non-hetro-tp case

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) August 18, 2025 16:29
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 18, 2025
Copy link

mergify bot commented Aug 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 22, 2025
auto-merge was automatically disabled August 22, 2025 15:59

Head branch was pushed to by a user without write access

@NickLucche NickLucche force-pushed the heterogenous-tp-flashinfer2 branch from 3631794 to 0a8de96 Compare August 22, 2025 15:59
@mergify mergify bot removed the needs-rebase label Aug 22, 2025
Signed-off-by: NickLucche <[email protected]>
improve clarity

Signed-off-by: NickLucche <[email protected]>
@NickLucche NickLucche force-pushed the heterogenous-tp-flashinfer2 branch from 0a8de96 to c3e2c88 Compare August 25, 2025 13:05
@NickLucche
Copy link
Collaborator Author

@LucasWilkinson is CI still broken? :(

[2025-08-25T15:19:29Z] FAILED distributed/test_comm_ops.py::test_multi_process_pipeline_parallel[test_target0-2] - ray.exceptions.RuntimeEnvSetupError: Failed to set up runtime environment.

@DarkLight1337 DarkLight1337 merged commit f0c503f into vllm-project:main Sep 3, 2025
43 checks passed
eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants