-
Notifications
You must be signed in to change notification settings - Fork 374
[comm] TRT-LLM's Multi-Node NVLink All-Reduce Kernel #1213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @nvmbreughe, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a significant enhancement to FlashInfer's communication capabilities by integrating a high-performance Multi-Node NVLink All-Reduce kernel. This enables efficient distributed computations across multiple GPUs and nodes, crucial for scaling large language models. The changes encompass the low-level CUDA kernel implementation, sophisticated Python-based memory management for multi-node GPU communication, and a robust testing framework to ensure correctness and reliability.
Highlights
- New Multi-Node All-Reduce Kernel: I've added the core Multi-Node NVLink All-Reduce kernel, adapted from TRT-LLM, implemented in CUDA C++. This kernel (
twoshot_allreduce_kernel
) performs a two-shot all-reduce operation, handling data scattering, reduction, and broadcasting across multiple GPUs and nodes using a combination of unicast and multicast memory operations. It supports various floating-point types (float, half, bfloat16) and is optimized for specific world sizes (2, 4, 8, 16, 32, 64). - Multi-Node NVLink Memory Management: New Python classes,
McastDeviceMemory
andMcastGPUBuffer
, have been introduced to manage the complex memory setup required for Multi-Node NVLink. These classes leverage the CUDA Driver API (pynvml
,cuda
module) to allocate, export, import, map, and bind shared memory regions (both unicast and multicast) across different GPUs and nodes, including handling fabric handles and signal pads for synchronization. - PyTorch Integration and Workspace Management: The new all-reduce functionality is seamlessly integrated into the PyTorch ecosystem as a custom operation. A dedicated Python module (
flashinfer/comm/trtllm_mnnvl_ar.py
) handles JIT compilation of the CUDA kernel, registers thetrtllm_mnnvl_all_reduce
function as a PyTorch op, and manages the lifecycle of necessary workspace buffers (McastGPUBuffer
,buffer_mnnvl
,buffer_flags_mnnvl
) using thread-local storage for efficient reuse. - Multi-Node Testing Framework: A new multi-node test (
tests/test_trtllm_mnnvl_allreduce.py
) has been added to validate the correctness of the TRT-LLM Multi-Node NVLink All-Reduce kernel. This test utilizesmpi4py
to simulate a distributed environment, sets up the necessary communication mapping, and verifies the all-reduce results against a referencetorch.sum
operation, ensuring proper functionality in a multi-GPU, multi-node setup.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a multi-node NVLink All-Reduce kernel from TensorRT-LLM, along with the necessary Python-side infrastructure for memory management and kernel invocation. The changes are extensive, touching upon low-level CUDA memory management, MPI communication, and PyTorch C++ extensions.
Overall, the implementation of the CUDA kernel and its Python bindings seems solid and follows established patterns for high-performance computing. However, I've identified a critical resource leak in the CUDA memory management within flashinfer/comm/mnnvl.py
, which must be addressed. There are also a few medium-severity issues related to maintainability and error handling that should be considered.
Once the critical and high-severity issues are resolved, this will be a valuable addition for multi-node performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, the docker image have been updated in https://github.com/flashinfer-ai/flashinfer/actions/runs/16137362896
and I have append an empty commit to this PR to trigger CI with latest docker image.
cc @yyihuang for another look on this PR before we merge it. |
@@ -30,6 +30,12 @@ | |||
from .trtllm_ar import ( | |||
trtllm_moe_finalize_allreduce_fusion as trtllm_moe_finalize_allreduce_fusion, | |||
) | |||
from .trtllm_mnnvl_ar import ( | |||
gen_trtllm_mnnvl_comm_module, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we include mnnvl comm in aot build?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the past @yzh119 suggested not to do this yet. Maybe we can do it as part of another PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessary at this moment.
Segfault was coming from incorrect usage of the unicast buffers during initialization. Reviewers advised against setting env variables and suggested another way (monkeypatch). This is now implemented.
flashinfer/comm/trtllm_mnnvl_ar.py
Outdated
wait_for_results: bool, | ||
launch_with_pdl: bool, | ||
) -> None: | ||
"""MNNVL all-reduce operation""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the main public API right?
This probably could deserve a complete description, including the expectation of the list of parameters and the invariants before/after.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out! The documentation must have been removed during my cleanup. I added an improved version now.
There is another main one: the workspace setup. It has documentation also now.
* Documentation to the main APIs * Removed __main__ from test file * Added dtype check to lamport_initialize
📌 Description
This PR adds TRT LLM's Multi-Node NVLink All-Reduce kernel. This includes the necessary code to set up the multicast and unicast buffers on the GPU.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).This test requires multiple nodes and can be run as follows:
srun -N2 --container-image=<your_aarch64_container> --mpi=pmix --container-name=mb_flashinfer --container-mounts=<parentdir_to_flash_infer>:/home -- bash -c 'cd /home/flashinfer && python -m tests.test_trtllm_mnnvl_allreduce
Reviewer Notes