diff --git a/csrc/nv_internal/cpp/common/opUtils.cpp b/csrc/nv_internal/cpp/common/opUtils.cpp new file mode 100644 index 000000000..56250668b --- /dev/null +++ b/csrc/nv_internal/cpp/common/opUtils.cpp @@ -0,0 +1,111 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/opUtils.h" + +#include +#include +#include + +#include +#include +#include + +#include "cuda.h" +#include "tensorrt_llm/runtime/utils/mpiTags.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" + +#if ENABLE_MULTI_DEVICE + +std::unordered_map* getDtypeMap() { + static std::unordered_map dtypeMap = { + {nvinfer1::DataType::kFLOAT, ncclFloat32}, {nvinfer1::DataType::kHALF, ncclFloat16}, + {nvinfer1::DataType::kBF16, ncclBfloat16}, {nvinfer1::DataType::kFP8, ncclInt8}, + {nvinfer1::DataType::kBOOL, ncclInt8}, {nvinfer1::DataType::kINT32, ncclInt32}, + {nvinfer1::DataType::kINT64, ncclInt64}, {nvinfer1::DataType::kUINT8, ncclUint8}, + {nvinfer1::DataType::kINT8, ncclInt8}, + }; + return &dtypeMap; +} + +namespace { + +// Get NCCL unique ID for a group of ranks. +ncclUniqueId getUniqueId(std::set const& group) { + auto const rank = COMM_SESSION.getRank(); + TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); + ncclUniqueId id; + if (rank == *group.begin()) { + NCCLCHECK_THROW(ncclGetUniqueId(&id)); + for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) { + COMM_SESSION.sendValue(id, *it, tensorrt_llm::mpi::MpiTag::kDefault); + } + } else { + COMM_SESSION.recvValue(id, *group.begin(), tensorrt_llm::mpi::MpiTag::kDefault); + } + TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); + return id; +} +} // namespace + +std::shared_ptr getComm(std::set const& group) { + auto const rank = COMM_SESSION.getRank(); + TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); + static std::map, std::shared_ptr> commMap; + static std::mutex mutex; + std::lock_guard lock(mutex); + std::ostringstream oss; + int index = 0; + for (auto const& rank : group) { + if (index != 0) { + oss << ","; + } + oss << rank; + index++; + } + auto groupStr = oss.str(); + auto it = commMap.find(group); + if (it != commMap.end()) { + auto ncclComm = it->second; + TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank); + return ncclComm; + } + + TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank); + ncclUniqueId id = getUniqueId(group); + int groupRank = 0; + for (auto const& currentRank : group) { + if (rank == currentRank) break; + ++groupRank; + } + TLLM_CHECK(static_cast(groupRank) < group.size()); + std::shared_ptr ncclComm(new ncclComm_t, [](ncclComm_t* comm) { + ncclCommDestroy(*comm); + delete comm; + }); +// Need static connection initialization for accurate KV cache size estimation +#if defined(_WIN32) + if (getenv("NCCL_RUNTIME_CONNECT") == nullptr) _putenv_s("NCCL_RUNTIME_CONNECT", "0"); +#else + setenv("NCCL_RUNTIME_CONNECT", "0", 0); +#endif // _WIN32 + NCCLCHECK_THROW(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank)); + commMap[group] = ncclComm; + TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); + return ncclComm; +} +#endif // ENABLE_MULTI_DEVICE diff --git a/csrc/nv_internal/cpp/runtime/utils/mpiUtils.cpp b/csrc/nv_internal/cpp/runtime/utils/mpiUtils.cpp new file mode 100644 index 000000000..1a242c42e --- /dev/null +++ b/csrc/nv_internal/cpp/runtime/utils/mpiUtils.cpp @@ -0,0 +1,571 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/runtime/utils/mpiUtils.h" + +#include +#include + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/runtime/common.h" +// #include "tensorrt_llm/runtime/iBuffer.h" + +#include +#include +#include +#include +#include +#ifndef _WIN32 +#include +#endif + +// We rely on SizeType32 being int32_t in some places with weak type checking, +// i.e. we're passing void ptr to some function. To prevent mysterious errors +// in the future, we trigger a compilation error here if SizeType32 isn't int32_t. +static_assert(std::is_same::value); + +namespace tensorrt_llm::mpi { + +MPI_Datatype getMpiDtype(MpiType dtype) { +#if ENABLE_MULTI_DEVICE + static std::unordered_map const dtype_map{ + {MpiType::kBYTE, MPI_BYTE}, {MpiType::kHALF, MPI_UINT16_T}, + {MpiType::kFLOAT, MPI_FLOAT}, {MpiType::kDOUBLE, MPI_DOUBLE}, + {MpiType::kBOOL, MPI_C_BOOL}, {MpiType::kINT8, MPI_INT8_T}, + {MpiType::kUINT8, MPI_UINT8_T}, {MpiType::kINT32, MPI_INT32_T}, + {MpiType::kUINT32, MPI_UINT32_T}, {MpiType::kINT64, MPI_INT64_T}, + {MpiType::kUINT64, MPI_UINT64_T}, {MpiType::kFP8, MPI_UINT8_T}, + {MpiType::kBF16, MPI_UINT16_T}, {MpiType::kCHAR, MPI_CHAR}, + }; + return dtype_map.at(dtype); +#else + TLLM_THROW("Multi device support is disabled."); +#endif +} + +MPI_Op getMpiOp(MpiOp op) { +#if ENABLE_MULTI_DEVICE + static std::unordered_map const op_map{ + {MpiOp::NULLOP, MPI_OP_NULL}, {MpiOp::MAX, MPI_MAX}, {MpiOp::MIN, MPI_MIN}, + {MpiOp::SUM, MPI_SUM}, {MpiOp::PROD, MPI_PROD}, {MpiOp::LAND, MPI_LAND}, + {MpiOp::BAND, MPI_BAND}, {MpiOp::LOR, MPI_LOR}, {MpiOp::BOR, MPI_BOR}, + {MpiOp::LXOR, MPI_LXOR}, {MpiOp::BXOR, MPI_BXOR}, {MpiOp::MINLOC, MPI_MINLOC}, + {MpiOp::MAXLOC, MPI_MAXLOC}, {MpiOp::REPLACE, MPI_REPLACE}, + }; + return op_map.at(op); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +namespace { + +bool mpiInitialized = false; +std::recursive_mutex mpiMutex; + +MpiComm initLocalSession() { +#if ENABLE_MULTI_DEVICE + MPI_Comm localComm = nullptr; + MPI_Comm_split_type(COMM_SESSION, OMPI_COMM_TYPE_HOST, COMM_SESSION.getRank(), MPI_INFO_NULL, + &localComm); + MpiComm localSession{localComm, false}; +#else + MpiComm localSession{COMM_SESSION, false}; +#endif // ENABLE_MULTI_DEVICE + return localSession; +} + +} // namespace + +std::vector getWorldRanks(MpiComm const& comm) { +#if ENABLE_MULTI_DEVICE + MPI_Group group = nullptr; + MPI_Group worldGroup = nullptr; + + MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); + MPICHECK(MPI_Comm_group(comm, &group)); + + int groupSize = 0; + MPICHECK(MPI_Group_size(group, &groupSize)); + std::vector ranks(groupSize); + std::vector worldRanks(groupSize); + std::iota(ranks.begin(), ranks.end(), 0); + + MPICHECK( + MPI_Group_translate_ranks(group, groupSize, ranks.data(), worldGroup, worldRanks.data())); + MPICHECK(MPI_Group_free(&group)); + MPICHECK(MPI_Group_free(&worldGroup)); +#else + std::vector worldRanks{0}; +#endif + return worldRanks; +} + +int getNumNodes() { +#if ENABLE_MULTI_DEVICE + TLLM_LOG_WARNING("Number of nodes was not provided, using MPI to determine number of nodes"); + + // Create a communicator for processes with the same hostname + MPI_Comm node_comm; + MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &node_comm); + + // Get rank in node_comm + int node_rank; + MPI_Comm_rank(node_comm, &node_rank); + + // Count only rank 0 processes + int local_count = (node_rank == 0) ? 1 : 0; + int num_nodes = 0; + + MPI_Allreduce(&local_count, &num_nodes, 1, MPI_INT, MPI_SUM, MPI_COMM_WORLD); + + MPI_Comm_free(&node_comm); + return num_nodes; +#else + return 1; +#endif +} + +void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent) { + // double-checked locking + if (mpiInitialized) { + return; + } + std::lock_guard lk(mpiMutex); + if (mpiInitialized) { + return; + } +#if ENABLE_MULTI_DEVICE + int initialized = 0; + TLLM_MPI_CHECK(MPI_Initialized(&initialized)); + if (!initialized) { + TLLM_LOG_INFO("Initializing MPI with thread mode %d", threadMode); + int providedMode = 0; + auto requiredMode = static_cast(threadMode); + MPICHECK(MPI_Init_thread(nullptr, nullptr, requiredMode, &providedMode)); + TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed"); + std::atexit([]() { MPI_Finalize(); }); + + /* + * We only catch SIGABRT and SIGSEGV because most, of not all errors in the worker will cause + * one of these 2 signals. Signals like SIGINT and SIGTERM should be issued to the parent and + * should terminate MPI workers correctly. + */ + for (int sig : {SIGABRT, SIGSEGV}) { + __sighandler_t previousHandler = nullptr; + if (forwardAbortToParent) { + previousHandler = std::signal(sig, [](int signal) { +#ifndef _WIN32 + pid_t parentProcessId = getppid(); + kill(parentProcessId, SIGKILL); +#endif + MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); + }); + } else { + previousHandler = + std::signal(sig, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); }); + } + TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed"); + } + + // ensure local MPI communicator is initialized + MpiComm::localSession(); + TLLM_LOG_INFO("Initialized MPI"); + } +#endif // ENABLE_MULTI_DEVICE + mpiInitialized = true; +} + +void MpiComm::barrier() const { +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Barrier(mComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +#if ENABLE_MULTI_DEVICE +template >>> +size_t invokeChunked(TMpiFunc func, TBase* buffer, size_t size, MPI_Datatype dtype, TArgs... args) { + constexpr auto maxP1 = static_cast(std::numeric_limits::max()) + 1; + if (TLLM_LIKELY(size < maxP1)) { + MPICHECK(func(buffer, size, dtype, args...)); + return 1; + } + + constexpr size_t alignment = 256; + int elementSize = 1; + MPICHECK(MPI_Type_size(dtype, &elementSize)); + elementSize = std::min(elementSize, alignment); + + // We cap at max alignment-bytes chunks that can be sent at once. + auto const step = maxP1 - (alignment / elementSize); + + using TCast = std::conditional_t, uint8_t const, uint8_t>; + size_t count = 0; + while (size != 0) { + auto currentStep = static_cast(std::min(size, step)); + MPICHECK(func(buffer, currentStep, dtype, args...)); + size -= currentStep; + size_t diff = static_cast(currentStep) * elementSize; + buffer = static_cast(buffer) + diff; + ++count; + } + + return count; +} +#endif // ENABLE_MULTI_DEVICE + +std::unique_ptr MpiComm::bcastAsync(void* buffer, size_t size, MpiType dtype, + int root) const { + std::unique_ptr r = std::make_unique(); +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Ibcast, buffer, size, getMpiDtype(dtype), root, mComm, &r->mRequest); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + return r; +} + +// std::unique_ptr MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const +// { +// return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root); +// } + +// void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const +// { +// #if ENABLE_MULTI_DEVICE +// invokeChunked(MPI_Bcast, buffer, size, getMpiDtype(dtype), root, mComm); +// #else +// TLLM_THROW("Multi device support is disabled."); +// #endif // ENABLE_MULTI_DEVICE +// } + +// void MpiComm::bcast(runtime::IBuffer& buf, int root) const +// { +// bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root); +// } + +std::unique_ptr MpiComm::sendAsync(void const* buffer, size_t size, MpiType dtype, + int dest, MpiTag tag) const { + TLLM_LOG_DEBUG("start MPI_Isend with dest %d, tag %d, size %d", dest, static_cast(tag), + size); + std::unique_ptr r = std::make_unique(); +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Isend, buffer, size, getMpiDtype(dtype), dest, static_cast(tag), mComm, + &r->mRequest); +#else + TLLM_THROW("Multi device support is disabled."); +#endif + TLLM_LOG_DEBUG("end MPI_Isend with dest %d, tag %d, size %d", dest, static_cast(tag), size); + return r; +} + +// std::unique_ptr MpiComm::sendAsync(runtime::IBuffer const& buf, int dest, MpiTag tag) +// const +// { +// return sendAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag); +// } + +void MpiComm::sendRawTag(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const { + TLLM_LOG_DEBUG("start MPI_Send with dest %d, tag %d, size %d", dest, tag, size); +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Send, buffer, size, getMpiDtype(dtype), dest, tag, mComm); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + TLLM_LOG_DEBUG("end MPI_Send with dest %d, tag %d, size %d", dest, tag, size); +} + +void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, MpiTag tag) const { + sendRawTag(buffer, size, dtype, dest, static_cast(tag)); +} + +// void MpiComm::send(runtime::IBuffer const& buf, int dest, MpiTag tag) const +// { +// send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag); +// } + +MPI_Status MpiComm::recvRawTag(void* buffer, size_t size, MpiType dtype, int source, + int tag) const { + TLLM_LOG_DEBUG("start MPI_Recv with source %d, tag %d, size %d", source, tag, size); + MPI_Status status{}; +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Recv, buffer, size, getMpiDtype(dtype), source, tag, mComm, &status); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + TLLM_LOG_DEBUG("end MPI_Recv with source %d, tag %d, size %d", source, tag, size); + return status; +} + +MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, MpiTag tag) const { + return recvRawTag(buffer, size, dtype, source, static_cast(tag)); +} + +// MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, MpiTag tag) const +// { +// return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag); +// } + +MpiComm MpiComm::split(int color, int key) const { + MPI_Comm splitComm = nullptr; +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Comm_split(mComm, color, key, &splitComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + return MpiComm{splitComm, true}; +} + +MpiComm const& MpiComm::setRawSessionByFortran(int64_t fortranHandle) { +#if ENABLE_MULTI_DEVICE + auto comm = MpiComm{MPI_Comm_f2c(fortranHandle), false}; +#else + TLLM_THROW("Multi device support is disabled."); + auto comm = MpiComm(nullptr, false); +#endif // ENABLE_MULTI_DEVICE + return MpiComm::setSession(std::move(comm)); +} + +void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, + MpiOp op) const { +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const { +#if ENABLE_MULTI_DEVICE + MPICHECK( + MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf, + std::vector const& recvcounts, std::vector const& displs, + MpiType recvtype) const { +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), + displs.data(), getMpiDtype(recvtype), mComm)); + +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::mprobeRawTag(int source, int tag, MPI_Message* msg, MPI_Status* status) const { +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::mprobe(int source, MpiTag tag, MPI_Message* msg, MPI_Status* status) const { + mprobeRawTag(source, static_cast(tag), msg, status); +} + +bool MpiComm::improbe(int source, MpiTag tag, MPI_Message* msg, MPI_Status* status) const { +#if ENABLE_MULTI_DEVICE + int flag{0}; + MPICHECK(MPI_Improbe(source, static_cast(tag), mComm, &flag, msg, status)); + return flag != 0; +#else + TLLM_THROW("Multi device support is disabled."); + return false; +#endif +} + +bool MpiComm::iprobe(int source, MpiTag tag, MPI_Status* status) const { +#if ENABLE_MULTI_DEVICE + int flag{0}; + MPICHECK(MPI_Iprobe(source, static_cast(tag), mComm, &flag, status)); + return flag != 0; +#else + TLLM_THROW("Multi device support is disabled."); + return false; +#endif +} + +void MpiComm::recvPoll(int source, MpiTag tag, int periodMs) const { + MPI_Status status; + while (!iprobe(source, tag, &status)) { + std::this_thread::sleep_for(std::chrono::milliseconds(periodMs)); + } +} + +int MpiComm::getRank() const { + int rank = 0; +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Comm_rank(mComm, &rank)); +#endif + return rank; +} + +int MpiComm::getSize() const { + int world_size = 1; +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Comm_size(mComm, &world_size)); +#endif + return world_size; +} + +MpiComm const& MpiComm::world() { + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + static MpiComm commWorld{MPI_COMM_WORLD, false}; + initialize(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return commWorld; +} + +MpiComm& MpiComm::mutableSession() { + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + static MpiComm commSession{MPI_COMM_WORLD, false}; + initialize(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return commSession; +} + +MpiComm& MpiComm::mutableLocalSession() { + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + static MpiComm localSession = initLocalSession(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return localSession; +} + +void MpiComm::refreshLocalSession() { +#if ENABLE_MULTI_DEVICE + static std::mutex mutex; + std::unique_lock lock(mutex); + auto initSessionRanks = getWorldRanks(MpiComm::session()); + auto localSessionRanks = getWorldRanks(MpiComm::localSession()); + + // Add to intersectionRanks in order of initSessionRanks + std::vector intersectionRanks; + std::unordered_set localSessionRanksSet(localSessionRanks.begin(), localSessionRanks.end()); + for (auto rank : initSessionRanks) { + if (localSessionRanksSet.find(rank) != localSessionRanksSet.end()) { + intersectionRanks.push_back(rank); + } + } + + MPI_Group worldGroup = nullptr; + MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); + MPI_Group localGroup = nullptr; + MPICHECK( + MPI_Group_incl(worldGroup, intersectionRanks.size(), intersectionRanks.data(), &localGroup)); + MPI_Comm localComm = nullptr; + MPICHECK( + MPI_Comm_create_group(MPI_COMM_WORLD, localGroup, intersectionRanks.front(), &localComm)); + MpiComm::mutableLocalSession().mFreeComm = true; + MpiComm::mutableLocalSession() = MpiComm{localComm, false}; + TLLM_LOG_INFO("Refreshed the MPI local session"); +#endif // ENABLE_MULTI_DEVICE +} + +MpiComm::MpiComm(MPI_Comm g, bool freeComm) : mComm{g}, mFreeComm{freeComm} { + TLLM_CHECK(mComm != MPI_COMM_NULL); +} + +MpiComm::~MpiComm() noexcept { +#if ENABLE_MULTI_DEVICE + if (mFreeComm && mComm) { + if (MPI_Comm_free(&mComm) != MPI_SUCCESS) { + TLLM_LOG_ERROR("MPI_Comm_free failed"); + } + } +#endif // ENABLE_MULTI_DEVICE +} + +MpiComm::MpiComm(MpiComm&& comm) noexcept : mComm{comm.mComm}, mFreeComm{comm.mFreeComm} { + comm.mFreeComm = false; +} + +MpiComm& MpiComm::operator=(MpiComm&& comm) noexcept { + this->~MpiComm(); + mComm = comm.mComm; + mFreeComm = comm.mFreeComm; + comm.mFreeComm = false; + return *this; +} + +MpiWaitThread::MpiWaitThread(std::string name, std::function funcWait, + std::function funcSetup) + : mName{name.c_str()}, mFuncWait{funcWait}, mFuncSetup{funcSetup} { + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + mThread = std::make_unique(&MpiWaitThread::sideThread, this); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +MpiWaitThread::~MpiWaitThread() { + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + waitStop(); + mShouldExit.store(true); + notifyStart(); + mThread->join(); + mThread.reset(nullptr); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::sideThread() { + if (mFuncSetup) { + mFuncSetup(); + } + while (!mShouldExit.load()) { + notifyStop(); + waitStart(); + mFuncWait(); + } +} + +void MpiWaitThread::waitStart() { + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::unique_lock lock(mMutex); + mCondVar.wait(lock, [this] { return mRunning; }); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::waitStop() { + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::unique_lock lock(mMutex); + mCondVar.wait(lock, [this] { return !mRunning; }); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::notifyStart() { + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::lock_guard lock(mMutex); + mRunning = true; + mCondVar.notify_one(); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::notifyStop() { + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::lock_guard lock(mMutex); + mRunning = false; + mCondVar.notify_one(); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +} // namespace tensorrt_llm::mpi diff --git a/csrc/nv_internal/include/tensorrt_llm/common/opUtils.h b/csrc/nv_internal/include/tensorrt_llm/common/opUtils.h new file mode 100644 index 000000000..51785ef89 --- /dev/null +++ b/csrc/nv_internal/include/tensorrt_llm/common/opUtils.h @@ -0,0 +1,38 @@ + +#pragma once + +#include +#include +#include + +#include "tensorrt_llm/common/NvInferRuntime.h" +#include "tensorrt_llm/common/cublasMMWrapper.h" +#include "tensorrt_llm/common/workspace.h" +#if ENABLE_MULTI_DEVICE +#include +#endif // ENABLE_MULTI_DEVICE + +#include + +#include +#include +#include +#include +#include +#include +#include + +#ifdef ENABLE_MULTI_DEVICE +#define NCCLCHECK_THROW(cmd) \ + do { \ + ncclResult_t r = cmd; \ + if (TLLM_UNLIKELY(r != ncclSuccess)) { \ + TLLM_THROW("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \ + } \ + } while (0) + +std::unordered_map* getDtypeMap(); + +std::shared_ptr getComm(std::set const& group); + +#endif // ENABLE_MULTI_DEVICE diff --git a/csrc/nv_internal/include/tensorrt_llm/runtime/common.h b/csrc/nv_internal/include/tensorrt_llm/runtime/common.h new file mode 100644 index 000000000..1df053e09 --- /dev/null +++ b/csrc/nv_internal/include/tensorrt_llm/runtime/common.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace tensorrt_llm::runtime { + +#define FMT_DIM "%ld" + +// typedefs +// Note that we use signed size types as recommended by TensorRT: +// https://github.com/NVIDIA/TensorRT/blob/main/CODING-GUIDELINES.md#signed-vs-unsigned-integers +using SizeType32 = std::int32_t; +using SizeType64 = std::int64_t; + +enum class RequestType : std::int32_t { kCONTEXT = 0, kGENERATION = 1 }; + +// Token ID type +using TokenIdType = std::int32_t; + +using LoraTaskIdType = std::uint64_t; +using TokenExtraIdType = std::uint64_t; +using VecTokenExtraIds = std::vector; + +struct UniqueToken { + TokenIdType tokenId; + TokenExtraIdType tokenExtraId; + + bool operator==(UniqueToken const& other) const noexcept { + return (tokenId == other.tokenId && tokenExtraId == other.tokenExtraId); + } +}; + +using VecUniqueTokens = std::vector; + +template +using StringPtrMap = std::unordered_map>; + +} // namespace tensorrt_llm::runtime diff --git a/csrc/nv_internal/tensorrt_llm/runtime/opUtils.cpp b/csrc/nv_internal/tensorrt_llm/runtime/opUtils.cpp new file mode 100644 index 000000000..56250668b --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/runtime/opUtils.cpp @@ -0,0 +1,111 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/opUtils.h" + +#include +#include +#include + +#include +#include +#include + +#include "cuda.h" +#include "tensorrt_llm/runtime/utils/mpiTags.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" + +#if ENABLE_MULTI_DEVICE + +std::unordered_map* getDtypeMap() { + static std::unordered_map dtypeMap = { + {nvinfer1::DataType::kFLOAT, ncclFloat32}, {nvinfer1::DataType::kHALF, ncclFloat16}, + {nvinfer1::DataType::kBF16, ncclBfloat16}, {nvinfer1::DataType::kFP8, ncclInt8}, + {nvinfer1::DataType::kBOOL, ncclInt8}, {nvinfer1::DataType::kINT32, ncclInt32}, + {nvinfer1::DataType::kINT64, ncclInt64}, {nvinfer1::DataType::kUINT8, ncclUint8}, + {nvinfer1::DataType::kINT8, ncclInt8}, + }; + return &dtypeMap; +} + +namespace { + +// Get NCCL unique ID for a group of ranks. +ncclUniqueId getUniqueId(std::set const& group) { + auto const rank = COMM_SESSION.getRank(); + TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); + ncclUniqueId id; + if (rank == *group.begin()) { + NCCLCHECK_THROW(ncclGetUniqueId(&id)); + for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) { + COMM_SESSION.sendValue(id, *it, tensorrt_llm::mpi::MpiTag::kDefault); + } + } else { + COMM_SESSION.recvValue(id, *group.begin(), tensorrt_llm::mpi::MpiTag::kDefault); + } + TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); + return id; +} +} // namespace + +std::shared_ptr getComm(std::set const& group) { + auto const rank = COMM_SESSION.getRank(); + TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); + static std::map, std::shared_ptr> commMap; + static std::mutex mutex; + std::lock_guard lock(mutex); + std::ostringstream oss; + int index = 0; + for (auto const& rank : group) { + if (index != 0) { + oss << ","; + } + oss << rank; + index++; + } + auto groupStr = oss.str(); + auto it = commMap.find(group); + if (it != commMap.end()) { + auto ncclComm = it->second; + TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank); + return ncclComm; + } + + TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank); + ncclUniqueId id = getUniqueId(group); + int groupRank = 0; + for (auto const& currentRank : group) { + if (rank == currentRank) break; + ++groupRank; + } + TLLM_CHECK(static_cast(groupRank) < group.size()); + std::shared_ptr ncclComm(new ncclComm_t, [](ncclComm_t* comm) { + ncclCommDestroy(*comm); + delete comm; + }); +// Need static connection initialization for accurate KV cache size estimation +#if defined(_WIN32) + if (getenv("NCCL_RUNTIME_CONNECT") == nullptr) _putenv_s("NCCL_RUNTIME_CONNECT", "0"); +#else + setenv("NCCL_RUNTIME_CONNECT", "0", 0); +#endif // _WIN32 + NCCLCHECK_THROW(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank)); + commMap[group] = ncclComm; + TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); + return ncclComm; +} +#endif // ENABLE_MULTI_DEVICE diff --git a/csrc/nv_internal/tensorrt_llm/runtime/utils/mpiTags.h b/csrc/nv_internal/tensorrt_llm/runtime/utils/mpiTags.h new file mode 100644 index 000000000..683bfaaaf --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/runtime/utils/mpiTags.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace tensorrt_llm::mpi { + +enum class MpiTag : int { + kDefault = 0, + + // DecoderStepAsyncSend + kDecoderStepNewOutputTokensHost = 0, + kDecoderStepFinishedSumHost = 1, + kDecoderStepSequenceLengthsHost = 2, + kDecoderStepCumLogProbsHost = 3, + kDecoderStepLogProbsHost = 4, + kDecoderStepCacheIndirectionOutput = 5, + kDecoderStepAcceptedLengthsCumSumDevice = 6, + kDecoderStepAcceptedPackedPathsDevice = 7, + kDecoderStepFinishReasonsHost = 8, + + // DecoderSlotAsyncSend + kDecoderSlotOutputIds = 9, + kDecoderSlotSequenceLengths = 10, + kDecoderSlotCumLogProbs = 11, + kDecoderSlotLogProbs = 12, + + // CancelledRequestsAsyncSend + kCancelledRequestsNumReq = 13, + kCancelledRequestsIds = 14, + + // RequestWithIdAsyncSend + kRequestWithIdNumReq = 15, + kRequestWithIdVecSize = 16, + kRequestWithIdPacked = 17, + + // Executor + kExecutorNumActiveRequests = 18, + kExecutorLowestPriorityActiveHasValue = 19, + kExecutorLowestPriorityActive = 20, + kExecutorShouldExit = 21, + + // TrtGptModelInflightBatching + kTrtGptModelInflightBatchingContextLogits = 22, + kTrtGptModelInflightBatchingGenerationLogits = 23, + + // Orchestrator + kOrchestratorId = 127, + kOrchestratorData = 1023, + kOrchestratorStatsId = 128, + kOrchestratorStatsData = 1024, + + // LogitsThread + kSpecDecLogitsId = 129, + kSpecDecLogitsData = 1025, +}; + +} // namespace tensorrt_llm::mpi diff --git a/csrc/nv_internal/tensorrt_llm/runtime/utils/mpiUtils.h b/csrc/nv_internal/tensorrt_llm/runtime/utils/mpiUtils.h new file mode 100644 index 000000000..20a0e297c --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/runtime/utils/mpiUtils.h @@ -0,0 +1,421 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "tensorrt_llm/runtime/utils/mpiTags.h" +#include "tensorrt_llm/runtime/utils/multiDeviceUtils.h" + +#ifdef ENABLE_FP8 +#include +#endif +#ifdef ENABLE_BF16 +#include +#endif + +#include +#include +#include +#include +#include +#include +#include + +#if ENABLE_MULTI_DEVICE +#include +#else +// Dummy defines to avoid #if in wider places. +typedef void* MPI_Datatype; +typedef void* MPI_Comm; +typedef void* MPI_Request; +typedef void* MPI_Message; +typedef void* MPI_Op; + +typedef struct MPI_Status { + int dummy; +} MPI_Status; + +#define MPI_THREAD_SINGLE 0 +#define MPI_THREAD_FUNNELED 1 +#define MPI_THREAD_SERIALIZED 2 +#define MPI_THREAD_MULTIPLE 3 +#define MPI_COMM_WORLD ((MPI_Comm)0x44000000) +#define MPI_COMM_NULL ((MPI_Comm)0x04000000) +#endif // ENABLE_MULTI_DEVICE + +#include +#include + +#define MPICHECK(cmd) TLLM_MPI_CHECK(cmd) + +namespace tensorrt_llm::runtime { +class IBuffer; +} + +// A wrapper module of the MPI library. +namespace tensorrt_llm::mpi { + +// A wrapper of MPI data type. MpiType::{data_type} +enum class MpiType { + kBYTE, + kHALF, + kFLOAT, + kDOUBLE, + kBOOL, + kINT8, + kUINT8, + kINT32, + kUINT32, + kINT64, + kUINT64, + kFP8, + kBF16, + kCHAR, +}; + +//! \brief For converting a C++ data type to a TensorRT data type. +template +struct MpiTypeConverter {}; + +template <> +struct MpiTypeConverter { + static constexpr auto value = MpiType::kBYTE; +}; + +template <> +struct MpiTypeConverter + +{ + static constexpr auto value = MpiType::kHALF; +}; + +template <> +struct MpiTypeConverter { + static constexpr auto value = MpiType::kFLOAT; +}; + +template <> +struct MpiTypeConverter { + static constexpr auto value = MpiType::kDOUBLE; +}; + +template <> +struct MpiTypeConverter { + static constexpr auto value = MpiType::kBOOL; +}; + +template <> +struct MpiTypeConverter { + static constexpr auto value = MpiType::kINT8; +}; + +template <> +struct MpiTypeConverter + +{ + static constexpr auto value = MpiType::kUINT8; +}; + +template <> +struct MpiTypeConverter { + static constexpr auto value = MpiType::kINT32; +}; + +template <> +struct MpiTypeConverter { + static constexpr auto value = MpiType::kUINT32; +}; + +template <> +struct MpiTypeConverter { + static constexpr auto value = MpiType::kINT64; +}; + +template <> +struct MpiTypeConverter { + static constexpr auto value = MpiType::kUINT64; +}; + +template <> +struct MpiTypeConverter { + static constexpr auto value = MpiType::kCHAR; +}; + +#ifdef ENABLE_FP8 +template <> +struct MpiTypeConverter<__nv_fp8_e4m3> { + static constexpr auto value = MpiType::kFP8; +}; +#endif + +#ifdef ENABLE_BF16 +template <> +struct MpiTypeConverter<__nv_bfloat16> { + static constexpr auto value = MpiType::kBF16; +}; +#endif + +// A wrapper of MPI_Op type. +enum class MpiOp { + NULLOP, + MAX, + MIN, + SUM, + PROD, + LAND, + BAND, + LOR, + BOR, + LXOR, + BXOR, + MINLOC, + MAXLOC, + REPLACE, +}; + +// A wrapper of the level of MPI thread support +enum class MpiThreadSupport : int { + THREAD_SINGLE = MPI_THREAD_SINGLE, + THREAD_FUNNELED = MPI_THREAD_FUNNELED, + THREAD_SERIALIZED = MPI_THREAD_SERIALIZED, + THREAD_MULTIPLE = MPI_THREAD_MULTIPLE, +}; + +class MpiRequest { + public: + MpiRequest() = default; + + ~MpiRequest() = default; + + void wait() { +#if ENABLE_MULTI_DEVICE + // TODO: Don't ignore return status + TLLM_MPI_CHECK(MPI_Wait(&mRequest, MPI_STATUS_IGNORE)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif + } + + void cancel() { +#if ENABLE_MULTI_DEVICE + TLLM_MPI_CHECK(MPI_Cancel(&mRequest)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif + } + + MPI_Request mRequest{}; +}; + +MPI_Datatype getMpiDtype(MpiType dtype); + +class MpiComm { + public: + explicit MpiComm(MPI_Comm g, bool freeComm); + ~MpiComm() noexcept; + + // no copy + MpiComm(MpiComm const&) = delete; + MpiComm& operator=(MpiComm const&) = delete; + + // move + MpiComm(MpiComm&&) noexcept; + MpiComm& operator=(MpiComm&&) noexcept; + + [[nodiscard]] int getRank() const; + [[nodiscard]] int getSize() const; + + operator MPI_Comm() const // NOLINT(*-explicit-constructor) + { + return mComm; + } + + //! \brief Returns the MPI world communicator. + static MpiComm const& world(); + + //! \brief Corresponds to `world()` by default, but can be overridden per process. + static MpiComm const& session() { return mutableSession(); } + + //! \brief Returns the MPI local communicator. + static MpiComm const& localSession() { return mutableLocalSession(); } + + static MpiComm const& setSession(MpiComm comm) { + auto& session = mutableSession(); + session = std::move(comm); + refreshLocalSession(); + return session; + } + + static MpiComm const& setRawSessionByFortran(int64_t fortranHandle); + + [[nodiscard]] MpiComm split(int color, int key) const; + + std::unique_ptr bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const; + + std::unique_ptr bcastAsync(runtime::IBuffer& buf, int root) const; + + void bcast(void* buffer, size_t size, MpiType dtype, int root) const; + + void bcast(runtime::IBuffer& buf, int root) const; + + template + void bcastValue(T& value, int root) const { + if constexpr (std::is_fundamental_v>) { + bcast(&value, 1, MpiTypeConverter>::value, root); + } else { + bcast(&value, sizeof(T), MpiType::kBYTE, root); + } + } + + template + void bcast(std::vector& vec, int root) const { + auto const rank = getRank(); + auto vecSize = (rank == root) ? static_cast(vec.size()) : int64_t(0); + bcast(&vecSize, 1, MpiType::kINT64, root); + vec.resize(vecSize); + if (vec.empty()) { + return; + } + + size_t bcastSize = vec.size() * sizeof(T); + if constexpr (std::is_fundamental_v>) { + bcastSize = vec.size(); + } + + // To prevent overflowing int32_t limit + size_t const maxChunkSize = std::numeric_limits::max(); + for (size_t pos = 0; pos < bcastSize; pos += maxChunkSize) { + auto chunkSize = std::min(bcastSize - pos, maxChunkSize); + auto intChunkSize = static_cast(chunkSize); + if constexpr (std::is_fundamental_v>) { + bcast(vec.data() + pos, intChunkSize, MpiTypeConverter>::value, root); + } else { + bcast(reinterpret_cast(vec.data()) + pos, intChunkSize, MpiType::kBYTE, root); + } + } + } + + std::unique_ptr sendAsync(void const* buffer, std::size_t size, MpiType dtype, + int dest, MpiTag tag) const; + std::unique_ptr sendAsync(runtime::IBuffer const& buf, int dest, MpiTag tag) const; + //! \deprecated This function is discouraged. Use the one with MpiTag enum instead. + void sendRawTag(void const* buffer, std::size_t size, MpiType dtype, int dest, int tag) const; + void send(void const* buffer, std::size_t size, MpiType dtype, int dest, MpiTag tag) const; + void send(runtime::IBuffer const& buf, int dest, MpiTag tag) const; + + template + void sendValue(T const& value, int dest, MpiTag tag) const { + if constexpr (std::is_fundamental_v>) { + send(&value, 1, MpiTypeConverter>::value, dest, tag); + } else { + send(&value, sizeof(T), MpiType::kBYTE, dest, tag); + } + } + + //! \deprecated This function is discouraged. Use the one with MpiTag enum instead. + MPI_Status recvRawTag(void* buffer, size_t size, MpiType dtype, int source, int tag) const; + MPI_Status recv(void* buffer, size_t size, MpiType dtype, int source, MpiTag tag) const; + MPI_Status recv(runtime::IBuffer& buf, int source, MpiTag tag) const; + + template + MPI_Status recvValue(T& value, int source, MpiTag tag) const { +#if ENABLE_MULTI_DEVICE + if constexpr (std::is_fundamental_v>) { + return recv(&value, 1, MpiTypeConverter>::value, source, tag); + } else { + return recv(&value, sizeof(T), MpiType::kBYTE, source, tag); + } +#else + TLLM_THROW("Multi device support is disabled."); +#endif + } + + void allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const; + void allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const; + + void allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf, + std::vector const& recvcounts, std::vector const& displs, + MpiType recvtype) const; + + void barrier() const; + + //! \deprecated This function is discouraged. Use the one with MpiTag enum instead. + void mprobeRawTag(int source, int tag, MPI_Message* msg, MPI_Status* status) const; + void mprobe(int source, MpiTag tag, MPI_Message* msg, MPI_Status* status) const; + bool improbe(int source, MpiTag tag, MPI_Message* msg, MPI_Status* status) const; + + //! \brief Returns if a message with the specified source and tag is available + bool iprobe(int source, MpiTag tag, MPI_Status* status) const; + + //! \brief Poll every periodMs until a message is available + void recvPoll(int source, MpiTag tag, int periodMs) const; + + bool operator==(MpiComm const& rhs) const { return mComm == rhs.mComm; } + + bool operator!=(MpiComm const& rhs) const { return !(rhs == *this); } + + private: + //! \brief Corresponds to `world()` by default, but can be overridden per process. + static MpiComm& mutableSession(); + + //! \brief Returns the MPI local communicator. + static MpiComm& mutableLocalSession(); + + static void refreshLocalSession(); + + MPI_Comm mComm; + bool mFreeComm; +}; + +std::vector getWorldRanks(MpiComm const& comm); + +int getNumNodes(); + +void initialize(MpiThreadSupport threadMode = MpiThreadSupport::THREAD_MULTIPLE, + bool forwardAbortToParent = false); + +class MpiWaitThread { + public: + explicit MpiWaitThread(std::string name, std::function funcWait, + std::function funcSetup = nullptr); + ~MpiWaitThread(); + + void waitStop(); + void notifyStart(); + + private: + void sideThread(); + + void waitStart(); + void notifyStop(); + + std::string mName; + std::function mFuncWait; + std::function mFuncSetup; + std::unique_ptr mThread; + std::mutex mMutex; + std::condition_variable mCondVar; + bool mRunning{true}; + std::atomic mShouldExit{false}; +}; + +} // namespace tensorrt_llm::mpi + +#define COMM_SESSION tensorrt_llm::mpi::MpiComm::session() +#define LOCAL_COMM_SESSION tensorrt_llm::mpi::MpiComm::localSession() diff --git a/csrc/nv_internal/tensorrt_llm/runtime/utils/multiDeviceUtils.h b/csrc/nv_internal/tensorrt_llm/runtime/utils/multiDeviceUtils.h new file mode 100644 index 000000000..6c7cbdd50 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/runtime/utils/multiDeviceUtils.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/stringUtils.h" + +#if ENABLE_MULTI_DEVICE +#include +#include + +#define TLLM_MPI_CHECK(cmd) \ + do { \ + auto e = cmd; \ + TLLM_CHECK_WITH_INFO(e == MPI_SUCCESS, "Failed: MPI error %s:%d '%d'", __FILE__, __LINE__, e); \ + } while (0) + +#define TLLM_NCCL_CHECK(cmd) \ + do { \ + ncclResult_t r = cmd; \ + TLLM_CHECK_WITH_INFO(r == ncclSuccess, "Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \ + ncclGetErrorString(r)); \ + } while (0) +#endif // ENABLE_MULTI_DEVICE diff --git a/csrc/trtllm_allgather.cu b/csrc/trtllm_allgather.cu new file mode 100644 index 000000000..48fca80ac --- /dev/null +++ b/csrc/trtllm_allgather.cu @@ -0,0 +1,153 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include + +#include "tensorrt_llm/common/NvInferRuntime.h" +#include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/runtime/torchUtils.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" +// #include +#include + +#include "pytorch_extension_utils.h" +#if ENABLE_MULTI_DEVICE +#include +#endif // ENABLE_MULTI_DEVICE + +namespace torch_ext { +#if ENABLE_MULTI_DEVICE + +namespace { + +class AllgatherOp { + public: + AllgatherOp(std::set group) : mGroup(std::move(group)) {} + + ~AllgatherOp() = default; + + int initialize() { + TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank()); + mNcclComm = getComm(mGroup); + TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank()); + return 0; + } + + torch::Tensor run(torch::Tensor input, torch::optional> sizes) { + TLLM_CHECK_WITH_INFO(mNcclComm.get() != nullptr, "mNcclComm should be initialized before used"); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + auto type = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type()); + std::vector outputShape = input.sizes().vec(); + if (sizes.has_value()) { + outputShape[0] = + std::accumulate(sizes.value().begin(), sizes.value().end(), 0, std::plus<>{}); + } else { + outputShape[0] *= mGroup.size(); + } + auto output = torch::empty(outputShape, input.options()); + if (sizes.has_value()) { + size_t numel_base = + std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{}); + int64_t split_offset = 0; + ncclGroupStart(); + for (int root = 0; root < static_cast(mGroup.size()); ++root) { + auto split_size = sizes.value()[root]; + NCCLCHECK_THROW(ncclBroadcast( + input.data_ptr(), + output.index({torch::indexing::Slice(split_offset, torch::indexing::None)}) + .mutable_data_ptr(), + numel_base * split_size, (*getDtypeMap())[type], root, *mNcclComm, stream)); + split_offset += split_size; + } + ncclGroupEnd(); + } else { + NCCLCHECK_THROW(ncclAllGather(input.data_ptr(), output.mutable_data_ptr(), input.numel(), + (*getDtypeMap())[type], *mNcclComm, stream)); + } + return output; + } + + std::vector run_list(torch::TensorList input_list, + torch::optional> sizes) { + std::vector output_list; + output_list.reserve(input_list.size()); + ncclGroupStart(); + for (auto const& input : input_list) { + auto output = run(input, sizes); + output_list.push_back(output); + } + ncclGroupEnd(); + return output_list; + } + + private: + std::set mGroup; + std::shared_ptr mNcclComm; +}; + +} // namespace + +#endif // ENABLE_MULTI_DEVICE + +torch::Tensor allgather(torch::Tensor input, torch::optional> sizes, + torch::List group_) { +#if ENABLE_MULTI_DEVICE + std::set group; + for (int64_t rank : group_) { + group.insert(static_cast(rank)); + } + AllgatherOp op(group); + op.initialize(); + auto output = op.run(input, sizes); + return output; +#else + return input; +#endif // ENABLE_MULTI_DEVICE +} + +std::vector allgather_list(torch::TensorList input_list, + torch::optional> sizes, + torch::List group_) { +#if ENABLE_MULTI_DEVICE + std::set group; + for (int64_t rank : group_) { + group.insert(static_cast(rank)); + } + AllgatherOp op(group); + op.initialize(); + auto output_list = op.run_list(input_list, sizes); + return output_list; +#else + return input_list.vec(); +#endif // ENABLE_MULTI_DEVICE +} + +} // namespace torch_ext + +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + m.def("trtllm_allgather(Tensor input, int[]? sizes, int[] group) -> Tensor"); + m.def("trtllm_allgather_list(Tensor[] input_list, int[]? sizes, int[] group) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("trtllm_allgather", &torch_ext::allgather); + m.impl("trtllm_allgather_list", &torch_ext::allgather_list); +} diff --git a/csrc/trtllm_reducescatter.cu b/csrc/trtllm_reducescatter.cu new file mode 100644 index 000000000..e2c651c70 --- /dev/null +++ b/csrc/trtllm_reducescatter.cu @@ -0,0 +1,156 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "tensorrt_llm/common/NvInferRuntime.h" +#include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/runtime/torchUtils.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" +// #include +#include "pytorch_extension_utils.h" +#if ENABLE_MULTI_DEVICE +#include +#endif // ENABLE_MULTI_DEVICE + +#include +#include +#include + +namespace torch_ext { +#if ENABLE_MULTI_DEVICE + +namespace { + +class ReducescatterOp { + public: + ReducescatterOp(std::set group) : mGroup(std::move(group)) {} + + ~ReducescatterOp() = default; + + int initialize() { + TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank()); + mNcclComm = getComm(mGroup); + TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank()); + return 0; + } + + torch::Tensor run(torch::Tensor const& input, torch::optional> sizes) { + TLLM_CHECK_WITH_INFO(mNcclComm.get() != nullptr, "mNcclComm should be initialized before used"); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + auto type = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type()); + std::vector outputShape = input.sizes().vec(); + if (sizes.has_value()) { + auto rank = COMM_SESSION.getRank(); + int groupRank = 0; + for (auto const& currentRank : mGroup) { + if (rank == currentRank) break; + ++groupRank; + } + TLLM_CHECK(static_cast(groupRank) < mGroup.size()); + outputShape[0] = sizes.value()[groupRank]; + } else { + outputShape[0] = outputShape[0] / mGroup.size(); + } + auto output = torch::empty(outputShape, input.options()); + if (sizes.has_value()) { + size_t numel_base = + std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{}); + int64_t split_offset = 0; + ncclGroupStart(); + for (int root = 0; root < static_cast(mGroup.size()); ++root) { + auto split_size = sizes.value()[root]; + NCCLCHECK_THROW(ncclReduce( + input.index({torch::indexing::Slice(split_offset, torch::indexing::None)}).data_ptr(), + output.mutable_data_ptr(), numel_base * split_size, (*getDtypeMap())[type], ncclSum, + root, *mNcclComm, stream)); + split_offset += split_size; + } + ncclGroupEnd(); + } else { + NCCLCHECK_THROW(ncclReduceScatter(input.data_ptr(), output.mutable_data_ptr(), output.numel(), + (*getDtypeMap())[type], ncclSum, *mNcclComm, stream)); + } + return output; + } + + std::vector run_list(torch::TensorList input_list, + torch::optional> sizes) noexcept { + std::vector output_list; + output_list.reserve(input_list.size()); + ncclGroupStart(); + for (auto const& input : input_list) { + auto output = run(input, sizes); + output_list.push_back(output); + } + ncclGroupEnd(); + return output_list; + } + + private: + std::set mGroup; + std::shared_ptr mNcclComm; +}; + +} // namespace + +#endif // ENABLE_MULTI_DEVICE + +extern torch::Tensor reducescatter(torch::Tensor input, torch::optional> sizes, + torch::List group_) { +#if ENABLE_MULTI_DEVICE + std::set group; + for (int64_t rank : group_) { + group.insert(static_cast(rank)); + } + ReducescatterOp op(group); + op.initialize(); + auto output = op.run(input, sizes); + return output; +#else + return input; +#endif // ENABLE_MULTI_DEVICE +} + +extern std::vector reducescatter_list(torch::TensorList input_list, + torch::optional> sizes, + torch::List group_) { +#if ENABLE_MULTI_DEVICE + std::set group; + for (int64_t rank : group_) { + group.insert(static_cast(rank)); + } + ReducescatterOp op(group); + op.initialize(); + auto output_list = op.run_list(input_list, sizes); + return output_list; +#else + return input_list.vec(); +#endif // ENABLE_MULTI_DEVICE +} + +} // namespace torch_ext + +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + m.def("reducescatter(Tensor input, int[]? sizes, int[] group) -> Tensor"); + m.def("reducescatter_list(Tensor[] input_list, int[]? sizes, int[] group) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("reducescatter", &torch_ext::reducescatter); + m.impl("reducescatter_list", &torch_ext::reducescatter_list); +} diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index b92d32b89..c716e8a2c 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -52,6 +52,7 @@ from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper from .gemm import bmm_fp8 as bmm_fp8 from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper +from .moe_mapping import Mapping as MoE_Mapping from .norm import fused_add_rmsnorm as fused_add_rmsnorm from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm from .norm import gemma_rmsnorm as gemma_rmsnorm diff --git a/flashinfer/comm.py b/flashinfer/comm.py index 3876fd170..737e9c49b 100644 --- a/flashinfer/comm.py +++ b/flashinfer/comm.py @@ -16,17 +16,22 @@ import ctypes import functools +import math +import os from dataclasses import dataclass +from itertools import accumulate from types import SimpleNamespace -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from flashinfer import MoE_Mapping + from .jit import JitSpec from .jit import env as jit_env -from .jit import gen_jit_spec +from .jit import gen_jit_spec, sm100a_nvcc_flags from .utils import register_custom_op # NOTE(Zihao): we should use cuda-python instead of ctypes cuda runtime bindings. @@ -231,13 +236,139 @@ def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: cudart = CudaRTLibrary() +def get_mpi_include_lib_path(): + import pathlib + import shlex + import subprocess + + cmd = ["mpicc", "-show"] + output = subprocess.check_output(cmd, text=True) + # Parse the output to extract include and library paths + parts = shlex.split(output) + include_dirs = [] + lib_dirs = [] + + i = 0 + while i < len(parts): + if parts[i] == "-I" and i + 1 < len(parts): + include_dirs.append(pathlib.Path(parts[i + 1])) + i += 2 + elif parts[i].startswith("-I"): + include_dirs.append(pathlib.Path(parts[i][2:])) + i += 1 + elif parts[i] == "-L" and i + 1 < len(parts): + lib_dirs.append(pathlib.Path(parts[i + 1])) + i += 2 + elif parts[i].startswith("-L"): + lib_dirs.append(pathlib.Path(parts[i][2:])) + i += 1 + else: + i += 1 + + # Return the first include directory found, or None if none found + include_dir = include_dirs[0] if include_dirs else None + + return include_dir, lib_dirs + + +def get_nccl_include_lib_path(): + import pathlib + import subprocess + + # Try to find NCCL paths from environment variables first + nccl_include_path = os.environ.get("NCCL_INCLUDE_PATH") + nccl_lib_path = os.environ.get("NCCL_LIB_PATH") + + if nccl_include_path and nccl_lib_path: + return nccl_include_path, [nccl_lib_path] + + # Find NCCL library using find command + try: + lib_output = subprocess.check_output( + ["find", "/usr", "-name", "libnccl*.so"], text=True + ).strip() + if lib_output: + lib_path = pathlib.Path(lib_output.split("\n")[0]).parent + # Find corresponding include path + include_output = subprocess.check_output( + ["find", "/usr/include", "-name", "nccl.h"], text=True + ).strip() + if include_output: + include_path = pathlib.Path(include_output.split("\n")[0]).parent + return include_path, [lib_path] + except subprocess.CalledProcessError: + pass + + raise RuntimeError( + "Could not find NCCL include or library paths. " + "Please set NCCL_INCLUDE_PATH and NCCL_LIB_PATH environment variables." + ) + + +def get_output_info(input: torch.Tensor, dim: int) -> List[int]: + dim = dim % input.ndim + output_shape = [val if idx != dim else -1 for idx, val in enumerate(input.shape)] + numel_base = -math.prod(output_shape) + return {"output_shape": output_shape, "numel_base": numel_base} + + +def filter_valid_input( + input_list: List[torch.Tensor], +) -> Tuple[List[torch.Tensor], List[bool]]: + func_valid = lambda x: x is not None + valid_list = list(map(func_valid, input_list)) + input_list = list(filter(func_valid, input_list)) + return input_list, valid_list + + +def restore_full_output( + output_list: List[torch.Tensor], valid_list: List[bool] +) -> List[torch.Tensor]: + index_list = list(accumulate(map(int, valid_list))) + output_list = list( + map( + lambda valid, index: output_list[index - 1] if valid else None, + valid_list, + index_list, + ) + ) + return output_list + + def gen_comm_module() -> JitSpec: + mpi_include_path, mpi_lib_path = get_mpi_include_lib_path() + nccl_include_path, nccl_lib_path = get_nccl_include_lib_path() + mpi_lib_path = str(mpi_lib_path[0]) + nccl_lib_path = str(nccl_lib_path[0]) + print(mpi_include_path, mpi_lib_path) + print(nccl_include_path, nccl_lib_path) return gen_jit_spec( "comm", [ jit_env.FLASHINFER_CSRC_DIR / "flashinfer_comm_ops.cu", jit_env.FLASHINFER_CSRC_DIR / "custom_all_reduce.cu", jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce.cu", + jit_env.FLASHINFER_CSRC_DIR / "trtllm_allgather.cu", + jit_env.FLASHINFER_CSRC_DIR / "trtllm_reducescatter.cu", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/opUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/runtime/utils/mpiUtils.cpp", + ], + extra_include_paths=[ + jit_env.FLASHINFER_CSRC_DIR / "nv_internal", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", + mpi_include_path, + nccl_include_path, + ], + extra_ldflags=[f"-L{mpi_lib_path}", "-lmpi", f"-L{nccl_lib_path}", "-lnccl"], + extra_cuda_cflags=sm100a_nvcc_flags + + [ + "-DENABLE_MULTI_DEVICE", + ], + extra_cflags=[ + "-DENABLE_MULTI_DEVICE", ], ) @@ -396,6 +527,155 @@ def trtllm_custom_all_reduce( lamport_peer_comm_buffer_ptrs_2, ) + @register_custom_op( + "flashinfer::all_gather", + mutates_args=[], + ) + def all_gather( + input: Union[torch.Tensor, List[torch.Tensor]], + mapping: MoE_Mapping, + dim: int = -1, + sizes: Optional[List[int]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + if mapping.tp_size == 1: + return input + + if sizes is not None: + assert len(sizes) == len(mapping.tp_group) + if isinstance(input, torch.Tensor): + assert input.shape[dim] == sizes[mapping.tp_rank] + else: + assert all( + [ + val.shape[dim] == sizes[mapping.tp_rank] + for val in input + if val is not None + ] + ) + # 'sizes' is not needed if all inputs in the same TP group have the same shape + for split_size in sizes[1:]: + if split_size != sizes[0]: + break + else: + sizes = None + + # Inputs are reshaped in this way to pass necessary shape information to the allgather op + if isinstance(input, torch.Tensor): + torch_op = module.trtllm_allgather + output_info = get_output_info(input, dim) + input = input.contiguous().view(-1, output_info["numel_base"]) + else: + input, valid = filter_valid_input(input) + torch_op = module.trtllm_allgather_list + output_info = [get_output_info(val, dim) for val in input] + input = [ + val.contiguous().view(-1, val_info["numel_base"]) + for val, val_info in zip(input, output_info) + ] + + output = torch_op( + input, + sizes, + mapping.tp_group, + ) + + def convert_output(x, x_info): + if dim == 0: + x = x.view(x_info["output_shape"]) + else: + if sizes is None: + x_list = x.chunk(mapping.tp_size) + else: + x_list = x.split(sizes) + x = torch.cat( + [x.reshape(x_info["output_shape"]) for x in x_list], dim=dim + ) + return x + + if isinstance(input, torch.Tensor): + output = convert_output(output, output_info) + else: + output = [ + convert_output(val, val_info) + for val, val_info in zip(output, output_info) + ] + output = restore_full_output(output, valid) + return output + + @register_custom_op( + "flashinfer::reduces_catter", + mutates_args=[], + ) + def reduce_scatter( + input: Union[torch.Tensor, List[torch.Tensor]], + mapping: MoE_Mapping, + dim: int = -1, + sizes: Optional[List[int]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + if mapping.tp_size == 1: + return input + + if sizes is not None: + assert len(sizes) == len(mapping.tp_group) + sum_split_size = sum(sizes) + if isinstance(input, torch.Tensor): + assert input.shape[dim] == sum_split_size + else: + assert all( + [ + val.shape[dim] == sum_split_size + for val in input + if val is not None + ] + ) + # 'sizes' is not needed if all outputs in the same TP group have the same shape + for split_size in sizes[1:]: + if split_size != sizes[0]: + break + else: + sizes = None + + def convert_input(x, x_info): + # Inputs are reshaped in this way to pass necessary shape information to the reducescatter op + if dim == 0: + x = x.contiguous().view(-1, x_info["numel_base"]) + else: + if sizes is None: + x_list = x.chunk(mapping.tp_size, dim=dim) + else: + x_list = x.split(sizes, dim=dim) + x = torch.cat([x.reshape(-1, x_info["numel_base"]) for x in x_list]) + return x + + if isinstance(input, torch.Tensor): + torch_op = module.reducescatter + output_info = get_output_info(input, dim) + input = convert_input(input, output_info) + else: + input, valid = filter_valid_input(input) + torch_op = module.reducescatter_list + output_info = [get_output_info(val, dim) for val in input] + input = [ + convert_input(val, val_info) + for val, val_info in zip(input, output_info) + ] + + output = torch_op( + input, + sizes, + mapping.tp_group, + ) + + if isinstance(input, torch.Tensor): + output = output.view(output_info["output_shape"]) + else: + output = [ + val.view(val_info["output_shape"]) + for val, val_info in zip(output, output_info) + ] + output = restore_full_output(output, valid) + return output + return SimpleNamespace( init_custom_ar=init_custom_ar, dispose=dispose, @@ -404,6 +684,8 @@ def trtllm_custom_all_reduce( register_graph_buffers=register_graph_buffers, meta_size=meta_size, all_reduce=all_reduce, + all_gather=all_gather, + reduce_scatter=reduce_scatter, trtllm_lamport_initialize=trtllm_lamport_initialize, trtllm_lamport_initialize_all=trtllm_lamport_initialize_all, trtllm_custom_all_reduce=trtllm_custom_all_reduce, @@ -723,3 +1005,31 @@ def trtllm_custom_all_reduce( lamport_peer_comm_buffer_ptrs_1, lamport_peer_comm_buffer_ptrs_2, ) + + +def all_gather( + input: Union[torch.Tensor, List[torch.Tensor]], + mapping: MoE_Mapping, + dim: int = -1, + sizes: Optional[List[int]] = None, +) -> Union[torch.Tensor, List[torch.Tensor]]: + return get_comm_module().all_gather( + input, + mapping, + dim, + sizes, + ) + + +def reduce_scatter( + input: Union[torch.Tensor, List[torch.Tensor]], + mapping: MoE_Mapping, + dim: int = -1, + sizes: Optional[List[int]] = None, +) -> Union[torch.Tensor, List[torch.Tensor]]: + return get_comm_module().reduce_scatter( + input, + mapping, + dim, + sizes, + ) diff --git a/flashinfer/moe_mapping.py b/flashinfer/moe_mapping.py new file mode 100644 index 000000000..b41b7ab33 --- /dev/null +++ b/flashinfer/moe_mapping.py @@ -0,0 +1,479 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + + +class Mapping(object): + """ + A node with 8 GPUs, tp_size = 4, cp_size = 1, pp_size = 2 + + 2 tp groups: + + - [0, 1, 2, 3] + - [4, 5, 6, 7] + + 4 pp groups: + + - [0, 4] + - [1, 5] + - [2, 6] + - [3, 7] + + A node with 8 GPUs, tp_size = 4, cp_size = 2, pp_size = 1 + + 2 tp groups: + + - [0, 1, 2, 3] + - [4, 5, 6, 7] + + 4 cp groups: + + - [0, 4] + - [1, 5] + - [2, 6] + - [3, 7] + + A node with 8 GPUs, moe_tp_size = 2, moe_ep_size = 4 + + 4 moe_tp groups: + + - [0, 4] + - [1, 5] + - [2, 6] + - [3, 7] + + 2 moe_ep groups: + + - [0, 1, 2, 3] + - [4, 5, 6, 7] + + 2 nodes with 16 GPUs, moe_tp_size = 2, moe_ep_size = 4, pp_size = 2 + + 8 moe_tp groups: + + - [0 4] + - [1 5] + - [2 6] + - [3 7] + - [8 12] + - [9 13] + - [10 14] + - [11 15] + + 4 moe_ep groups: + + - [0, 1, 2, 3] + - [4, 5, 6, 7] + - [8, 9, 10, 11] + - [12, 13, 14, 15] + + 8 pp groups: + + - [0 8] + - [1 9] + - [2 10] + - [3 11] + - [4 12] + - [5 13] + - [6 14] + - [7 15] + + 2 nodes with 8 GPUs, tp_size 2, pp_size 2, cp_size 2 + + 4 tp groups: + - [0, 1] + - [2, 3] + - [4, 5] + - [6, 7] + + 4 pp groups: + - [0, 4] + - [1, 5] + - [2, 6] + - [3, 7] + + 4 cp groups: + - [0, 2] + - [1, 3] + - [4, 6] + - [5, 7] + """ + + def __init__( + self, + world_size=1, + rank=0, + gpus_per_node=8, + *, + cp_size=1, + cp_config=None, + tp_size=1, + pp_size=1, + moe_cluster_size=-1, # -1 means no moe + moe_tp_size=-1, # -1 means no moe + moe_ep_size=-1, # -1 means no moe + attn_tp_size=-1, + attn_cp_size=-1, + auto_parallel=False, + enable_attention_dp=False, + ): + # set default values for non-moe cases + # or where only one MOE parallelism size is specified + if moe_cluster_size == -1: + moe_cluster_size = 1 + + if moe_tp_size == -1 and moe_ep_size == -1: + moe_tp_size = tp_size // moe_cluster_size + moe_ep_size = 1 + + elif moe_tp_size == -1: + moe_tp_size = tp_size // (moe_ep_size * moe_cluster_size) + + elif moe_ep_size == -1: + moe_ep_size = tp_size // (moe_tp_size * moe_cluster_size) + + if attn_tp_size == -1 and attn_cp_size == -1: + # fallback to ulysses + attn_tp_size = tp_size * cp_size + attn_cp_size = 1 + + elif attn_tp_size == -1: + attn_tp_size = cp_size * tp_size // attn_cp_size + + elif attn_cp_size == -1: + attn_cp_size = cp_size * tp_size // attn_tp_size + + if attn_cp_size != 1: + raise ValueError( + f"attn_cp_size must be 1 for now, but got {attn_tp_size}, {attn_cp_size}." + ) + + if auto_parallel: + if tp_size != 1 or pp_size != 1 or tp_size != 1: + raise ValueError( + f"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, but got {tp_size}, {pp_size}, {cp_size}." + ) + else: + if tp_size * pp_size * cp_size != world_size: + raise ValueError( + f"world_size must equal to tp_size * pp_size * cp_size, but got {world_size} != {tp_size} * {pp_size} * {cp_size}." + ) + + moe_tp_ep_size = moe_tp_size * moe_ep_size + moe_tp_cluster_ep_size = moe_tp_ep_size * moe_cluster_size + if moe_tp_cluster_ep_size != tp_size: + raise ValueError( + f"tp_size must equal to moe_tp_size * moe_ep_size * moe_cluster_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size} * {moe_cluster_size}" + ) + + attn_tp_cp_size = attn_tp_size * attn_cp_size + if attn_tp_cp_size != tp_size * cp_size: + raise ValueError( + f"tp_size * cp_size must equal to attn_tp_size * attn_cp_size, but got {tp_size} * {cp_size} != {attn_tp_size} * {attn_cp_size}" + ) + + if moe_ep_size != 1 and cp_size > 1: + raise NotImplementedError("CP don't support MoE tp/ep yet") + + self.tp_size = tp_size + self.cp_size = cp_size + self.cp_config = cp_config if cp_config is not None else {} + self.pp_size = pp_size + self.moe_tp_size = moe_tp_size + self.moe_ep_size = moe_ep_size + self.moe_cluster_size = moe_cluster_size + self.attn_tp_size = attn_tp_size + self.attn_cp_size = attn_cp_size + self.auto_parallel = auto_parallel + self.world_size = world_size + self.enable_attention_dp = enable_attention_dp + self.rank = rank + self.gpus_per_node = gpus_per_node + self.pp_groups = [] + self.cp_groups = [] + self.tp_groups = [] + self.moe_cluster_groups = [] + self.moe_tp_groups = [] + self.moe_ep_groups = [] + + if moe_cluster_size > 1: + assert moe_ep_size == 1 + + # init pp group + for i in range(tp_size * cp_size): + ranks = range(i, world_size, tp_size * cp_size) + self.pp_groups.append(list(ranks)) + + # init cp group + for i in range(pp_size): + for j in range(tp_size): + ranks = range( + i * tp_size * cp_size + j, (i + 1) * tp_size * cp_size + j, tp_size + ) + self.cp_groups.append(list(ranks)) + + # init tp group + for i in range(pp_size): + for j in range(cp_size): + ranks = range( + i * tp_size * cp_size + j * tp_size, + i * tp_size * cp_size + (j + 1) * tp_size, + ) + self.tp_groups.append(list(ranks)) + + # init moe tp group + for i in range(pp_size): + for j in range(moe_cluster_size * moe_ep_size): + ranks = range( + i * moe_tp_cluster_ep_size + j, + (i + 1) * moe_tp_cluster_ep_size, + moe_cluster_size * moe_ep_size, + ) + self.moe_tp_groups.append(list(ranks)) + + # init moe cluster group + for i in range(pp_size): + for j in range(moe_tp_size): + ranks = range( + i * moe_tp_cluster_ep_size + j * moe_cluster_size, + i * moe_tp_cluster_ep_size + (j + 1) * moe_cluster_size, + ) + self.moe_cluster_groups.append(list(ranks)) + + # init moe ep group + for i in range(pp_size): + for j in range(moe_tp_size): + for k in range(moe_cluster_size): + ranks = range( + i * moe_tp_cluster_ep_size + + j * moe_cluster_size * moe_ep_size + + k * moe_ep_size, + i * moe_tp_cluster_ep_size + + j * moe_cluster_size * moe_ep_size + + (k + 1) * moe_ep_size, + ) + self.moe_ep_groups.append(list(ranks)) + + def __eq__(self, other): + if not isinstance(other, Mapping): + return NotImplemented + + return ( + self.world_size == other.world_size + and self.rank == other.rank + and self.gpus_per_node == other.gpus_per_node + and self.cp_size == other.cp_size + and self.tp_size == other.tp_size + and self.moe_cluster_size == other.moe_cluster_size + and self.pp_size == other.pp_size + and self.moe_tp_size == other.moe_tp_size + and self.moe_ep_size == other.moe_ep_size + and self.attn_tp_size == other.attn_tp_size + and self.attn_cp_size == other.attn_cp_size + and self.auto_parallel == other.auto_parallel + ) + + def __hash__(self): + return hash( + ( + self.world_size, + self.rank, + self.gpus_per_node, + self.cp_size, + self.tp_size, + self.pp_size, + self.moe_tp_size, + self.moe_cluster_size, + self.moe_ep_size, + self.attn_tp_size, + self.attn_cp_size, + self.auto_parallel, + ) + ) + + @property + def rank(self): + return self._rank + + @rank.setter + def rank(self, rank: int): + # TODO(qijun): skip check for enable_attention_dp temporarily, will support attention_dp_size + if not self.enable_attention_dp: + if not isinstance(rank, int) or rank < 0 and rank >= self.world_size: + raise ValueError( + f"Rank should be an integer between 0 and {self.world_size-1}, but got {rank}." + ) + self._rank = rank + + @property + def tp_rank(self): + return 0 if self.auto_parallel else self.rank % self.tp_size + + @property + def pp_rank(self): + return 0 if self.auto_parallel else self.rank // (self.tp_size * self.cp_size) + + @property + def cp_rank(self): + return ( + 0 + if self.auto_parallel + else self.rank % (self.tp_size * self.cp_size) // self.tp_size + ) + + @property + def moe_tp_rank(self): + return self.tp_rank // (self.moe_ep_size * self.moe_cluster_size) + + @property + def moe_cluster_rank(self): + return self.tp_rank % self.moe_cluster_size + + @property + def moe_ep_rank(self): + return self.tp_rank % self.moe_ep_size + + @property + def tp_group(self): + return self.tp_groups[self.pp_rank * self.cp_size + self.cp_rank] + + @property + def pp_group(self): + return self.pp_groups[self.cp_rank * self.tp_size + self.tp_rank] + + @property + def cp_group(self): + return self.cp_groups[self.pp_rank * self.tp_size + self.tp_rank] + + @property + def moe_tp_group(self): + return self.moe_tp_groups[ + self.pp_rank * self.moe_cluster_size * self.moe_ep_size + + self.moe_cluster_rank * self.moe_ep_size + + self.moe_ep_rank + ] + + @property + def moe_cluster_group(self): + return self.moe_cluster_groups[ + self.pp_rank * self.moe_tp_size * self.moe_ep_size + + self.moe_tp_rank * self.moe_ep_size + + self.moe_ep_rank + ] + + @property + def moe_ep_group(self): + return self.moe_ep_groups[ + self.pp_rank * self.moe_cluster_size * self.moe_tp_size + + self.tp_rank * self.moe_cluster_size + + self.moe_cluster_rank + ] + + @property + def node_rank(self): + return self.rank // self.gpus_per_node + + @property + def local_rank(self): + return self.rank % self.gpus_per_node + + def has_cp(self): + return self.cp_size > 1 + + def get_node_rank(self, rank: int): + return rank // self.gpus_per_node + + def get_local_rank(self, rank: int): + return rank % self.gpus_per_node + + def is_multi_node(self): + return self.world_size > self.gpus_per_node + + def has_tp(self): + return self.tp_size > 1 + + def is_last_pp_rank(self): + return self.pp_rank == self.pp_size - 1 + + def is_second_last_pp_rank(self): + return self.pp_rank == self.pp_size - 2 + + def is_first_pp_rank(self): + return self.pp_rank == 0 + + def has_pp(self): + return self.pp_size > 1 + + def prev_pp_rank(self): + p = self.rank - self.tp_size * self.cp_size + if p < 0: + p = p + self.world_size + return p + + def next_pp_rank(self): + p = self.rank + self.tp_size * self.cp_size + if p >= self.world_size: + p = p - self.world_size + return p + + def has_moe_cluster(self): + return self.moe_cluster_size > 1 + + def has_moe_tp(self): + return self.moe_tp_size > 1 + + def has_moe_ep(self): + return self.moe_ep_size > 1 + + # TODO: Add support of uneven/arbitrary layer segmentation + def pp_layers(self, num_layers: int) -> List[int]: + layers_per_pipeline_stage = num_layers // self.pp_size + if self.pp_rank == self.pp_size - 1: + layers_range = range(self.pp_rank * layers_per_pipeline_stage, num_layers) + else: + layers_range = range( + self.pp_rank * layers_per_pipeline_stage, + (self.pp_rank + 1) * layers_per_pipeline_stage, + ) + return list(layers_range) + + def ep_experts(self, num_experts: int) -> List[int]: + assert self.cp_size == 1 + experts_per_rank = num_experts // self.moe_ep_size + experts_range = range( + self.moe_ep_rank * experts_per_rank, + (self.moe_ep_rank + 1) * experts_per_rank, + ) + return list(experts_range) + + @classmethod + def from_dict(cls, mapping: dict): + return cls(**mapping) + + def to_dict(self): + return { + "world_size": self.world_size, + "rank": self.rank, + "gpus_per_node": self.gpus_per_node, + "cp_size": self.cp_size, + "tp_size": self.tp_size, + "pp_size": self.pp_size, + "moe_tp_size": self.moe_tp_size, + "moe_cluster_size": self.moe_cluster_size, + "moe_ep_size": self.moe_ep_size, + "attn_tp_size": self.attn_tp_size, + "attn_cp_size": self.attn_cp_size, + "auto_parallel": self.auto_parallel, + } diff --git a/tests/test_trtllm_collectives.py b/tests/test_trtllm_collectives.py new file mode 100644 index 000000000..7cb7be67d --- /dev/null +++ b/tests/test_trtllm_collectives.py @@ -0,0 +1,121 @@ +# import multiprocessing as mp +# import socket +from typing import Any + +import pytest +import torch + +import flashinfer.comm as comm +from flashinfer import MoE_Mapping + +# import torch.distributed as dist + + +RANDOM_SEED = 42 +NUM_REPEATS = 8 # To test input as a list + + +def _run_reduce_scatter_worker(rank, world_size, dtype): + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + mapp = MoE_Mapping( + world_size=world_size, + rank=rank, + gpus_per_node=world_size, + tp_size=world_size, + ) + hidden_dim = 32 + + sizes = [world_size * (i + 1) for i in range(world_size)] + total_size = sum(sizes) + shape = (world_size, total_size, hidden_dim) + + input_tensors = [ + torch.randn(shape, dtype=dtype, device=device) for _ in range(NUM_REPEATS) + ] + expected_output = [i.sum(dim=0) for i in input_tensors] + input_rs = [i[rank,:,:] for i in input_tensors] + + output = comm.reduce_scatter( + input_rs, + mapp, + dim=0, + sizes=sizes, + ) + + for i in range(NUM_REPEATS): + start = sum(sizes[:rank]) + end = start + sizes[rank] + torch.testing.assert_close( + output[i], expected_output[i][start:end,:], atol=1e-2, rtol=3e-2 + ) + + +def _run_allgather_worker(world_size, rank, hidden_dim, dtype): + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + device = torch.device(f"cuda:{rank}") + mapp = MoE_Mapping( + world_size=world_size, + rank=rank, + gpus_per_node=world_size, + tp_size=world_size, + ) + sizes = [world_size * (i + 1) for i in range(world_size)] + total_size = sum(sizes) + shape_ref = (total_size, hidden_dim) + + out_ref = torch.randn(shape_ref, dtype=dtype, device=device) + start = sum(sizes[:rank]) + end = start + sizes[rank] + inp = out_ref[start:end, :] + out = comm.all_gather( + inp, + mapp, + dim=0, + sizes=sizes, + ) + torch.testing.assert_close(out, out_ref, atol=1e-3, rtol=3e-2) + + +@pytest.mark.parametrize("hidden_dim", [64, 128]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_trtllm_all_gather(hidden_dim, dtype): + torch.manual_seed(RANDOM_SEED) + from mpi4py import MPI + + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.size + assert world_size > 0 + available_gpus = torch.cuda.device_count() + if world_size > available_gpus: + raise ValueError( + f"world_size {world_size} is greater than available_gpus {available_gpus}" + ) + print(f"Running test for world_size={world_size}") + _run_allgather_worker(world_size, rank, hidden_dim, dtype) + + print(f"all_gather with tp = {world_size}: OK") + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_trtllm_reduce_scatter(dtype): + torch.manual_seed(RANDOM_SEED) + from mpi4py import MPI + + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.size + assert world_size > 0 + available_gpus = torch.cuda.device_count() + if world_size > available_gpus: + raise ValueError( + f"world_size {world_size} is greater than available_gpus {available_gpus}" + ) + print(f"Running test for world_size={world_size}") + _run_reduce_scatter_worker(rank, world_size, dtype) + + print(f"reduce_scatter with tp = {world_size}: OK")