-
Notifications
You must be signed in to change notification settings - Fork 374
Port AllGather/ReduceScatter from TensorRT-LLM #1145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @wenscarl, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request ports essential collective communication primitives, specifically AllGather and ReduceScatter, along with their underlying MPI and NCCL utility infrastructure from TensorRT-LLM. This work lays the foundation for enabling various forms of distributed parallelism (like Tensor Parallelism) within FlashInfer, allowing operations to scale across multiple GPUs and nodes. A new Python Mapping
class is introduced to manage these complex parallelism configurations.
Highlights
- Ported Collective Operations: Introduced AllGather and ReduceScatter collective communication operations, ported from TensorRT-LLM, enabling multi-GPU/multi-node capabilities for relevant parts of the codebase.
- MPI/NCCL Utilities: Ported core MPI and NCCL utility functions and classes (like
MpiComm
,MpiRequest
, and NCCL communicator management) necessary for distributed operations, guarded by theENABLE_MULTI_DEVICE
flag. - Parallelism Mapping: Added a new
Mapping
class in Python to represent and manage various parallelism configurations (Tensor Parallelism, Pipeline Parallelism, Context Parallelism, MoE Parallelism, etc.) and provide rank/group information. - PyTorch Custom Operations: Implemented PyTorch custom CUDA operations (
trtllm_allgather
,trtllm_reducescatter
) that wrap the underlying NCCL calls, making these collectives available in the Python frontend. - Python API Wrappers: Added user-friendly Python functions (
all_gather
,reduce_scatter
) inflashinfer.comm
that utilize the new custom operations and theMapping
object to handle distributed tensors. - Integration and Testing: Integrated the new C++/CUDA files into the JIT compilation process and added multi-GPU tests using
mpi4py
to verify the correctness of the ported collective operations.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request ports AllGather and ReduceScatter functionalities from TensorRT-LLM, introducing new C++/CUDA custom operations for PyTorch, along with MPI/NCCL utilities and Python bindings. A new Mapping
class is added to manage distributed configurations, and tests for these collective operations are included. Key areas for improvement include addressing file duplication, enhancing portability of build configurations (MPI/NCCL path discovery), fixing potential runtime errors, and ensuring code correctness in the new Mapping
class and C++ utilities.
MPICHECK(MPI_Type_size(dtype, &elementSize)); | ||
elementSize = std::min<int>(elementSize, alignment); | ||
|
||
// We cap at max alignment-bytes chunks that can be sent at once. | ||
auto const step = maxP1 - (alignment / elementSize); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a potential division by zero if elementSize
is zero. MPI_Type_size
could return 0 for certain types (e.g., MPI_DATATYPE_NULL
or if an error occurs, though MPICHECK
should catch MPI errors). If elementSize
becomes 0, the calculation alignment / elementSize
in step = maxP1 - (alignment / elementSize)
would lead to a runtime error. Consider adding a check for elementSize <= 0
before this line.
flashinfer/comm.py
Outdated
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", | ||
mpi_include_path, | ||
], | ||
extra_ldflags=[f"-L{mpi_lib_path}", "-lmpi", "-L/usr/lib/aarch64-linux-gnu/ -lnccl"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NCCL library path -L/usr/lib/aarch64-linux-gnu/
is hardcoded. This will cause build failures on systems with different architectures (e.g., x86_64) or where NCCL is installed in a non-standard location. NCCL paths should be discovered dynamically (e.g., from NCCL_HOME
environment variable or by searching standard library paths) or configurable by the user.
extra_ldflags=[f"-L{mpi_lib_path}", "-lmpi"], # NCCL path needs to be discovered or configurable
# Example: Add discovered NCCL lib path and -lnccl here
#if ENABLE_MULTI_DEVICE | ||
|
||
std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap() | ||
{ | ||
static std::unordered_map<nvinfer1::DataType, ncclDataType_t> dtypeMap = { | ||
{nvinfer1::DataType::kFLOAT, ncclFloat32}, | ||
{nvinfer1::DataType::kHALF, ncclFloat16}, | ||
{nvinfer1::DataType::kBF16, ncclBfloat16}, | ||
{nvinfer1::DataType::kFP8, ncclInt8}, | ||
{nvinfer1::DataType::kBOOL, ncclInt8}, | ||
{nvinfer1::DataType::kINT32, ncclInt32}, | ||
{nvinfer1::DataType::kINT64, ncclInt64}, | ||
{nvinfer1::DataType::kUINT8, ncclUint8}, | ||
{nvinfer1::DataType::kINT8, ncclInt8}, | ||
}; | ||
return &dtypeMap; | ||
} | ||
|
||
namespace | ||
{ | ||
|
||
// Get NCCL unique ID for a group of ranks. | ||
ncclUniqueId getUniqueId(std::set<int> const& group) | ||
{ | ||
auto const rank = COMM_SESSION.getRank(); | ||
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); | ||
ncclUniqueId id; | ||
if (rank == *group.begin()) | ||
{ | ||
NCCLCHECK_THROW(ncclGetUniqueId(&id)); | ||
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) | ||
{ | ||
COMM_SESSION.sendValue(id, *it, tensorrt_llm::mpi::MpiTag::kDefault); | ||
} | ||
} | ||
else | ||
{ | ||
COMM_SESSION.recvValue(id, *group.begin(), tensorrt_llm::mpi::MpiTag::kDefault); | ||
} | ||
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); | ||
return id; | ||
} | ||
} // namespace | ||
|
||
std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group) | ||
{ | ||
auto const rank = COMM_SESSION.getRank(); | ||
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); | ||
static std::map<std::set<int>, std::shared_ptr<ncclComm_t>> commMap; | ||
static std::mutex mutex; | ||
std::lock_guard<std::mutex> lock(mutex); | ||
std::ostringstream oss; | ||
int index = 0; | ||
for (auto const& rank : group) | ||
{ | ||
if (index != 0) | ||
{ | ||
oss << ","; | ||
} | ||
oss << rank; | ||
index++; | ||
} | ||
auto groupStr = oss.str(); | ||
auto it = commMap.find(group); | ||
if (it != commMap.end()) | ||
{ | ||
auto ncclComm = it->second; | ||
TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank); | ||
return ncclComm; | ||
} | ||
|
||
TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank); | ||
ncclUniqueId id = getUniqueId(group); | ||
int groupRank = 0; | ||
for (auto const& currentRank : group) | ||
{ | ||
if (rank == currentRank) | ||
break; | ||
++groupRank; | ||
} | ||
TLLM_CHECK(static_cast<size_t>(groupRank) < group.size()); | ||
std::shared_ptr<ncclComm_t> ncclComm(new ncclComm_t, | ||
[](ncclComm_t* comm) | ||
{ | ||
ncclCommDestroy(*comm); | ||
delete comm; | ||
}); | ||
// Need static connection initialization for accurate KV cache size estimation | ||
#if defined(_WIN32) | ||
if (getenv("NCCL_RUNTIME_CONNECT") == nullptr) | ||
_putenv_s("NCCL_RUNTIME_CONNECT", "0"); | ||
#else | ||
setenv("NCCL_RUNTIME_CONNECT", "0", 0); | ||
#endif // _WIN32 | ||
NCCLCHECK_THROW(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank)); | ||
commMap[group] = ncclComm; | ||
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); | ||
return ncclComm; | ||
} | ||
#endif // ENABLE_MULTI_DEVICE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file appears to be identical to csrc/nv_internal/tensorrt_llm/runtime/opUtils.cpp
. Code duplication can lead to maintenance issues. Please consolidate them into a single file. Based on the include paths used in flashinfer/comm.py
(nv_internal/cpp/common/opUtils.cpp
), this location seems more appropriate if tensorrt_llm/common/opUtils.h
is the intended header.
#if ENABLE_MULTI_DEVICE | ||
|
||
std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap() | ||
{ | ||
static std::unordered_map<nvinfer1::DataType, ncclDataType_t> dtypeMap = { | ||
{nvinfer1::DataType::kFLOAT, ncclFloat32}, | ||
{nvinfer1::DataType::kHALF, ncclFloat16}, | ||
{nvinfer1::DataType::kBF16, ncclBfloat16}, | ||
{nvinfer1::DataType::kFP8, ncclInt8}, | ||
{nvinfer1::DataType::kBOOL, ncclInt8}, | ||
{nvinfer1::DataType::kINT32, ncclInt32}, | ||
{nvinfer1::DataType::kINT64, ncclInt64}, | ||
{nvinfer1::DataType::kUINT8, ncclUint8}, | ||
{nvinfer1::DataType::kINT8, ncclInt8}, | ||
}; | ||
return &dtypeMap; | ||
} | ||
|
||
namespace | ||
{ | ||
|
||
// Get NCCL unique ID for a group of ranks. | ||
ncclUniqueId getUniqueId(std::set<int> const& group) | ||
{ | ||
auto const rank = COMM_SESSION.getRank(); | ||
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); | ||
ncclUniqueId id; | ||
if (rank == *group.begin()) | ||
{ | ||
NCCLCHECK_THROW(ncclGetUniqueId(&id)); | ||
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) | ||
{ | ||
COMM_SESSION.sendValue(id, *it, tensorrt_llm::mpi::MpiTag::kDefault); | ||
} | ||
} | ||
else | ||
{ | ||
COMM_SESSION.recvValue(id, *group.begin(), tensorrt_llm::mpi::MpiTag::kDefault); | ||
} | ||
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); | ||
return id; | ||
} | ||
} // namespace | ||
|
||
std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group) | ||
{ | ||
auto const rank = COMM_SESSION.getRank(); | ||
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); | ||
static std::map<std::set<int>, std::shared_ptr<ncclComm_t>> commMap; | ||
static std::mutex mutex; | ||
std::lock_guard<std::mutex> lock(mutex); | ||
std::ostringstream oss; | ||
int index = 0; | ||
for (auto const& rank : group) | ||
{ | ||
if (index != 0) | ||
{ | ||
oss << ","; | ||
} | ||
oss << rank; | ||
index++; | ||
} | ||
auto groupStr = oss.str(); | ||
auto it = commMap.find(group); | ||
if (it != commMap.end()) | ||
{ | ||
auto ncclComm = it->second; | ||
TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank); | ||
return ncclComm; | ||
} | ||
|
||
TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank); | ||
ncclUniqueId id = getUniqueId(group); | ||
int groupRank = 0; | ||
for (auto const& currentRank : group) | ||
{ | ||
if (rank == currentRank) | ||
break; | ||
++groupRank; | ||
} | ||
TLLM_CHECK(static_cast<size_t>(groupRank) < group.size()); | ||
std::shared_ptr<ncclComm_t> ncclComm(new ncclComm_t, | ||
[](ncclComm_t* comm) | ||
{ | ||
ncclCommDestroy(*comm); | ||
delete comm; | ||
}); | ||
// Need static connection initialization for accurate KV cache size estimation | ||
#if defined(_WIN32) | ||
if (getenv("NCCL_RUNTIME_CONNECT") == nullptr) | ||
_putenv_s("NCCL_RUNTIME_CONNECT", "0"); | ||
#else | ||
setenv("NCCL_RUNTIME_CONNECT", "0", 0); | ||
#endif // _WIN32 | ||
NCCLCHECK_THROW(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank)); | ||
commMap[group] = ncclComm; | ||
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); | ||
return ncclComm; | ||
} | ||
#endif // ENABLE_MULTI_DEVICE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -16,17 +16,19 @@ | |||
|
|||
import ctypes | |||
import functools | |||
import math | |||
from dataclasses import dataclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def gen_comm_module() -> JitSpec: | ||
mpi_include_path, mpi_lib_path = get_mpi_include_lib_path() | ||
mpi_lib_path = str(mpi_lib_path[0]) | ||
print(mpi_include_path, mpi_lib_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); | ||
} | ||
|
||
} // namespace tensorrt_llm::mpi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flashinfer/mapping.py
Outdated
if tp_size != 1 or pp_size != 1 or tp_size != 1: | ||
raise ValueError( | ||
f"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, but got {tp_size}, {pp_size}, {cp_size}." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition if tp_size != 1 or pp_size != 1 or tp_size != 1:
has a duplicated check for tp_size != 1
. It seems the intention might have been to check cp_size != 1
as well.
if tp_size != 1 or pp_size != 1 or tp_size != 1: | |
raise ValueError( | |
f"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, but got {tp_size}, {pp_size}, {cp_size}." | |
) | |
if tp_size != 1 or pp_size != 1 or cp_size != 1: # Assuming cp_size was intended for the last check | |
raise ValueError( | |
f"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, but got {tp_size}, {pp_size}, {cp_size}." | |
) |
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); | ||
return ncclComm; | ||
} | ||
#endif // ENABLE_MULTI_DEVICE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/test_trtllm_collectives.py
Outdated
print(f"Running test for world_size={world_size}") | ||
_run_reduce_scatter_worker(rank, world_size, dtype) | ||
|
||
print(f"reduce_scatter with tp = {world_size}: OK") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CI error is because of mpi is missing from our dockerfile. We can add mpi installation to https://github.com/flashinfer-ai/flashinfer/blob/main/docker/Dockerfile.ci_gpu by: conda install mpi4py after the updated dockerfile is merged, we can manually trigger to update the docker image on dockerhub. |
<!-- .github/pull_request_template.md --> ## 📌 Description Install the python packages for CI docker: mpi4py, pynvml. They will be used for the comm ops. ## 🔍 Related Issues #1145, #1134 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
<!-- .github/pull_request_template.md --> ## 📌 Description Install the python packages for CI docker: mpi4py, pynvml. They will be used for the comm ops. ## 🔍 Related Issues flashinfer-ai#1145, flashinfer-ai#1134 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
📌 Port AllGather/ReduceScatter from TensorRT-LLM
🔍 Related Issues
This PR introduces dependency on mpi4py.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes