Skip to content

Port AllGather/ReduceScatter from TensorRT-LLM #1145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

wenscarl
Copy link
Collaborator

@wenscarl wenscarl commented Jun 15, 2025

📌 Port AllGather/ReduceScatter from TensorRT-LLM

🔍 Related Issues

This PR introduces dependency on mpi4py.

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @wenscarl, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request ports essential collective communication primitives, specifically AllGather and ReduceScatter, along with their underlying MPI and NCCL utility infrastructure from TensorRT-LLM. This work lays the foundation for enabling various forms of distributed parallelism (like Tensor Parallelism) within FlashInfer, allowing operations to scale across multiple GPUs and nodes. A new Python Mapping class is introduced to manage these complex parallelism configurations.

Highlights

  • Ported Collective Operations: Introduced AllGather and ReduceScatter collective communication operations, ported from TensorRT-LLM, enabling multi-GPU/multi-node capabilities for relevant parts of the codebase.
  • MPI/NCCL Utilities: Ported core MPI and NCCL utility functions and classes (like MpiComm, MpiRequest, and NCCL communicator management) necessary for distributed operations, guarded by the ENABLE_MULTI_DEVICE flag.
  • Parallelism Mapping: Added a new Mapping class in Python to represent and manage various parallelism configurations (Tensor Parallelism, Pipeline Parallelism, Context Parallelism, MoE Parallelism, etc.) and provide rank/group information.
  • PyTorch Custom Operations: Implemented PyTorch custom CUDA operations (trtllm_allgather, trtllm_reducescatter) that wrap the underlying NCCL calls, making these collectives available in the Python frontend.
  • Python API Wrappers: Added user-friendly Python functions (all_gather, reduce_scatter) in flashinfer.comm that utilize the new custom operations and the Mapping object to handle distributed tensors.
  • Integration and Testing: Integrated the new C++/CUDA files into the JIT compilation process and added multi-GPU tests using mpi4py to verify the correctness of the ported collective operations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request ports AllGather and ReduceScatter functionalities from TensorRT-LLM, introducing new C++/CUDA custom operations for PyTorch, along with MPI/NCCL utilities and Python bindings. A new Mapping class is added to manage distributed configurations, and tests for these collective operations are included. Key areas for improvement include addressing file duplication, enhancing portability of build configurations (MPI/NCCL path discovery), fixing potential runtime errors, and ensuring code correctness in the new Mapping class and C++ utilities.

Comment on lines 246 to 250
MPICHECK(MPI_Type_size(dtype, &elementSize));
elementSize = std::min<int>(elementSize, alignment);

// We cap at max alignment-bytes chunks that can be sent at once.
auto const step = maxP1 - (alignment / elementSize);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential division by zero if elementSize is zero. MPI_Type_size could return 0 for certain types (e.g., MPI_DATATYPE_NULL or if an error occurs, though MPICHECK should catch MPI errors). If elementSize becomes 0, the calculation alignment / elementSize in step = maxP1 - (alignment / elementSize) would lead to a runtime error. Consider adding a check for elementSize <= 0 before this line.

jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include",
mpi_include_path,
],
extra_ldflags=[f"-L{mpi_lib_path}", "-lmpi", "-L/usr/lib/aarch64-linux-gnu/ -lnccl"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The NCCL library path -L/usr/lib/aarch64-linux-gnu/ is hardcoded. This will cause build failures on systems with different architectures (e.g., x86_64) or where NCCL is installed in a non-standard location. NCCL paths should be discovered dynamically (e.g., from NCCL_HOME environment variable or by searching standard library paths) or configurable by the user.

        extra_ldflags=[f"-L{mpi_lib_path}", "-lmpi"], # NCCL path needs to be discovered or configurable
        # Example: Add discovered NCCL lib path and -lnccl here

Comment on lines 1 to 130
#if ENABLE_MULTI_DEVICE

std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap()
{
static std::unordered_map<nvinfer1::DataType, ncclDataType_t> dtypeMap = {
{nvinfer1::DataType::kFLOAT, ncclFloat32},
{nvinfer1::DataType::kHALF, ncclFloat16},
{nvinfer1::DataType::kBF16, ncclBfloat16},
{nvinfer1::DataType::kFP8, ncclInt8},
{nvinfer1::DataType::kBOOL, ncclInt8},
{nvinfer1::DataType::kINT32, ncclInt32},
{nvinfer1::DataType::kINT64, ncclInt64},
{nvinfer1::DataType::kUINT8, ncclUint8},
{nvinfer1::DataType::kINT8, ncclInt8},
};
return &dtypeMap;
}

namespace
{

// Get NCCL unique ID for a group of ranks.
ncclUniqueId getUniqueId(std::set<int> const& group)
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
ncclUniqueId id;
if (rank == *group.begin())
{
NCCLCHECK_THROW(ncclGetUniqueId(&id));
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
COMM_SESSION.sendValue(id, *it, tensorrt_llm::mpi::MpiTag::kDefault);
}
}
else
{
COMM_SESSION.recvValue(id, *group.begin(), tensorrt_llm::mpi::MpiTag::kDefault);
}
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return id;
}
} // namespace

std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group)
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
static std::map<std::set<int>, std::shared_ptr<ncclComm_t>> commMap;
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
std::ostringstream oss;
int index = 0;
for (auto const& rank : group)
{
if (index != 0)
{
oss << ",";
}
oss << rank;
index++;
}
auto groupStr = oss.str();
auto it = commMap.find(group);
if (it != commMap.end())
{
auto ncclComm = it->second;
TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank);
return ncclComm;
}

TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank);
ncclUniqueId id = getUniqueId(group);
int groupRank = 0;
for (auto const& currentRank : group)
{
if (rank == currentRank)
break;
++groupRank;
}
TLLM_CHECK(static_cast<size_t>(groupRank) < group.size());
std::shared_ptr<ncclComm_t> ncclComm(new ncclComm_t,
[](ncclComm_t* comm)
{
ncclCommDestroy(*comm);
delete comm;
});
// Need static connection initialization for accurate KV cache size estimation
#if defined(_WIN32)
if (getenv("NCCL_RUNTIME_CONNECT") == nullptr)
_putenv_s("NCCL_RUNTIME_CONNECT", "0");
#else
setenv("NCCL_RUNTIME_CONNECT", "0", 0);
#endif // _WIN32
NCCLCHECK_THROW(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank));
commMap[group] = ncclComm;
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return ncclComm;
}
#endif // ENABLE_MULTI_DEVICE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This file appears to be identical to csrc/nv_internal/tensorrt_llm/runtime/opUtils.cpp. Code duplication can lead to maintenance issues. Please consolidate them into a single file. Based on the include paths used in flashinfer/comm.py (nv_internal/cpp/common/opUtils.cpp), this location seems more appropriate if tensorrt_llm/common/opUtils.h is the intended header.

Comment on lines 1 to 130
#if ENABLE_MULTI_DEVICE

std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap()
{
static std::unordered_map<nvinfer1::DataType, ncclDataType_t> dtypeMap = {
{nvinfer1::DataType::kFLOAT, ncclFloat32},
{nvinfer1::DataType::kHALF, ncclFloat16},
{nvinfer1::DataType::kBF16, ncclBfloat16},
{nvinfer1::DataType::kFP8, ncclInt8},
{nvinfer1::DataType::kBOOL, ncclInt8},
{nvinfer1::DataType::kINT32, ncclInt32},
{nvinfer1::DataType::kINT64, ncclInt64},
{nvinfer1::DataType::kUINT8, ncclUint8},
{nvinfer1::DataType::kINT8, ncclInt8},
};
return &dtypeMap;
}

namespace
{

// Get NCCL unique ID for a group of ranks.
ncclUniqueId getUniqueId(std::set<int> const& group)
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
ncclUniqueId id;
if (rank == *group.begin())
{
NCCLCHECK_THROW(ncclGetUniqueId(&id));
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
COMM_SESSION.sendValue(id, *it, tensorrt_llm::mpi::MpiTag::kDefault);
}
}
else
{
COMM_SESSION.recvValue(id, *group.begin(), tensorrt_llm::mpi::MpiTag::kDefault);
}
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return id;
}
} // namespace

std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group)
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
static std::map<std::set<int>, std::shared_ptr<ncclComm_t>> commMap;
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
std::ostringstream oss;
int index = 0;
for (auto const& rank : group)
{
if (index != 0)
{
oss << ",";
}
oss << rank;
index++;
}
auto groupStr = oss.str();
auto it = commMap.find(group);
if (it != commMap.end())
{
auto ncclComm = it->second;
TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank);
return ncclComm;
}

TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank);
ncclUniqueId id = getUniqueId(group);
int groupRank = 0;
for (auto const& currentRank : group)
{
if (rank == currentRank)
break;
++groupRank;
}
TLLM_CHECK(static_cast<size_t>(groupRank) < group.size());
std::shared_ptr<ncclComm_t> ncclComm(new ncclComm_t,
[](ncclComm_t* comm)
{
ncclCommDestroy(*comm);
delete comm;
});
// Need static connection initialization for accurate KV cache size estimation
#if defined(_WIN32)
if (getenv("NCCL_RUNTIME_CONNECT") == nullptr)
_putenv_s("NCCL_RUNTIME_CONNECT", "0");
#else
setenv("NCCL_RUNTIME_CONNECT", "0", 0);
#endif // _WIN32
NCCLCHECK_THROW(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank));
commMap[group] = ncclComm;
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return ncclComm;
}
#endif // ENABLE_MULTI_DEVICE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This file appears to be identical to csrc/nv_internal/cpp/common/opUtils.cpp. Please remove this duplicated file to avoid maintenance issues. The JIT compilation in flashinfer/comm.py references the other path.

@@ -16,17 +16,19 @@

import ctypes
import functools
import math
from dataclasses import dataclass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The accumulate function from itertools is used in restore_full_output (line 288) but not imported. Please add from itertools import accumulate.

def gen_comm_module() -> JitSpec:
mpi_include_path, mpi_lib_path = get_mpi_include_lib_path()
mpi_lib_path = str(mpi_lib_path[0])
print(mpi_include_path, mpi_lib_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement seems to be for debugging. It should be removed or guarded by a debug flag in production code.

Suggested change
print(mpi_include_path, mpi_lib_path)
# print(mpi_include_path, mpi_lib_path) # Consider removing or using a logger for debugging

TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}

} // namespace tensorrt_llm::mpi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file is missing a newline character at the end.

Suggested change
} // namespace tensorrt_llm::mpi

Comment on lines 163 to 166
if tp_size != 1 or pp_size != 1 or tp_size != 1:
raise ValueError(
f"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, but got {tp_size}, {pp_size}, {cp_size}."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition if tp_size != 1 or pp_size != 1 or tp_size != 1: has a duplicated check for tp_size != 1. It seems the intention might have been to check cp_size != 1 as well.

Suggested change
if tp_size != 1 or pp_size != 1 or tp_size != 1:
raise ValueError(
f"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, but got {tp_size}, {pp_size}, {cp_size}."
)
if tp_size != 1 or pp_size != 1 or cp_size != 1: # Assuming cp_size was intended for the last check
raise ValueError(
f"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, but got {tp_size}, {pp_size}, {cp_size}."
)

TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return ncclComm;
}
#endif // ENABLE_MULTI_DEVICE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file is missing a newline character at the end. It's a common convention to end files with a newline.

Suggested change
#endif // ENABLE_MULTI_DEVICE

print(f"Running test for world_size={world_size}")
_run_reduce_scatter_worker(rank, world_size, dtype)

print(f"reduce_scatter with tp = {world_size}: OK")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file is missing a newline character at the end.

Suggested change
print(f"reduce_scatter with tp = {world_size}: OK")

@wenscarl wenscarl marked this pull request as ready for review June 15, 2025 20:49
@wenscarl wenscarl requested a review from yzh119 June 15, 2025 20:49
@wenscarl wenscarl requested a review from yzh119 June 16, 2025 04:12
@yzh119
Copy link
Collaborator

yzh119 commented Jun 16, 2025

The CI error is because of mpi is missing from our dockerfile.
cc @yongwww would you mind taking a look and installing mpi to our docker image?

We can add mpi installation to https://github.com/flashinfer-ai/flashinfer/blob/main/docker/Dockerfile.ci_gpu by:

conda install mpi4py

after the updated dockerfile is merged, we can manually trigger
https://github.com/flashinfer-ai/flashinfer/actions/workflows/release-ci-docker.yml

to update the docker image on dockerhub.

@yongwww yongwww mentioned this pull request Jun 16, 2025
5 tasks
yzh119 pushed a commit that referenced this pull request Jun 17, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

Install the python packages for CI docker: mpi4py, pynvml. They will be
used for the comm ops.

## 🔍 Related Issues

#1145,
#1134

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
Anerudhan pushed a commit to Anerudhan/flashinfer that referenced this pull request Jun 28, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

Install the python packages for CI docker: mpi4py, pynvml. They will be
used for the comm ops.

## 🔍 Related Issues

flashinfer-ai#1145,
flashinfer-ai#1134

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants