Skip to content

Commit 0ecdc99

Browse files
committed
cuda: Update CUDA APIs to work with CUDA 13.0
Replaced cudaGetDriverEntryPoint with cudaGetDriverEntryPointByVersion since it is deprecated in CUDA 13.0. For all other CUDA APIs, added the version during declaration and resolution so that the correct APIs are called. Signed-off-by: Sunita Bhaskaran <[email protected]>
1 parent af21e6c commit 0ecdc99

File tree

1 file changed

+40
-18
lines changed

1 file changed

+40
-18
lines changed

src/nccl_ofi_cuda.cpp

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,30 @@
1414
#include "nccl_ofi_log.h"
1515
#include "nccl_ofi_param.h"
1616

17-
#define DECLARE_CUDA_FUNCTION(function) static PFN_##function pfn_##function = NULL
18-
#define RESOLVE_CUDA_FUNCTION(function) \
19-
do { \
17+
#define DECLARE_CUDA_FUNCTION(function, version) static PFN_##function##_v##version pfn_##function = NULL
18+
19+
#if CUDART_VERSION >= 13000
20+
#define RESOLVE_CUDA_FUNCTION(function, version) do { \
21+
enum cudaDriverEntryPointQueryResult result; \
22+
cudaError_t err = \
23+
cudaGetDriverEntryPointByVersion(#function, (void **)&pfn_##function, version, cudaEnableDefault, &result); \
24+
if (err != cudaSuccess) { \
25+
switch (result) { \
26+
case cudaDriverEntryPointSymbolNotFound: \
27+
NCCL_OFI_WARN("Failed to resolve CUDA function %s", #function); \
28+
break; \
29+
case cudaDriverEntryPointVersionNotSufficent: \
30+
NCCL_OFI_WARN("Insufficient driver to use CUDA function %s", #function); \
31+
break; \
32+
case cudaDriverEntryPointSuccess: \
33+
default: \
34+
NCCL_OFI_WARN("Unexpected cudaDriverEntryPointQueryResutlt value %d", (int)result); \
35+
break; \
36+
} \
37+
} \
38+
} while (0);
39+
#else
40+
#define RESOLVE_CUDA_FUNCTION(function, version) do { \
2041
enum cudaDriverEntryPointQueryResult result; \
2142
cudaError_t err = \
2243
cudaGetDriverEntryPoint(#function, (void **)&pfn_##function, cudaEnableDefault, &result); \
@@ -35,14 +56,15 @@
3556
} \
3657
} \
3758
} while (0);
59+
#endif
3860

39-
DECLARE_CUDA_FUNCTION(cuCtxGetDevice);
40-
DECLARE_CUDA_FUNCTION(cuDeviceGetAttribute);
41-
DECLARE_CUDA_FUNCTION(cuMemGetHandleForAddressRange);
42-
DECLARE_CUDA_FUNCTION(cuMemGetAddressRange);
43-
DECLARE_CUDA_FUNCTION(cuMemAlloc);
44-
DECLARE_CUDA_FUNCTION(cuMemFree);
45-
DECLARE_CUDA_FUNCTION(cuMemcpy);
61+
DECLARE_CUDA_FUNCTION(cuCtxGetDevice, 2000);
62+
DECLARE_CUDA_FUNCTION(cuDeviceGetAttribute, 2000);
63+
DECLARE_CUDA_FUNCTION(cuMemGetHandleForAddressRange, 11070);
64+
DECLARE_CUDA_FUNCTION(cuMemGetAddressRange, 3020);
65+
DECLARE_CUDA_FUNCTION(cuMemAlloc, 3020);
66+
DECLARE_CUDA_FUNCTION(cuMemFree, 3020);
67+
DECLARE_CUDA_FUNCTION(cuMemcpy, 4000);
4668

4769
int nccl_net_ofi_cuda_init(void)
4870
{
@@ -70,13 +92,13 @@ int nccl_net_ofi_cuda_init(void)
7092
driverVersion,
7193
runtimeVersion);
7294

73-
RESOLVE_CUDA_FUNCTION(cuCtxGetDevice);
74-
RESOLVE_CUDA_FUNCTION(cuDeviceGetAttribute);
75-
RESOLVE_CUDA_FUNCTION(cuMemGetHandleForAddressRange);
76-
RESOLVE_CUDA_FUNCTION(cuMemGetAddressRange);
77-
RESOLVE_CUDA_FUNCTION(cuMemAlloc);
78-
RESOLVE_CUDA_FUNCTION(cuMemFree);
79-
RESOLVE_CUDA_FUNCTION(cuMemcpy);
95+
RESOLVE_CUDA_FUNCTION(cuCtxGetDevice, 2000);
96+
RESOLVE_CUDA_FUNCTION(cuDeviceGetAttribute, 2000);
97+
RESOLVE_CUDA_FUNCTION(cuMemGetHandleForAddressRange, 11070);
98+
RESOLVE_CUDA_FUNCTION(cuMemGetAddressRange, 3020);
99+
RESOLVE_CUDA_FUNCTION(cuMemAlloc, 3020);
100+
RESOLVE_CUDA_FUNCTION(cuMemFree, 3020);
101+
RESOLVE_CUDA_FUNCTION(cuMemcpy, 4000);
80102

81103
if (HAVE_CUDA_GDRFLUSH_SUPPORT && nccl_net_ofi_cuda_have_gdr_support_attr() && ofi_nccl_cuda_flush_enable()) {
82104
NCCL_OFI_WARN("CUDA flush enabled");
@@ -129,7 +151,7 @@ int nccl_net_ofi_cuda_mem_alloc(void **ptr, size_t size)
129151

130152
int nccl_net_ofi_cuda_mem_free(void *ptr)
131153
{
132-
CUresult ret = pfn_cuMemFree((CUdeviceptr)ptr);
154+
CUresult ret = pfn_cuMemFree((CUdeviceptr)ptr);
133155
return ret == CUDA_SUCCESS ? 0 : -EINVAL;
134156
}
135157

0 commit comments

Comments
 (0)