Skip to content

[comm] TRT-LLM's Multi-Node NVLink All-Reduce Kernel #1213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions csrc/trtllm_mnnvl_allreduce.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include "flashinfer/comm/trtllm_mnnvl_allreduce.cuh"
#include "pytorch_extension_utils.h"

using namespace flashinfer::trtllm_mnnvl_allreduce;

#define DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(scalar_type, c_type, ...) \
[&] { \
switch (scalar_type) { \
case at::ScalarType::Float: { \
using c_type = float; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using c_type = half; \
return __VA_ARGS__(); \
} \
case at::ScalarType::BFloat16: { \
using c_type = __nv_bfloat16; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported dtype in DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE: ", \
scalar_type); \
} \
}()

void trtllm_mnnvl_all_reduce(at::Tensor& in, at::Tensor& out, int64_t multicast_buffer_ptr,
int64_t buffer_ptrs_dev, int64_t buffer_M,
at::Tensor& buffer_flags_mnnvl, int64_t nranks, int64_t rank,
bool wait_for_results, bool launch_with_pdl) {
const c10::cuda::OptionalCUDAGuard device_guard(in.device());
auto stream = at::cuda::getCurrentCUDAStream();

DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(in.scalar_type(), c_type, [&] {
// Extract parameters from tensors
int64_t num_tokens = in.size(0);
int64_t token_dim = in.size(1);

// Validate input parameters
TORCH_CHECK(nranks >= 2 && nranks <= 64, "nranks must be between 2 and 64, got ", nranks);
TORCH_CHECK(rank >= 0 && rank < nranks, "rank must be between 0 and nranks-1, got ", rank);

// Create the parameters struct
AllReduceParams<c_type> params;
params.nranks = nranks;
params.rank = rank;
params.buffer_M = buffer_M;
params.num_tokens = num_tokens;
params.token_dim = token_dim;
params.buffer_ptrs_dev = reinterpret_cast<void**>(buffer_ptrs_dev);
params.multicast_ptr = reinterpret_cast<void*>(multicast_buffer_ptr);
params.buffer_flags = buffer_flags_mnnvl.data_ptr();
params.wait_for_results = wait_for_results;
params.launch_with_pdl = launch_with_pdl;
params.input = in.data_ptr();
params.output = out.data_ptr();
params.stream = stream.stream();

auto status = twoshot_allreduce_dispatch_world_size<c_type>(params);
TORCH_CHECK(status == cudaSuccess,
"twoshot_allreduce_dispatch_world_size failed with error code ",
cudaGetErrorString(status));
});
}

TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
m.def("trtllm_mnnvl_all_reduce", &trtllm_mnnvl_all_reduce);
}
6 changes: 6 additions & 0 deletions flashinfer/comm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
from .trtllm_ar import (
trtllm_moe_finalize_allreduce_fusion as trtllm_moe_finalize_allreduce_fusion,
)
from .trtllm_mnnvl_ar import (
gen_trtllm_mnnvl_comm_module,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we include mnnvl comm in aot build?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the past @yzh119 suggested not to do this yet. Maybe we can do it as part of another PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary at this moment.

get_allreduce_mnnvl_workspace,
mpi_barrier,
trtllm_mnnvl_all_reduce,
)
from .vllm_ar import all_reduce as vllm_all_reduce
from .vllm_ar import dispose as vllm_dispose
from .vllm_ar import gen_vllm_comm_module as gen_vllm_comm_module
Expand Down
Loading