diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu new file mode 100644 index 000000000..9dd55470c --- /dev/null +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -0,0 +1,68 @@ +#include "flashinfer/comm/trtllm_mnnvl_allreduce.cuh" +#include "pytorch_extension_utils.h" + +using namespace flashinfer::trtllm_mnnvl_allreduce; + +#define DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(scalar_type, c_type, ...) \ + [&] { \ + switch (scalar_type) { \ + case at::ScalarType::Float: { \ + using c_type = float; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Half: { \ + using c_type = half; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::BFloat16: { \ + using c_type = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + default: \ + TORCH_CHECK(false, "Unsupported dtype in DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE: ", \ + scalar_type); \ + } \ + }() + +void trtllm_mnnvl_all_reduce(at::Tensor& in, at::Tensor& out, int64_t multicast_buffer_ptr, + int64_t buffer_ptrs_dev, int64_t buffer_M, + at::Tensor& buffer_flags_mnnvl, int64_t nranks, int64_t rank, + bool wait_for_results, bool launch_with_pdl) { + const c10::cuda::OptionalCUDAGuard device_guard(in.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(in.scalar_type(), c_type, [&] { + // Extract parameters from tensors + int64_t num_tokens = in.size(0); + int64_t token_dim = in.size(1); + + // Validate input parameters + TORCH_CHECK(nranks >= 2 && nranks <= 64, "nranks must be between 2 and 64, got ", nranks); + TORCH_CHECK(rank >= 0 && rank < nranks, "rank must be between 0 and nranks-1, got ", rank); + + // Create the parameters struct + AllReduceParams params; + params.nranks = nranks; + params.rank = rank; + params.buffer_M = buffer_M; + params.num_tokens = num_tokens; + params.token_dim = token_dim; + params.buffer_ptrs_dev = reinterpret_cast(buffer_ptrs_dev); + params.multicast_ptr = reinterpret_cast(multicast_buffer_ptr); + params.buffer_flags = buffer_flags_mnnvl.data_ptr(); + params.wait_for_results = wait_for_results; + params.launch_with_pdl = launch_with_pdl; + params.input = in.data_ptr(); + params.output = out.data_ptr(); + params.stream = stream.stream(); + + auto status = twoshot_allreduce_dispatch_world_size(params); + TORCH_CHECK(status == cudaSuccess, + "twoshot_allreduce_dispatch_world_size failed with error code ", + cudaGetErrorString(status)); + }); +} + +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + m.def("trtllm_mnnvl_all_reduce", &trtllm_mnnvl_all_reduce); +} diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index 1fffcb658..378661e9a 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -30,6 +30,12 @@ from .trtllm_ar import ( trtllm_moe_finalize_allreduce_fusion as trtllm_moe_finalize_allreduce_fusion, ) +from .trtllm_mnnvl_ar import ( + gen_trtllm_mnnvl_comm_module, + get_allreduce_mnnvl_workspace, + mpi_barrier, + trtllm_mnnvl_all_reduce, +) from .vllm_ar import all_reduce as vllm_all_reduce from .vllm_ar import dispose as vllm_dispose from .vllm_ar import gen_vllm_comm_module as gen_vllm_comm_module diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index b06861ea9..33ca46cdb 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # Code imported from TensorRT-LLM/tensorrt_llm/_mnnvl_utils.py +import ctypes import logging import platform import sys from dataclasses import dataclass +from typing import List import pynvml import torch @@ -24,12 +26,89 @@ from mpi4py import MPI from ..cuda_utils import checkCudaErrors -from .dlpack_utils import pack_strided_memory +from .dlpack_utils import create_dlpack_capsule, pack_strided_memory from .mapping import Mapping # mpi4py only exports MPI_COMM_TYPE_SHARED, so we define OMPI_COMM_TYPE_HOST here OMPI_COMM_TYPE_HOST = 9 +# Constants from C++ header +SIGNAL_PAD_SIZE = 2048 # kSIGNAL_PAD_SIZE from header + +MNNVL_DEBUG = False + + +def round_up(val: int, gran: int) -> int: + """Efficient implementation assuming gran is a power of 2""" + return (val + gran - 1) & ~(gran - 1) + + +def create_tensor_from_cuda_memory( + ptr: int, shape: tuple, dtype: torch.dtype, device_id: int +) -> torch.Tensor: + """ + Create a PyTorch tensor from a CUDA memory pointer using DLPack. + + Args: + ptr: CUDA memory pointer address as integer + shape: Desired tensor shape + dtype: PyTorch data type + device_id: CUDA device ID + + Returns: + PyTorch tensor that wraps the CUDA memory + """ + # Calculate total size in elements + numel = 1 + for dim in shape: + numel *= dim + + # Get element size in bytes + element_size = torch.tensor([], dtype=dtype).element_size() + total_size_bytes = numel * element_size + + # Create DLPack capsule for contiguous memory (stride = element_size, num_segments = numel) + capsule_wrapper = create_dlpack_capsule( + ptr, element_size, element_size, numel, dtype, device_id + ) + + # Convert to tensor and reshape + tensor = torch.utils.dlpack.from_dlpack(capsule_wrapper.capsule) + tensor._capsule_wrapper = capsule_wrapper # Keep reference to prevent GC + + # Reshape to desired shape + return tensor.view(shape) + + +def test_cuda_memory_access(ptr: int, size: int, device_id: int) -> bool: + """ + Test if CUDA memory at ptr is accessible by trying to read/write a small amount. + + Args: + ptr: CUDA memory pointer + size: Size of memory region + device_id: CUDA device ID + + Returns: + True if memory is accessible, False otherwise + """ + try: + # Test with a small 4-byte read/write + test_size = min(4, size) + host_data = bytearray(test_size) + + # Try to copy from device to host + checkCudaErrors(cuda.cuMemcpyDtoH(host_data, ptr, test_size)) + + # Try to copy back from host to device + checkCudaErrors(cuda.cuMemcpyHtoD(ptr, host_data, test_size)) + + print(f"DEBUG: Memory access test PASSED for ptr=0x{ptr:x}") + return True + except Exception as e: + print(f"DEBUG: Memory access test FAILED for ptr=0x{ptr:x}: {e}") + return False + class MpiComm: _comm: MPI.Intracomm = MPI.COMM_WORLD @@ -308,3 +387,484 @@ def supports_mnnvl() -> bool: if not "aarch64" in arch: return False return MnnvlMemory.support_nvlink(True) + + +class McastDeviceMemory: + """Python port of McastDeviceMemory from TensorRT-LLM""" + + def __init__( + self, + buf_size: int, + group_size: int, + group_rank: int, + device_idx: int, + is_multi_node: bool = True, + ): + cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) + + primary_ctx = checkCudaErrors(cuda.cuDevicePrimaryCtxRetain(cu_device)) + checkCudaErrors(cuda.cuCtxSetCurrent(primary_ctx)) + + current_context = checkCudaErrors(cuda.cuCtxGetCurrent()) + + # Set CUDA device + import cuda.cudart as cudart + + checkCudaErrors(cudart.cudaSetDevice(device_idx)) + + self.is_multi_node = is_multi_node + self.device_idx = device_idx + self.group_size = group_size + self.group_rank = group_rank + self.buf_size = buf_size + self.signal_pad_offset = 0 + self.allocation_size = 0 + + # CUDA memory handles and pointers + self.mc_ptr = 0 # CUdeviceptr mMcPtr + self.uc_ptrs: List[int] = [] # std::vector mUcPtrs + self.signal_pads_dev: List[int] = [] # std::vector mSignalPadsDev + self.mc_handle = 0 # CUmemGenericAllocationHandle mMcHandle + self.uc_handles: List[int] = ( + [] + ) # std::vector mUcHandles + + # Signal pad constants + self.SIGNAL_PAD_ALIGNMENT = 16 + self.SIGNAL_PAD_SIZE = SIGNAL_PAD_SIZE + + # Check if device supports multicasting + multicast_supported = checkCudaErrors( + cuda.cuDeviceGetAttribute( + cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, + device_idx, + ) + ) + if multicast_supported == 0: + raise RuntimeError( + "[McastDeviceMemory] Device does not support multicasting." + ) + + # Calculate signal pad offset with alignment (matching C++ exactly) + self.signal_pad_offset = round_up(buf_size, self.SIGNAL_PAD_ALIGNMENT) + + logging.info( + f"[McastDeviceMemory] Rank: {group_rank}, Group size: {group_size}, " + f"mnNvlink: {is_multi_node}, device_idx: {device_idx}, " + f"Signal pad offset: {self.signal_pad_offset}" + ) + + if self.is_multi_node: + # Check if fabric handle is supported + fabric_handle_supported = checkCudaErrors( + cuda.cuDeviceGetAttribute( + cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, + device_idx, + ) + ) + if fabric_handle_supported == 0: + raise RuntimeError( + "[McastDeviceMemory] Device does not support fabric handle." + ) + + current_context = checkCudaErrors(cuda.cuCtxGetCurrent()) + + self._alloc_mn_mcast_mem(buf_size) + else: + # For single-node NVLS, would need to implement _alloc_nvls_mcast_mem + raise NotImplementedError("Single-node NVLS allocation not implemented yet") + + # Initialize signal pads + self.signal_pads_dev = [0] * self.group_size + for i in range(self.group_size): + self.signal_pads_dev[i] = self.uc_ptrs[i] + self.signal_pad_offset + if i == self.group_rank: + checkCudaErrors( + cuda.cuMemsetD8(self.signal_pads_dev[i], 0, self.SIGNAL_PAD_SIZE) + ) + + def __del__(self): + """Destructor - cleanup allocated memory""" + + # Check if we're in a valid state for cleanup + if not hasattr(self, "is_multi_node"): + return + + if not self.is_multi_node: + return + + # Skip cleanup during Python finalization to avoid segfaults + # Especially cause the CUDA context could be destroyed at this point. + if sys.is_finalizing(): + return + + # Verify CUDA context is still valid + try: + cuda.cuCtxGetCurrent() + except Exception as e: + print(f"Destructor: CUDA context invalid, skipping cleanup: {e}") + return + + # Unmap UC regions and release their handles + if hasattr(self, "uc_handles") and self.uc_handles: + for rank in range(self.group_size): + if self.uc_handles[rank] != 0: + try: + # Release the handle + checkCudaErrors(cuda.cuMemRelease(self.uc_handles[rank])) + # Unmap the vmem + if rank < len(self.uc_ptrs) and self.uc_ptrs[rank]: + checkCudaErrors( + cuda.cuMemUnmap( + self.uc_ptrs[rank], self.allocation_size + ) + ) + except Exception as e: + print( + f"Destructor: Failed to release UC handle for rank {rank}: {e}" + ) + + # Free the UC address space + if hasattr(self, "uc_base_ptr") and self.uc_base_ptr: + checkCudaErrors( + cuda.cuMemAddressFree(self.uc_base_ptr, self.total_uc_size) + ) + + # Release MC handle + if hasattr(self, "mc_handle") and self.mc_handle and self.mc_handle != 0: + try: + checkCudaErrors(cuda.cuMemUnmap(self.mc_ptr, self.allocation_size)) + checkCudaErrors( + cuda.cuMemAddressFree(self.mc_ptr, self.allocation_size) + ) + checkCudaErrors(cuda.cuMemRelease(self.mc_handle)) + except Exception as e: + print(f"Destructor: Failed to release MC handle: {e}") + + def get_signal_pad_ptrs_dev(self) -> List[int]: + """Get the raw array of signal pad pointers to all ranks (including self)""" + return self.signal_pads_dev + + def get_buffer_ptrs_dev(self) -> List[int]: + """Get the raw array of unicast pointers to all ranks (including self)""" + return self.uc_ptrs + + def get_unicast_ptr(self, rank: int) -> int: + """Get the raw unicast pointer to a given rank""" + if rank >= len(self.uc_ptrs): + raise ValueError(f"Rank {rank} out of range (0-{len(self.uc_ptrs)-1})") + + data_ptr = self.uc_ptrs[rank] + # Note: In C++, this would call tensorrt_llm::common::registerMcastDevMemBuffer + # For Python port, we skip this registration for now + return data_ptr + + def get_multicast_ptr(self) -> int: + """Get the raw multicast pointer""" + # Note: In C++, this would call tensorrt_llm::common::registerMcastDevMemBuffer + # For Python port, we skip this registration for now + return int(self.mc_ptr) + + def get_rank(self) -> int: + """Get the rank of this device in the group""" + return self.group_rank + + def get_world_size(self) -> int: + """Get the total number of devices in the group""" + return self.group_size + + def _alloc_mn_mcast_mem(self, buf_size: int): + """Allocate multi-node multicast memory using MNNVL""" + + # Verify CUDA context + try: + current_device = checkCudaErrors(cuda.cuCtxGetDevice()) + current_context = checkCudaErrors(cuda.cuCtxGetCurrent()) + + if int(current_device) != self.device_idx: + print( + f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}" + ) + except Exception as e: + print(f"Error checking CUDA context: {e}") + + # Get MPI communicator + comm = MpiComm() + + # Set up allocation properties + handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + + allocation_prop = cuda.CUmemAllocationProp() + allocation_prop.requestedHandleTypes = handle_type + allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED + allocation_prop.location = cuda.CUmemLocation() + allocation_prop.location.type = ( + cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + ) + allocation_prop.location.id = self.device_idx + + allocation_prop.allocFlags.gpuDirectRDMACapable = 1 + + # Get allocation granularity + alloc_granularity = checkCudaErrors( + cuda.cuMemGetAllocationGranularity( + allocation_prop, + cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM, + ) + ) + + # mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, alloc_granularity); + self.allocation_size = round_up( + buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity + ) + + # Set up multicast properties + mc_prop = cuda.CUmulticastObjectProp() + mc_prop.numDevices = self.group_size + mc_prop.size = self.allocation_size + mc_prop.handleTypes = handle_type + + # Get multicast granularity + mc_granularity = checkCudaErrors( + cuda.cuMulticastGetGranularity( + mc_prop, + cuda.CUmulticastGranularity_flags.CU_MULTICAST_GRANULARITY_RECOMMENDED, + ) + ) + + self.allocation_size = round_up(self.allocation_size, mc_granularity) + + # Initialize UC handles list + self.uc_handles = [0] * self.group_size + + # Allocate local GPU memory + self.uc_handles[self.group_rank] = checkCudaErrors( + cuda.cuMemCreate(self.allocation_size, allocation_prop, 0) + ) + + # Export local handle to fabric handle + my_fabric_handle = checkCudaErrors( + cuda.cuMemExportToShareableHandle( + self.uc_handles[self.group_rank], + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + 0, + ) + ) + + # All-gather fabric handles + all_fabric_handles = comm.allgather(my_fabric_handle.data) + cuda.cuCtxSynchronize() + + # Import remote handles + for p in range(self.group_size): + if p != self.group_rank: + self.uc_handles[p] = checkCudaErrors( + cuda.cuMemImportFromShareableHandle( + all_fabric_handles[p], + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + ) + ) + + # Initialize multicasting + if self.group_rank == 0: + # Create multicast object + self.mc_handle = checkCudaErrors(cuda.cuMulticastCreate(mc_prop)) + + # Export multicast handle + mc_fabric_handle = checkCudaErrors( + cuda.cuMemExportToShareableHandle( + self.mc_handle, + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + 0, + ) + ) + else: + mc_fabric_handle = None + + # Broadcast multicast handle + mc_fabric_handle_data = comm.bcast( + mc_fabric_handle.data if mc_fabric_handle else None, root=0 + ) + # Sync device to ensure broadcast is complete + cuda.cuCtxSynchronize() + # Import multicast handle for non-root ranks + if self.group_rank != 0: + self.mc_handle = checkCudaErrors( + cuda.cuMemImportFromShareableHandle( + mc_fabric_handle_data, + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + ) + ) + + # Add device to multicast + checkCudaErrors(cuda.cuMulticastAddDevice(self.mc_handle, self.device_idx)) + + # Bind memory addresses + self.uc_ptrs = [0] * self.group_size + + # Reserve address space for UC pointers + total_uc_size = self.allocation_size * self.group_size + self.total_uc_size = total_uc_size + uc_base_ptr = checkCudaErrors( + cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0) + ) + self.uc_base_ptr = uc_base_ptr # Store for cleanup + + # Set up memory access descriptor + access_desc = cuda.CUmemAccessDesc() + access_desc.location = cuda.CUmemLocation() + access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + access_desc.location.id = self.device_idx + access_desc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE + + # Map UC memory + for i in range(self.group_size): + offset = self.allocation_size * i + self.uc_ptrs[i] = int(uc_base_ptr) + offset + checkCudaErrors( + cuda.cuMemMap( + self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0 + ) + ) + + # Set memory access permissions + checkCudaErrors( + cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1) + ) + + # Bind MC pointer + self.mc_ptr = checkCudaErrors( + cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0) + ) + checkCudaErrors( + cuda.cuMemMap(self.mc_ptr, self.allocation_size, 0, self.mc_handle, 0) + ) + checkCudaErrors( + cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1) + ) + + # Bind memory to multicast + checkCudaErrors( + cuda.cuMulticastBindMem( + self.mc_handle, + 0, # mcOffset + self.uc_handles[self.group_rank], + 0, # memOffset + self.allocation_size, + 0, # flags + ) + ) + + def get_multicast_ptr_as_int64(self) -> int: + """Get multicast pointer as int64 (legacy compatibility)""" + return self.get_multicast_ptr() + + def get_buffer_ptrs_dev_as_int64(self) -> int: + """Get buffer pointers device as int64 (returning first UC pointer for now) (legacy compatibility)""" + return self.uc_ptrs[0] if self.uc_ptrs else 0 + + def lamport_initialize(self, rank: int, dtype: torch.dtype): + if dtype == torch.bfloat16: + neg_zero = 0x8000 + dsize = 2 + memset_func = cuda.cuMemsetD16 + elif dtype == torch.float32: + neg_zero = 0x80000000 + dsize = 4 + memset_func = cuda.cuMemsetD32 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + # Calculate number of elements that fit in allocation_size + num_elements = self.allocation_size // dsize + + checkCudaErrors( + memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements) + ) + + +class McastGPUBuffer: + """ + Wrapper class for McastDeviceMemory to facilitate PyTorch tensor creation. + It manages a buffer accessible via unicast or multicast for multi-node communication. + + Python port of McastGPUBuffer from TensorRT-LLM + """ + + def __init__( + self, + buf_size: int, + group_size: int, + group_rank: int, + device: torch.device, + mn_nvlink: bool = True, + ): + """ + Constructor for McastGpuBuffer. + + Args: + buf_size: The total size of the buffer in bytes + group_size: The number of ranks in the communication group + group_rank: The rank of the local process within the group + device: The CUDA device for buffer allocation + mn_nvlink: Flag indicating if multi-node NVLink is used + """ + self.mcast_device_memory = McastDeviceMemory( + buf_size, group_size, group_rank, device.index, mn_nvlink + ) + self.buf_size = buf_size + self.local_device = device + + def lamport_initialize(self, rank: int): + self.mcast_device_memory.lamport_initialize(rank) + + def get_mc_buffer( + self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 + ) -> torch.Tensor: + """ + Returns a PyTorch tensor view of the multicast buffer portion. + + Args: + sizes: The desired shape (dimensions) of the tensor + dtype: The data type of the tensor elements + storage_offset: The offset in elements from the start of the buffer + + Returns: + A PyTorch tensor wrapping the multicast buffer section + """ + raise NotImplementedError("Not implemented yet") + + def get_multicast_ptr(self) -> int: + """Get the raw multicast pointer""" + return self.mcast_device_memory.get_multicast_ptr() + + def get_multicast_ptr_as_int64(self) -> int: + """Get the multicast pointer as int64""" + return self.get_multicast_ptr() + + def get_buffer_ptrs_dev(self) -> List[int]: + """Get the buffer pointers device array""" + return self.mcast_device_memory.get_buffer_ptrs_dev() + + def get_buffer_ptrs_dev_as_int64(self) -> int: + """Get the buffer pointers device as int64 (returning first UC pointer)""" + ptrs = self.get_buffer_ptrs_dev() + assert ptrs is not None + return ptrs[0] if ptrs else 0 + + def get_buffer_ptrs_dev_as_ctypes_ptr(self) -> int: + """ + Get buffer pointers as ctypes array pointer (equivalent to C++ void**). + Returns the address of a ctypes array that can be cast to int64_t and back to void**. + + This matches the C++ pattern: + reinterpret_cast(reinterpret_cast(mUcPtrs.data())) + """ + # Create ctypes array of void pointers + ArrayType = ctypes.c_void_p * len(self.mcast_device_memory.uc_ptrs) + self._buffer_ptrs_array = ArrayType( + *self.mcast_device_memory.uc_ptrs + ) # Keep reference to prevent GC + + # Return the address of this array (equivalent to .data() in C++) + return ctypes.cast(self._buffer_ptrs_array, ctypes.c_void_p).value diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py new file mode 100644 index 000000000..800c504b6 --- /dev/null +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -0,0 +1,200 @@ +""" +MNNVL (Multi-Node NVLink) communication operations for FlashInfer. + +""" + +import functools +import math +import os +from types import SimpleNamespace +from typing import List, Tuple + +import torch +from mpi4py import MPI + +from flashinfer.comm.mapping import Mapping + +from ..jit import JitSpec +from ..jit import env as jit_env +from ..jit import gen_jit_spec, sm100a_nvcc_flags +from ..utils import register_custom_op +from .mnnvl import McastGPUBuffer + + +def mpi_barrier(): + """MPI barrier - could potentially be replaced with dist.barrier()""" + MPI.COMM_WORLD.Barrier() + + +def gen_trtllm_mnnvl_comm_module() -> JitSpec: + return gen_jit_spec( + "trtllm_mnnvl_comm", + [ + jit_env.FLASHINFER_CSRC_DIR / "trtllm_mnnvl_allreduce.cu", + ], + ) + + +@functools.cache +def get_trtllm_mnnvl_comm_module(): + module = gen_trtllm_mnnvl_comm_module().build_and_load() + + @register_custom_op( + "flashinfer::trtllm_mnnvl_all_reduce", + mutates_args=[ + "inp", + "out", + "multicast_buffer_ptr", + "buffer_ptrs_dev", + "buffer_mnnvl", + "buffer_flags_mnnvl", + "nranks", + "rank", + "wait_for_results", + "launch_with_pdl", + ], + ) + def trtllm_mnnvl_all_reduce( + inp: torch.Tensor, + out: torch.Tensor, + multicast_buffer_ptr: int, # Pointer address as integer + buffer_ptrs_dev: int, # Pointer address as integer + buffer_mnnvl: torch.Tensor, + buffer_flags_mnnvl: torch.Tensor, + nranks: int, + rank: int, + wait_for_results: bool, + launch_with_pdl: bool, + ) -> None: + module.trtllm_mnnvl_all_reduce( + inp, + out, + multicast_buffer_ptr, + buffer_ptrs_dev, + buffer_mnnvl, + buffer_flags_mnnvl, + nranks, + rank, + wait_for_results, + launch_with_pdl, + ) + + return SimpleNamespace( + trtllm_mnnvl_all_reduce=trtllm_mnnvl_all_reduce, + ) + + +def get_allreduce_mnnvl_workspace( + mapping: Mapping, dtype: torch.dtype +) -> Tuple[McastGPUBuffer, torch.Tensor, int]: + """Get workspace buffers needed for multi-node NVLink all-reduce operation. + + This function allocates and initializes the workspace buffers required for performing + multi-node NVLink all-reduce operations. It creates: + 1. A multicast GPU buffer for communication between nodes + 2. A flags tensor to track buffer state + 3. Maximum number of elements that can fit in the buffer + + The buffer size is calculated to efficiently handle common hidden dimensions + (2048, 4096, 5120, 7168, 8192) by using their LCM of 286720. + + Args: + mapping: Tensor parallel mapping configuration containing rank info + dtype: Data type of the tensors being reduced + + Returns: + Tuple containing: + - McastGPUBuffer: Multicast buffer for inter-node communication + - torch.Tensor: Buffer flags tensor tracking state + - int: Maximum number of elements that can fit in buffer + """ + force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1" + + # buffer shape: [3, 2, buffer_tokens, hidden_dim] + stride = 3 * 2 * dtype.itemsize + # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 + # max_num_elements must be a multiple of 286720 + lcm_hidden_dim = 286720 + TARGET_WORKSPACE_SIZE_BYTES = 12_000_000 + buffer_size_in_bytes = math.ceil( + TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) + ) * (lcm_hidden_dim * stride) + max_num_elements = buffer_size_in_bytes // stride + + mcast_buffer = McastGPUBuffer( + buffer_size_in_bytes, + mapping.tp_size, + mapping.tp_rank, + torch.device("cuda", mapping.local_rank), + mapping.is_multi_node() or force_mn, + ) + + # Initialize the unicast buffer with -0.0 + mcast_buffer.lamport_initialize(mapping.tp_rank, dtype) + + # CPU barrier since we assume this should not be called in cuda graph + torch.cuda.synchronize() + mpi_barrier() + + # This is a buffer to maintain the state of this allreduce Op + # [Buffer_ptr, Clear_ptr, Buffer_size, atomic access counter] + buffer_flags = torch.tensor( + [0, 2, max_num_elements, 0], + dtype=torch.uint32, + device=torch.device("cuda", mapping.local_rank), + ) + + return ( + mcast_buffer, + buffer_flags, + max_num_elements, + ) + + +def trtllm_mnnvl_all_reduce( + inp: torch.Tensor, + out: torch.Tensor, + multicast_buffer_ptr: int, # Pointer address as integer + buffer_ptrs_dev: int, # Pointer address as integer + buffer_M: int, + buffer_flags_mnnvl: torch.Tensor, + nranks: int, + rank: int, + wait_for_results: bool, + launch_with_pdl: bool, +) -> None: + """Perform a multi-node NVLink all-reduce operation across multiple GPUs. + + This function performs an all-reduce (sum) operation using NVIDIA's multi-node NVLink (MNNVL) + technology to efficiently combine tensors across multiple GPUs and nodes. + + There are 3 steps: + 1. scatter each GPU's input shard to the right unicast buffer + 2. perform all-reduce on each GPU + 3. broadcast the result to all GPUs + + Args: + inp: Local Input Shard + out: Output tensor to store the result + multicast_buffer_ptr: Pointer to the multicast buffer as an integer + buffer_ptrs_dev: Pointer to device buffer pointers as an integer + buffer_M: Maximum number of elements // hidden_dim + buffer_flags_mnnvl: Tensor containing buffer state flags + nranks: Total number of ranks participating in the all-reduce + rank: Current process rank + wait_for_results: If True, store the result to out + launch_with_pdl: If True, launch using Programmatic Dependent Launch + """ + module = get_trtllm_mnnvl_comm_module() + module.trtllm_mnnvl_all_reduce( + inp, + out, + multicast_buffer_ptr, + buffer_ptrs_dev, + buffer_M, + buffer_flags_mnnvl, + nranks, + rank, + wait_for_results, + launch_with_pdl, + ) diff --git a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh new file mode 100644 index 000000000..032536221 --- /dev/null +++ b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +#include "../exception.h" +#include "../logging.h" +namespace flashinfer { +namespace trtllm_mnnvl_allreduce { + +template +struct AllReduceParams { + int nranks; + int rank; + int buffer_M; + int num_tokens; + int token_dim; + void** buffer_ptrs_dev; + void* multicast_ptr; + void* buffer_flags; + bool wait_for_results; + bool launch_with_pdl; + + void* input; + void* output; + cudaStream_t stream; +}; + +__device__ bool isNegZero(float v) { return v == 0.f && signbit(v); } + +__device__ bool isNegZero(__nv_bfloat16 val) { return isNegZero(__bfloat162float(val)); } + +template +inline __device__ float toFloat(T val) { + return val; +} + +template <> +inline __device__ float toFloat<__nv_bfloat16>(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template +inline __device__ T fromFloat(float val) { + return val; +} + +template <> +inline __device__ __nv_bfloat16 fromFloat<__nv_bfloat16>(float val) { + return __float2bfloat16(val); +} + +template +__global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, + int num_tokens, int buffer_M, int token_dim, int rank, + uint32_t* buffer_flags, bool wait_for_results) { + int elt = blockIdx.y * blockDim.x + threadIdx.x; + + if (elt >= token_dim) return; + int token = blockIdx.x; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + + uint32_t* offset_access_ptr = &buffer_flags[3]; + // Buffer size is M * N, and we need two buffers for reduce-scatter and allgather + uint32_t buffer_size = (buffer_flags[2] << 1); + uint32_t input_offset = buffer_flags[0] * buffer_size; + uint32_t clear_offset = buffer_flags[1] * buffer_size; + + if (wait_for_results) { + __syncthreads(); + if (threadIdx.x == 0) { + atomicAdd(offset_access_ptr, 1); + } + } + + if (elt < token_dim) { + // Scatter token + int dest_rank = token % WORLD_SIZE; + int dest_token_offset = token / WORLD_SIZE; + T val = shard_ptr[token * token_dim + elt]; + if (isNegZero(val)) val = fromFloat(0.f); + input_ptrs[dest_rank][input_offset + dest_token_offset * token_dim * WORLD_SIZE + + rank * token_dim + elt] = val; + + // Reduce and broadcast + + int global_token = token * WORLD_SIZE + rank; + if (global_token < num_tokens) { + float accum = 0.f; + + T values[WORLD_SIZE]; + + for (int r = 0; r < WORLD_SIZE; r++) { + input_ptrs[rank][clear_offset + token * token_dim * WORLD_SIZE + r * token_dim + elt] = + fromFloat(-0.f); + } + + while (1) { + bool valid = true; + for (int r = 0; r < WORLD_SIZE; r++) { + T volatile* lamport_ptr = + (T volatile*)&input_ptrs[rank][input_offset + token * token_dim * WORLD_SIZE + + r * token_dim + elt]; + values[r] = *lamport_ptr; + valid &= !isNegZero(values[r]); + } + if (valid) break; + } + for (int r = 0; r < WORLD_SIZE; r++) { + accum += toFloat(values[r]); + } + mcast_ptr[input_offset + buffer_M * token_dim + global_token * token_dim + elt] = + fromFloat(accum); + } + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + + input_ptrs[rank][clear_offset + buffer_M * token_dim + token * token_dim + elt] = + fromFloat(-0.f); + + // Optionally wait for results if the next layer isn't doing the Lamport check + if (wait_for_results) { + T volatile* lamport_ptr = (T volatile*)&input_ptrs[rank][input_offset + buffer_M * token_dim + + token * token_dim + elt]; + T val = *lamport_ptr; + while (isNegZero(val)) val = *lamport_ptr; + + // Copy if requested + if (output_ptr) output_ptr[token * token_dim + elt] = val; + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0) { + // Make sure all blocks have finished reading the offsets, 2-D grid + while (*reinterpret_cast(offset_access_ptr) < gridDim.x * gridDim.y) { + } + buffer_flags[0] = (buffer_flags[0] + 1) % 3; + buffer_flags[1] = (buffer_flags[1] + 1) % 3; + *(offset_access_ptr) = 0; + } + } +} + +// Template-based dispatch functions following the same pattern as trtllm_allreduce.cuh +template +cudaError_t twoshot_allreduce_dispatch(AllReduceParams& params) { + int const num_threads = 128; + int const num_blocks = (params.token_dim + num_threads - 1) / num_threads; + + dim3 grid(params.num_tokens, num_blocks); + + FLASHINFER_LOG_DEBUG( + "[MNNVL TwoShot AllReduce] twoshot allreduce on rank %d, world_size: %d, buffer_M: %d, " + "num_tokens: %d, token_dim: " + "%d, wait_for_results: %d, launch_with_pdl: %d", + params.rank, params.nranks, params.buffer_M, params.num_tokens, params.token_dim, + params.wait_for_results, params.launch_with_pdl); + + cudaLaunchConfig_t config; + cudaLaunchAttribute attrs[1]; + config.dynamicSmemBytes = 0; + config.stream = params.stream; + config.gridDim = grid; + config.blockDim = num_threads; + config.attrs = attrs; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = params.launch_with_pdl ? 1 : 0; + config.numAttrs = 1; + + cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel, + reinterpret_cast(params.output), reinterpret_cast(params.input), + reinterpret_cast(params.buffer_ptrs_dev), + reinterpret_cast(params.multicast_ptr), params.num_tokens, params.buffer_M, + params.token_dim, params.rank, + reinterpret_cast(params.buffer_flags), params.wait_for_results); + + return cudaSuccess; +} + +template +cudaError_t twoshot_allreduce_dispatch_world_size(AllReduceParams& params) { + FLASHINFER_LOG_DEBUG("twoshot_allreduce_dispatch_world_size"); + switch (params.nranks) { + case 2: + return twoshot_allreduce_dispatch(params); + case 4: + return twoshot_allreduce_dispatch(params); + case 8: + return twoshot_allreduce_dispatch(params); + case 16: + return twoshot_allreduce_dispatch(params); + case 32: + return twoshot_allreduce_dispatch(params); + case 64: + return twoshot_allreduce_dispatch(params); + default: + FLASHINFER_ERROR("MNNVL AllReduce: unsupported world_size " + std::to_string(params.nranks) + + ". Supported sizes: {2, 4, 8, 16, 32, 64}"); + return cudaErrorInvalidValue; + } +} + +} // namespace trtllm_mnnvl_allreduce +} // namespace flashinfer diff --git a/tests/test_trtllm_mnnvl_allreduce.py b/tests/test_trtllm_mnnvl_allreduce.py new file mode 100644 index 000000000..3af6ca6fd --- /dev/null +++ b/tests/test_trtllm_mnnvl_allreduce.py @@ -0,0 +1,220 @@ +# Check torch version: +import os +import sys +import traceback + +import pytest +import torch +from mpi4py import MPI # Added MPI import + +import flashinfer.comm as comm +from flashinfer.comm.mapping import Mapping +from flashinfer.comm.mnnvl import McastDeviceMemory, McastGPUBuffer + + +@torch.inference_mode() +def row_linear_residual_norm_fusion_forward( + x: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + hidden_size: int, + dtype: torch.dtype, + tensor_parallel_size: int, + tensor_parallel_rank: int, + fusion: bool, + reference_output: tuple[torch.Tensor, ...], +): + + x = x.cuda() + residual = residual.cuda() + norm_weight = norm_weight.cuda() + reference_output = tuple(t.cuda() for t in reference_output) + + MPI.COMM_WORLD.barrier() + + mapping = Mapping( + world_size=tensor_parallel_size, + tp_size=tensor_parallel_size, + rank=tensor_parallel_rank, + ) + + def func( + input, + residual, + norm_weight, + eps, + enable_fusion, + multicast_ptr, + buffer_ptrs_dev, + max_num_elements_mnnvl, + ): + # For both fused and unfused cases: + shape = input.shape + + hidden_size = shape[-1] + + assert max_num_elements_mnnvl % hidden_size == 0 + + input = input.view(-1, shape[-1]) + output = torch.empty_like(input) + + buffer_M = max_num_elements_mnnvl // hidden_size + + if enable_fusion: + raise NotImplementedError("Fusion not implemented") + + else: + comm.trtllm_mnnvl_all_reduce( + input, + output, + multicast_ptr, + buffer_ptrs_dev, # Attempted to use this raw pointer + buffer_M, + buffer_flags_mnnvl, + tensor_parallel_size, + tensor_parallel_rank, + True, # wait_for_results + False, # launch_with_pdl + ) + return (output.view(shape),) + + # Get workspace buffers using MPI rank + mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( + comm.get_allreduce_mnnvl_workspace(mapping, dtype) + ) + + multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr_as_int64() + buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev_as_ctypes_ptr() + + try: + output = func( + x.clone(), + residual.clone(), + norm_weight, + eps, + fusion, + multicast_ptr, + buffer_ptrs_dev, + max_num_elements_mnnvl, + ) + + assert output[0].shape == reference_output[0].shape + + if tensor_parallel_rank == 0: + print("output[0] (first 10 values):", output[0].flatten()[:10]) + print( + "reference_output[0] (first 10 values):", + reference_output[0].flatten()[:10], + ) + + if fusion: + print("output[1] (first 10 values):", output[1].flatten()[:10]) + print( + "reference_output[1] (first 10 values):", + reference_output[1].flatten()[:10], + ) + + torch.testing.assert_close( + output[0], + reference_output[0], + rtol=0.05, + atol=0.15, + ) + + if fusion: + torch.testing.assert_close( + output[1], + reference_output[1], + rtol=0.05, + atol=0.15, + ) + + finally: + # Ensure cleanup happens even if assertions fail + del mcast_buffer_mnnvl + + +"""Main test function that runs on each MPI rank""" + + +# seq_lens = [1, 4, 32, 128] +@pytest.mark.parametrize("seq_len", [4]) +@pytest.mark.parametrize("fusion", [False]) +def test_mnnvl_allreduce_full(monkeypatch, seq_len: int, fusion: bool): + monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. + + # Get MPI info + rank = MPI.COMM_WORLD.Get_rank() + world_size = MPI.COMM_WORLD.Get_size() + + # Ensure we have exactly 2 ranks for this test + if world_size < 2: + if rank == 0: + print(f"ERROR: This test requires at least 2 MPI ranks, got {world_size}") + sys.exit(1) + + # Set CUDA device based on rank + torch.cuda.set_device(rank) + + if rank == 0: + print(f"Running MNNVL AllReduce test with {world_size} ranks") + print(f"Rank {rank} using GPU {torch.cuda.current_device()}") + + hidden_size = 7168 + dtype = torch.bfloat16 + tensor_parallel_size = world_size + eps = 1e-5 + + torch.manual_seed(42) + + try: + if rank == 0: + print( + f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}" + ) + + # Generate test data (same on all ranks due to same seed) + x_full = torch.randn((tensor_parallel_size, seq_len, hidden_size), dtype=dtype) + residual = torch.randn((seq_len, hidden_size), dtype=dtype) + norm_weight = torch.randn((hidden_size,), dtype=dtype) + + # Each rank gets its slice of the input + x = x_full[rank, :, :] + + # Compute reference output based on fusion mode + if fusion: + raise NotImplementedError("Fusion not implemented") + else: + # Non-fused case: Only AllReduce + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + reference_output = (allreduce_result,) + + # Run the test + row_linear_residual_norm_fusion_forward( + x, + residual, + norm_weight, + eps, + hidden_size, + dtype, + tensor_parallel_size, + rank, + fusion, + reference_output, + ) + + # Synchronize before next test + comm.mpi_barrier() + + print(f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}") + + except Exception as e: + print(f"FAILED[rank={rank}]: seq_len={seq_len}, fusion={fusion} failed: {e}") + if rank == 0: + traceback.print_exc() + # Don't exit immediately, let other tests run + comm.mpi_barrier() + + # Final synchronization and results + comm.mpi_barrier()