Skip to content

[UR][CUDA][HIP] Refactor setKernelParams #18518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[UR][CUDA][HIP] Refactor setKernelParams
- Remove unused Context parameters
- Avoid unnecessary copy in `guessLocalWorkSize`
- Simplify the control flow in setKernelParams
- Move cached properties fetching code to constructors
- Query HIP for occupancy in `guessLocalWorkSize`
  • Loading branch information
npmiller committed May 20, 2025
commit 226e80ee1fcc9040661426088b0cceea1f11d9b9
11 changes: 5 additions & 6 deletions unified-runtime/source/adapters/cuda/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
uint32_t LocalSize = hKernel->getLocalSize();
CUfunction CuFunc = hKernel->get();
UR_CHECK_ERROR(setKernelParams(
hCommandBuffer->Context, hCommandBuffer->Device, workDim,
pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize, hKernel, CuFunc,
ThreadsPerBlock, BlocksPerGrid));
hCommandBuffer->Device, workDim, pGlobalWorkOffset, pGlobalWorkSize,
pLocalWorkSize, hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid));

// Set node param structure with the kernel related data
auto &ArgPointers = hKernel->getArgPointers();
Expand Down Expand Up @@ -1373,9 +1372,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
size_t BlocksPerGrid[3] = {1u, 1u, 1u};
CUfunction CuFunc = KernelData.Kernel->get();
auto Result = setKernelParams(
hCommandBuffer->Context, hCommandBuffer->Device, KernelData.WorkDim,
KernelData.GlobalWorkOffset, KernelData.GlobalWorkSize, LocalWorkSize,
KernelData.Kernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
hCommandBuffer->Device, KernelData.WorkDim, KernelData.GlobalWorkOffset,
KernelData.GlobalWorkSize, LocalWorkSize, KernelData.Kernel, CuFunc,
ThreadsPerBlock, BlocksPerGrid);
if (Result != UR_RESULT_SUCCESS) {
return Result;
}
Expand Down
4 changes: 4 additions & 0 deletions unified-runtime/source/adapters/cuda/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ struct ur_device_handle_t_ : ur::cuda::handle_base {
return MaxWorkItemSizes[index];
}

const size_t *getMaxWorkItemSizes() const noexcept {
return MaxWorkItemSizes;
}

size_t getMaxWorkGroupSize() const noexcept { return MaxWorkGroupSize; };

size_t getMaxRegsPerBlock() const noexcept { return MaxRegsPerBlock; };
Expand Down
116 changes: 46 additions & 70 deletions unified-runtime/source/adapters/cuda/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,14 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
GlobalSizeNormalized[i] = GlobalWorkSize[i];
}

size_t MaxBlockDim[3];
MaxBlockDim[0] = Device->getMaxWorkItemSizes(0);
MaxBlockDim[1] = Device->getMaxWorkItemSizes(1);
MaxBlockDim[2] = Device->getMaxWorkItemSizes(2);

int MinGrid, MaxBlockSize;
UR_CHECK_ERROR(cuOccupancyMaxPotentialBlockSize(
&MinGrid, &MaxBlockSize, Kernel->get(), NULL, Kernel->getLocalSize(),
MaxBlockDim[0]));
Device->getMaxWorkItemSizes(0)));

roundToHighestFactorOfGlobalSizeIn3d(ThreadsPerBlock, GlobalSizeNormalized,
MaxBlockDim, MaxBlockSize);
Device->getMaxWorkItemSizes(),
MaxBlockSize);
}

// Helper to verify out-of-registers case (exceeded block max registers).
Expand All @@ -145,7 +141,6 @@ bool hasExceededMaxRegistersPerBlock(ur_device_handle_t Device,

// Helper to compute kernel parameters from workload
// dimensions.
// @param [in] Context handler to the target Context
// @param [in] Device handler to the target Device
// @param [in] WorkDim workload dimension
// @param [in] GlobalWorkOffset pointer workload global offsets
Expand All @@ -155,73 +150,56 @@ bool hasExceededMaxRegistersPerBlock(ur_device_handle_t Device,
// @param [out] ThreadsPerBlock Number of threads per block we should run
// @param [out] BlocksPerGrid Number of blocks per grid we should run
ur_result_t
setKernelParams([[maybe_unused]] const ur_context_handle_t Context,
const ur_device_handle_t Device, const uint32_t WorkDim,
setKernelParams(const ur_device_handle_t Device, const uint32_t WorkDim,
const size_t *GlobalWorkOffset, const size_t *GlobalWorkSize,
const size_t *LocalWorkSize, ur_kernel_handle_t &Kernel,
CUfunction &CuFunc, size_t (&ThreadsPerBlock)[3],
size_t (&BlocksPerGrid)[3]) {
size_t MaxWorkGroupSize = 0u;
bool ProvidedLocalWorkGroupSize = LocalWorkSize != nullptr;

try {
// Set the active context here as guessLocalWorkSize needs an active context
ScopedContext Active(Device);
{
size_t *MaxThreadsPerBlock = Kernel->MaxThreadsPerBlock;
size_t *ReqdThreadsPerBlock = Kernel->ReqdThreadsPerBlock;
MaxWorkGroupSize = Device->getMaxWorkGroupSize();

if (ProvidedLocalWorkGroupSize) {
auto IsValid = [&](int Dim) {
if (ReqdThreadsPerBlock[Dim] != 0 &&
LocalWorkSize[Dim] != ReqdThreadsPerBlock[Dim])
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;

if (MaxThreadsPerBlock[Dim] != 0 &&
LocalWorkSize[Dim] > MaxThreadsPerBlock[Dim])
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;

if (LocalWorkSize[Dim] > Device->getMaxWorkItemSizes(Dim))
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
// Checks that local work sizes are a divisor of the global work sizes
// which includes that the local work sizes are neither larger than
// the global work sizes and not 0.
if (0u == LocalWorkSize[Dim])
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
if (0u != (GlobalWorkSize[Dim] % LocalWorkSize[Dim]))
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
ThreadsPerBlock[Dim] = LocalWorkSize[Dim];
return UR_RESULT_SUCCESS;
};

size_t KernelLocalWorkGroupSize = 1;
for (size_t Dim = 0; Dim < WorkDim; Dim++) {
auto Err = IsValid(Dim);
if (Err != UR_RESULT_SUCCESS)
return Err;
// If no error then compute the total local work size as a product of
// all dims.
KernelLocalWorkGroupSize *= LocalWorkSize[Dim];
}

if (size_t MaxLinearThreadsPerBlock = Kernel->MaxLinearThreadsPerBlock;
MaxLinearThreadsPerBlock &&
MaxLinearThreadsPerBlock < KernelLocalWorkGroupSize) {
if (LocalWorkSize != nullptr) {
size_t KernelLocalWorkGroupSize = 1;
for (size_t i = 0; i < WorkDim; i++) {
if (Kernel->ReqdThreadsPerBlock[i] &&
Kernel->ReqdThreadsPerBlock[i] != LocalWorkSize[i])
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}

if (hasExceededMaxRegistersPerBlock(Device, Kernel,
KernelLocalWorkGroupSize)) {
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
}
} else {
guessLocalWorkSize(Device, ThreadsPerBlock, GlobalWorkSize, WorkDim,
Kernel);
if (Kernel->MaxThreadsPerBlock[i] &&
Kernel->MaxThreadsPerBlock[i] < LocalWorkSize[i])
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;

if (LocalWorkSize[i] > Device->getMaxWorkItemSizes(i))
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
// Checks that local work sizes are a divisor of the global work sizes
// which includes that the local work sizes are neither larger than
// the global work sizes and not 0.
if (0u == LocalWorkSize[i] ||
0u != (GlobalWorkSize[i] % LocalWorkSize[i]))
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;

ThreadsPerBlock[i] = LocalWorkSize[i];

// Compute the total local work size as a product of all is.
KernelLocalWorkGroupSize *= LocalWorkSize[i];
}

if (Kernel->MaxLinearThreadsPerBlock &&
Kernel->MaxLinearThreadsPerBlock < KernelLocalWorkGroupSize) {
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}

if (hasExceededMaxRegistersPerBlock(Device, Kernel,
KernelLocalWorkGroupSize)) {
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
}
} else {
guessLocalWorkSize(Device, ThreadsPerBlock, GlobalWorkSize, WorkDim,
Kernel);
}

if (MaxWorkGroupSize <
if (Device->getMaxWorkGroupSize() <
ThreadsPerBlock[0] * ThreadsPerBlock[1] * ThreadsPerBlock[2]) {
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}
Expand Down Expand Up @@ -407,10 +385,9 @@ enqueueKernelLaunch(ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,

// This might return UR_RESULT_ERROR_ADAPTER_SPECIFIC, which cannot be handled
// using the standard UR_CHECK_ERROR
if (ur_result_t Ret =
setKernelParams(hQueue->getContext(), hQueue->Device, workDim,
pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
if (ur_result_t Ret = setKernelParams(
hQueue->Device, workDim, pGlobalWorkOffset, pGlobalWorkSize,
pLocalWorkSize, hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
Ret != UR_RESULT_SUCCESS)
return Ret;

Expand Down Expand Up @@ -595,10 +572,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(

// This might return UR_RESULT_ERROR_ADAPTER_SPECIFIC, which cannot be handled
// using the standard UR_CHECK_ERROR
if (ur_result_t Ret =
setKernelParams(hQueue->getContext(), hQueue->Device, workDim,
pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
if (ur_result_t Ret = setKernelParams(
hQueue->Device, workDim, pGlobalWorkOffset, pGlobalWorkSize,
pLocalWorkSize, hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
Ret != UR_RESULT_SUCCESS)
return Ret;

Expand Down
3 changes: 1 addition & 2 deletions unified-runtime/source/adapters/cuda/enqueue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ bool hasExceededMaxRegistersPerBlock(ur_device_handle_t Device,
size_t BlockSize);

ur_result_t
setKernelParams(const ur_context_handle_t Context,
const ur_device_handle_t Device, const uint32_t WorkDim,
setKernelParams(const ur_device_handle_t Device, const uint32_t WorkDim,
const size_t *GlobalWorkOffset, const size_t *GlobalWorkSize,
const size_t *LocalWorkSize, ur_kernel_handle_t &Kernel,
CUfunction &CuFunc, size_t (&ThreadsPerBlock)[3],
Expand Down
34 changes: 3 additions & 31 deletions unified-runtime/source/adapters/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,7 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
return ReturnValue(size_t(MaxThreads));
}
case UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE: {
size_t GroupSize[3] = {0, 0, 0};
const auto &ReqdWGSizeMDMap =
hKernel->getProgram()->KernelReqdWorkGroupSizeMD;
const auto ReqdWGSizeMD = ReqdWGSizeMDMap.find(hKernel->getName());
if (ReqdWGSizeMD != ReqdWGSizeMDMap.end()) {
const auto ReqdWGSize = ReqdWGSizeMD->second;
GroupSize[0] = std::get<0>(ReqdWGSize);
GroupSize[1] = std::get<1>(ReqdWGSize);
GroupSize[2] = std::get<2>(ReqdWGSize);
}
return ReturnValue(GroupSize, 3);
return ReturnValue(hKernel->ReqdThreadsPerBlock, 3);
}
case UR_KERNEL_GROUP_INFO_LOCAL_MEM_SIZE: {
// OpenCL LOCAL == CUDA SHARED
Expand All @@ -124,28 +114,10 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
return ReturnValue(uint64_t(Bytes));
}
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE: {
size_t MaxGroupSize[3] = {0, 0, 0};
const auto &MaxWGSizeMDMap =
hKernel->getProgram()->KernelMaxWorkGroupSizeMD;
const auto MaxWGSizeMD = MaxWGSizeMDMap.find(hKernel->getName());
if (MaxWGSizeMD != MaxWGSizeMDMap.end()) {
const auto MaxWGSize = MaxWGSizeMD->second;
MaxGroupSize[0] = std::get<0>(MaxWGSize);
MaxGroupSize[1] = std::get<1>(MaxWGSize);
MaxGroupSize[2] = std::get<2>(MaxWGSize);
}
return ReturnValue(MaxGroupSize, 3);
return ReturnValue(hKernel->MaxThreadsPerBlock, 3);
}
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE: {
size_t MaxLinearGroupSize = 0;
const auto &MaxLinearWGSizeMDMap =
hKernel->getProgram()->KernelMaxLinearWorkGroupSizeMD;
const auto MaxLinearWGSizeMD =
MaxLinearWGSizeMDMap.find(hKernel->getName());
if (MaxLinearWGSizeMD != MaxLinearWGSizeMDMap.end()) {
MaxLinearGroupSize = MaxLinearWGSizeMD->second;
}
return ReturnValue(MaxLinearGroupSize);
return ReturnValue(hKernel->MaxLinearThreadsPerBlock);
}
default:
break;
Expand Down
56 changes: 37 additions & 19 deletions unified-runtime/source/adapters/cuda/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,25 +258,43 @@ struct ur_kernel_handle_t_ : ur::cuda::handle_base {
Context{Context}, Program{Program}, RefCount{1} {
urProgramRetain(Program);
urContextRetain(Context);
/// Note: this code assumes that there is only one device per context
ur_result_t RetError = urKernelGetGroupInfo(
this, Program->getDevice(),
UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE,
sizeof(ReqdThreadsPerBlock), ReqdThreadsPerBlock, nullptr);
(void)RetError;
assert(RetError == UR_RESULT_SUCCESS);
/// Note: this code assumes that there is only one device per context
RetError = urKernelGetGroupInfo(
this, Program->getDevice(),
UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE,
sizeof(MaxThreadsPerBlock), MaxThreadsPerBlock, nullptr);
assert(RetError == UR_RESULT_SUCCESS);
/// Note: this code assumes that there is only one device per context
RetError = urKernelGetGroupInfo(
this, Program->getDevice(),
UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE,
sizeof(MaxLinearThreadsPerBlock), &MaxLinearThreadsPerBlock, nullptr);
assert(RetError == UR_RESULT_SUCCESS);

// Get reqd work group size
const auto &ReqdWGSizeMDMap = Program->KernelReqdWorkGroupSizeMD;
const auto ReqdWGSizeMD = ReqdWGSizeMDMap.find(Name);
if (ReqdWGSizeMD != ReqdWGSizeMDMap.end()) {
const auto ReqdWGSize = ReqdWGSizeMD->second;
ReqdThreadsPerBlock[0] = std::get<0>(ReqdWGSize);
ReqdThreadsPerBlock[1] = std::get<1>(ReqdWGSize);
ReqdThreadsPerBlock[2] = std::get<2>(ReqdWGSize);
} else {
ReqdThreadsPerBlock[0] = 0;
ReqdThreadsPerBlock[1] = 0;
ReqdThreadsPerBlock[2] = 0;
}

// Get max work group size
const auto &MaxWGSizeMDMap = Program->KernelMaxWorkGroupSizeMD;
const auto MaxWGSizeMD = MaxWGSizeMDMap.find(Name);
if (MaxWGSizeMD != MaxWGSizeMDMap.end()) {
const auto MaxWGSize = MaxWGSizeMD->second;
MaxThreadsPerBlock[0] = std::get<0>(MaxWGSize);
MaxThreadsPerBlock[1] = std::get<1>(MaxWGSize);
MaxThreadsPerBlock[2] = std::get<2>(MaxWGSize);
} else {
MaxThreadsPerBlock[0] = 0;
MaxThreadsPerBlock[1] = 0;
MaxThreadsPerBlock[2] = 0;
}

// Get max linear work group size
MaxLinearThreadsPerBlock = 0;
const auto MaxLinearWGSizeMD =
Program->KernelMaxLinearWorkGroupSizeMD.find(Name);
if (MaxLinearWGSizeMD != Program->KernelMaxLinearWorkGroupSizeMD.end()) {
MaxLinearThreadsPerBlock = MaxLinearWGSizeMD->second;
}

UR_CHECK_ERROR(
cuFuncGetAttribute(&RegsPerThread, CU_FUNC_ATTRIBUTE_NUM_REGS, Func));
}
Expand Down
22 changes: 12 additions & 10 deletions unified-runtime/source/adapters/hip/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ struct ur_device_handle_t_ : ur::hip::handle_base {
uint32_t DeviceIndex;

int MaxWorkGroupSize{0};
int MaxBlockDimX{0};
int MaxBlockDimY{0};
int MaxBlockDimZ{0};
size_t MaxBlockDim[3];
int MaxCapacityLocalMem{0};
int MaxChosenLocalMem{0};
int ManagedMemSupport{0};
Expand All @@ -45,12 +43,18 @@ struct ur_device_handle_t_ : ur::hip::handle_base {

UR_CHECK_ERROR(hipDeviceGetAttribute(
&MaxWorkGroupSize, hipDeviceAttributeMaxThreadsPerBlock, HIPDevice));

int MaxDim;
UR_CHECK_ERROR(hipDeviceGetAttribute(
&MaxBlockDimX, hipDeviceAttributeMaxBlockDimX, HIPDevice));
&MaxDim, hipDeviceAttributeMaxBlockDimX, HIPDevice));
MaxBlockDim[0] = size_t(MaxDim);
UR_CHECK_ERROR(hipDeviceGetAttribute(
&MaxBlockDimY, hipDeviceAttributeMaxBlockDimY, HIPDevice));
&MaxDim, hipDeviceAttributeMaxBlockDimY, HIPDevice));
MaxBlockDim[1] = size_t(MaxDim);
UR_CHECK_ERROR(hipDeviceGetAttribute(
&MaxBlockDimZ, hipDeviceAttributeMaxBlockDimZ, HIPDevice));
&MaxDim, hipDeviceAttributeMaxBlockDimZ, HIPDevice));
MaxBlockDim[2] = size_t(MaxDim);

UR_CHECK_ERROR(hipDeviceGetAttribute(
&MaxCapacityLocalMem, hipDeviceAttributeMaxSharedMemoryPerBlock,
HIPDevice));
Expand Down Expand Up @@ -107,11 +111,9 @@ struct ur_device_handle_t_ : ur::hip::handle_base {

int getMaxWorkGroupSize() const noexcept { return MaxWorkGroupSize; };

int getMaxBlockDimX() const noexcept { return MaxBlockDimX; };

int getMaxBlockDimY() const noexcept { return MaxBlockDimY; };
size_t getMaxBlockDim(int dim) const noexcept { return MaxBlockDim[dim]; };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is noexcept on something that accesses MaxBlockDim[dim] correct? Might'nt this technically throw an exception when out of bounds?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a size_t C array not a vector, so I don't believe this would ever throw an exception. Maybe with extra windows bounds checking or something? But even then I'd expect that to be asserting instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah perhaps. Unfortunately https://en.cppreference.com/w/cpp/container/language/array.html is down so I can't see the C++ behaviour, but another (less trusted) source said that the out-of-range behaviour on a C-style array is undefined and so I don't know if we can guarantee it won't throw an exception.


int getMaxBlockDimZ() const noexcept { return MaxBlockDimZ; };
const size_t *getMaxBlockDim() const noexcept { return MaxBlockDim; };

int getMaxCapacityLocalMem() const noexcept { return MaxCapacityLocalMem; };

Expand Down
Loading
Loading