Skip to content

Port AllGather/ReduceScatter from TensorRT-LLM #1145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions csrc/nv_internal/cpp/common/opUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights
* reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tensorrt_llm/common/opUtils.h"

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>

#include <functional>
#include <mutex>
#include <thread>

#include "cuda.h"
#include "tensorrt_llm/runtime/utils/mpiTags.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"

#if ENABLE_MULTI_DEVICE

std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap() {
static std::unordered_map<nvinfer1::DataType, ncclDataType_t> dtypeMap = {
{nvinfer1::DataType::kFLOAT, ncclFloat32}, {nvinfer1::DataType::kHALF, ncclFloat16},
{nvinfer1::DataType::kBF16, ncclBfloat16}, {nvinfer1::DataType::kFP8, ncclInt8},
{nvinfer1::DataType::kBOOL, ncclInt8}, {nvinfer1::DataType::kINT32, ncclInt32},
{nvinfer1::DataType::kINT64, ncclInt64}, {nvinfer1::DataType::kUINT8, ncclUint8},
{nvinfer1::DataType::kINT8, ncclInt8},
};
return &dtypeMap;
}

namespace {

// Get NCCL unique ID for a group of ranks.
ncclUniqueId getUniqueId(std::set<int> const& group) {
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
ncclUniqueId id;
if (rank == *group.begin()) {
NCCLCHECK_THROW(ncclGetUniqueId(&id));
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) {
COMM_SESSION.sendValue(id, *it, tensorrt_llm::mpi::MpiTag::kDefault);
}
} else {
COMM_SESSION.recvValue(id, *group.begin(), tensorrt_llm::mpi::MpiTag::kDefault);
}
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return id;
}
} // namespace

std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group) {
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
static std::map<std::set<int>, std::shared_ptr<ncclComm_t>> commMap;
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
std::ostringstream oss;
int index = 0;
for (auto const& rank : group) {
if (index != 0) {
oss << ",";
}
oss << rank;
index++;
}
auto groupStr = oss.str();
auto it = commMap.find(group);
if (it != commMap.end()) {
auto ncclComm = it->second;
TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank);
return ncclComm;
}

TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank);
ncclUniqueId id = getUniqueId(group);
int groupRank = 0;
for (auto const& currentRank : group) {
if (rank == currentRank) break;
++groupRank;
}
TLLM_CHECK(static_cast<size_t>(groupRank) < group.size());
std::shared_ptr<ncclComm_t> ncclComm(new ncclComm_t, [](ncclComm_t* comm) {
ncclCommDestroy(*comm);
delete comm;
});
// Need static connection initialization for accurate KV cache size estimation
#if defined(_WIN32)
if (getenv("NCCL_RUNTIME_CONNECT") == nullptr) _putenv_s("NCCL_RUNTIME_CONNECT", "0");
#else
setenv("NCCL_RUNTIME_CONNECT", "0", 0);
#endif // _WIN32
NCCLCHECK_THROW(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank));
commMap[group] = ncclComm;
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return ncclComm;
}
#endif // ENABLE_MULTI_DEVICE
Loading