Skip to content

[comm] TRT-LLM's Multi-Node NVLink All-Reduce Kernel #1213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 11, 2025

Conversation

nvmbreughe
Copy link
Contributor

@nvmbreughe nvmbreughe commented Jul 4, 2025

📌 Description

This PR adds TRT LLM's Multi-Node NVLink All-Reduce kernel. This includes the necessary code to set up the multicast and unicast buffers on the GPU.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • [] I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • [V] I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

This test requires multiple nodes and can be run as follows:
srun -N2 --container-image=<your_aarch64_container> --mpi=pmix --container-name=mb_flashinfer --container-mounts=<parentdir_to_flash_infer>:/home -- bash -c 'cd /home/flashinfer && python -m tests.test_trtllm_mnnvl_allreduce

Reviewer Notes

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @nvmbreughe, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant enhancement to FlashInfer's communication capabilities by integrating a high-performance Multi-Node NVLink All-Reduce kernel. This enables efficient distributed computations across multiple GPUs and nodes, crucial for scaling large language models. The changes encompass the low-level CUDA kernel implementation, sophisticated Python-based memory management for multi-node GPU communication, and a robust testing framework to ensure correctness and reliability.

Highlights

  • New Multi-Node All-Reduce Kernel: I've added the core Multi-Node NVLink All-Reduce kernel, adapted from TRT-LLM, implemented in CUDA C++. This kernel (twoshot_allreduce_kernel) performs a two-shot all-reduce operation, handling data scattering, reduction, and broadcasting across multiple GPUs and nodes using a combination of unicast and multicast memory operations. It supports various floating-point types (float, half, bfloat16) and is optimized for specific world sizes (2, 4, 8, 16, 32, 64).
  • Multi-Node NVLink Memory Management: New Python classes, McastDeviceMemory and McastGPUBuffer, have been introduced to manage the complex memory setup required for Multi-Node NVLink. These classes leverage the CUDA Driver API (pynvml, cuda module) to allocate, export, import, map, and bind shared memory regions (both unicast and multicast) across different GPUs and nodes, including handling fabric handles and signal pads for synchronization.
  • PyTorch Integration and Workspace Management: The new all-reduce functionality is seamlessly integrated into the PyTorch ecosystem as a custom operation. A dedicated Python module (flashinfer/comm/trtllm_mnnvl_ar.py) handles JIT compilation of the CUDA kernel, registers the trtllm_mnnvl_all_reduce function as a PyTorch op, and manages the lifecycle of necessary workspace buffers (McastGPUBuffer, buffer_mnnvl, buffer_flags_mnnvl) using thread-local storage for efficient reuse.
  • Multi-Node Testing Framework: A new multi-node test (tests/test_trtllm_mnnvl_allreduce.py) has been added to validate the correctness of the TRT-LLM Multi-Node NVLink All-Reduce kernel. This test utilizes mpi4py to simulate a distributed environment, sets up the necessary communication mapping, and verifies the all-reduce results against a reference torch.sum operation, ensuring proper functionality in a multi-GPU, multi-node setup.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a multi-node NVLink All-Reduce kernel from TensorRT-LLM, along with the necessary Python-side infrastructure for memory management and kernel invocation. The changes are extensive, touching upon low-level CUDA memory management, MPI communication, and PyTorch C++ extensions.

Overall, the implementation of the CUDA kernel and its Python bindings seems solid and follows established patterns for high-performance computing. However, I've identified a critical resource leak in the CUDA memory management within flashinfer/comm/mnnvl.py, which must be addressed. There are also a few medium-severity issues related to maintainability and error handling that should be considered.

Once the critical and high-severity issues are resolved, this will be a valuable addition for multi-node performance.

@yzh119 yzh119 changed the title We add TRT-LLM's Multi-Node NVLink All-Reduce Kernel [comm] TRT-LLM's Multi-Node NVLink All-Reduce Kernel Jul 5, 2025
@yyihuang yyihuang self-requested a review July 8, 2025 02:10
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, the docker image have been updated in https://github.com/flashinfer-ai/flashinfer/actions/runs/16137362896
and I have append an empty commit to this PR to trigger CI with latest docker image.

@yzh119
Copy link
Collaborator

yzh119 commented Jul 8, 2025

cc @yyihuang for another look on this PR before we merge it.

@@ -30,6 +30,12 @@
from .trtllm_ar import (
trtllm_moe_finalize_allreduce_fusion as trtllm_moe_finalize_allreduce_fusion,
)
from .trtllm_mnnvl_ar import (
gen_trtllm_mnnvl_comm_module,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we include mnnvl comm in aot build?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the past @yzh119 suggested not to do this yet. Maybe we can do it as part of another PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary at this moment.

Segfault was coming from incorrect usage of the unicast buffers during
initialization.

Reviewers advised against setting env variables and suggested another
way (monkeypatch). This is now implemented.
wait_for_results: bool,
launch_with_pdl: bool,
) -> None:
"""MNNVL all-reduce operation"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the main public API right?
This probably could deserve a complete description, including the expectation of the list of parameters and the invariants before/after.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out! The documentation must have been removed during my cleanup. I added an improved version now.

There is another main one: the workspace setup. It has documentation also now.

* Documentation to the main APIs
* Removed __main__ from test file
* Added dtype check to lamport_initialize
@yzh119
Copy link
Collaborator

yzh119 commented Jul 11, 2025

I think it's ready to be merged, thank you @nvmbreughe !

@yzh119 yzh119 merged commit a03c290 into flashinfer-ai:main Jul 11, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants