66#include " config.h"
77
88#include < errno.h>
9+ #include < dlfcn.h>
10+ #include < memory>
911#include < cudaTypedefs.h>
1012#include < cuda_runtime_api.h>
1113
1416#include " nccl_ofi_log.h"
1517#include " nccl_ofi_param.h"
1618
19+ /* CUDA Runtime function pointers - only for functions without driver equivalents */
20+ static cudaError_t (*pfn_cudaRuntimeGetVersion)(int *runtimeVersion) = NULL;
21+
22+ /* Both entry point functions for cross-version compatibility */
23+ static cudaError_t (*pfn_cudaGetDriverEntryPointByVersion)(const char *symbol, void **funcPtr, unsigned int cudaVersion, unsigned long long flags, enum cudaDriverEntryPointQueryResult *driverStatus) = NULL;
24+ static cudaError_t (*pfn_cudaGetDriverEntryPoint)(const char *symbol, void **funcPtr, unsigned long long flags, enum cudaDriverEntryPointQueryResult *driverStatus) = NULL;
25+
26+ #if ENABLE_CUDART_DYNAMIC
27+
28+ struct DlcloseDeleter {
29+ void operator ()(void * handle) const {
30+ if (handle != nullptr ) {
31+ dlclose (handle);
32+ }
33+ }
34+ };
35+
36+ /* Global unique_ptr to automatically call dlclose when plugin is unloaded */
37+ static std::unique_ptr<void , DlcloseDeleter> cudaruntime_lib;
38+ #endif
39+
1740#define DECLARE_CUDA_FUNCTION (function, version ) static PFN_##function##_v##version pfn_##function = NULL
1841
19- # if CUDART_VERSION >= 13000
42+ /* Simple function resolution with fallback for cross-version compatibility */
2043#define RESOLVE_CUDA_FUNCTION (function, version ) do { \
21- enum cudaDriverEntryPointQueryResult result; \
22- cudaError_t err = \
23- cudaGetDriverEntryPointByVersion (#function, (void **)&pfn_##function, version, cudaEnableDefault, &result); \
24- if (err != cudaSuccess) { \
25- switch (result) { \
26- case cudaDriverEntryPointSymbolNotFound: \
27- NCCL_OFI_WARN (" Failed to resolve CUDA function %s" , #function); \
28- break ; \
29- case cudaDriverEntryPointVersionNotSufficent: \
30- NCCL_OFI_WARN (" Insufficient driver to use CUDA function %s" , #function); \
31- break ; \
32- case cudaDriverEntryPointSuccess: \
33- default : \
34- NCCL_OFI_WARN (" Unexpected cudaDriverEntryPointQueryResutlt value %d" , (int )result); \
35- break ; \
44+ enum cudaDriverEntryPointQueryResult result = cudaDriverEntryPointSymbolNotFound; \
45+ cudaError_t err = cudaErrorUnknown; \
46+ bool resolved = false ; \
47+ /* Try versioned entry point first (CUDA 13+ preferred) */ \
48+ if (pfn_cudaGetDriverEntryPointByVersion != NULL ) { \
49+ err = pfn_cudaGetDriverEntryPointByVersion (#function, (void **)&pfn_##function, version, cudaEnableDefault, &result); \
50+ if (err == cudaSuccess && pfn_##function != NULL ) { \
51+ resolved = true ; \
3652 } \
3753 } \
38- } while (0 );
39- #else
40- #define RESOLVE_CUDA_FUNCTION (function, version ) do { \
41- enum cudaDriverEntryPointQueryResult result; \
42- cudaError_t err = \
43- cudaGetDriverEntryPoint (#function, (void **)&pfn_##function, cudaEnableDefault, &result); \
44- if (err != cudaSuccess) { \
45- switch (result) { \
46- case cudaDriverEntryPointSymbolNotFound: \
47- NCCL_OFI_WARN (" Failed to resolve CUDA function %s" , #function); \
48- break ; \
49- case cudaDriverEntryPointVersionNotSufficent: \
50- NCCL_OFI_WARN (" Insufficient driver to use CUDA function %s" , #function); \
51- break ; \
52- case cudaDriverEntryPointSuccess: \
53- default : \
54- NCCL_OFI_WARN (" Unexpected cudaDriverEntryPointQueryResutlt value %d" , (int )result); \
55- break ; \
54+ /* Fallback to legacy entry point for CUDA 12 compatibility */ \
55+ if (!resolved && pfn_cudaGetDriverEntryPoint != NULL ) { \
56+ err = pfn_cudaGetDriverEntryPoint (#function, (void **)&pfn_##function, cudaEnableDefault, &result); \
57+ if (err == cudaSuccess && pfn_##function != NULL ) { \
58+ resolved = true ; \
5659 } \
5760 } \
61+ if (!resolved) { \
62+ NCCL_OFI_WARN (" Failed to resolve CUDA function %s (last error: %d, result: %d)" , #function, err, result); \
63+ return -ENOTSUP; \
64+ } \
5865 } while (0 );
59- #endif
6066
67+ #define LOAD_CUDA_RUNTIME_SYM (handle, sym ) \
68+ pfn_##sym = (decltype (pfn_##sym))dlsym(handle, #sym); \
69+ if (pfn_##sym == NULL ) { \
70+ NCCL_OFI_WARN (" Failed to load CUDA runtime symbol %s" , #sym); \
71+ return -ENOTSUP; \
72+ }
73+
74+ /* Use driver APIs wherever possible - they are version-stable */
6175DECLARE_CUDA_FUNCTION (cuDriverGetVersion, 2020 );
6276DECLARE_CUDA_FUNCTION (cuCtxGetDevice, 2000 );
6377DECLARE_CUDA_FUNCTION (cuDeviceGetAttribute, 2000 );
@@ -77,13 +91,58 @@ int nccl_net_ofi_cuda_init(void)
7791{
7892 int driverVersion = -1 ;
7993 int runtimeVersion = -1 ;
94+ cudaError_t res;
95+ CUresult cu_ret;
96+
97+ #if ENABLE_CUDART_DYNAMIC
98+ /* Dynamic loading for binaries when static library support disabled */
99+ /* Load library only once and keep it loaded for program lifetime */
100+ if (cudaruntime_lib == nullptr ) {
101+ (void ) dlerror (); /* Clear any previous errors */
102+ cudaruntime_lib = std::unique_ptr<void , DlcloseDeleter>(dlopen (" libcudart.so" , RTLD_NOW));
103+ if (!cudaruntime_lib) {
104+ NCCL_OFI_WARN (" Failed to find CUDA Runtime library: %s" , dlerror ());
105+ return -ENOTSUP;
106+ }
107+ }
80108
81- cudaError_t res = cudaRuntimeGetVersion (&runtimeVersion);
109+ LOAD_CUDA_RUNTIME_SYM (cudaruntime_lib.get (), cudaRuntimeGetVersion);
110+
111+ /* Get runtime version first to determine which entry point functions to load */
112+ res = pfn_cudaRuntimeGetVersion (&runtimeVersion);
82113 if (res != cudaSuccess) {
83114 NCCL_OFI_WARN (" Failed to query CUDA runtime version." );
84115 return -EINVAL;
85116 }
86117
118+ if (runtimeVersion >= 13000 ) {
119+ LOAD_CUDA_RUNTIME_SYM (cudaruntime_lib.get (), cudaGetDriverEntryPointByVersion);
120+ } else {
121+ LOAD_CUDA_RUNTIME_SYM (cudaruntime_lib.get (), cudaGetDriverEntryPoint);
122+ }
123+
124+ if (pfn_cudaGetDriverEntryPointByVersion == NULL && pfn_cudaGetDriverEntryPoint == NULL ) {
125+ NCCL_OFI_WARN (" No CUDA driver entry point functions available in runtime" );
126+ return -ENOTSUP;
127+ }
128+ #else
129+ /* Static CUDA runtime - use direct function calls */
130+ pfn_cudaRuntimeGetVersion = cudaRuntimeGetVersion;
131+
132+ /* Get runtime version first to determine which entry point functions to use */
133+ res = cudaRuntimeGetVersion (&runtimeVersion);
134+ if (res != cudaSuccess) {
135+ NCCL_OFI_WARN (" Failed to query CUDA runtime version." );
136+ return -EINVAL;
137+ }
138+
139+ #if CUDART_VERSION >= 13000
140+ pfn_cudaGetDriverEntryPointByVersion = cudaGetDriverEntryPointByVersion;
141+ #else
142+ pfn_cudaGetDriverEntryPoint = cudaGetDriverEntryPoint;
143+ #endif
144+ #endif
145+
87146 RESOLVE_CUDA_FUNCTION (cuDriverGetVersion, 2020 );
88147 RESOLVE_CUDA_FUNCTION (cuCtxGetDevice, 2000 );
89148 RESOLVE_CUDA_FUNCTION (cuDeviceGetAttribute, 2000 );
@@ -99,16 +158,16 @@ int nccl_net_ofi_cuda_init(void)
99158 RESOLVE_CUDA_FUNCTION (cuMemFree, 3020 );
100159 RESOLVE_CUDA_FUNCTION (cuMemcpy, 4000 );
101160
102- CUresult cu_ret = pfn_cuDriverGetVersion (&driverVersion);
161+ cu_ret = pfn_cuDriverGetVersion (&driverVersion);
103162 if (cu_ret != CUDA_SUCCESS) {
104163 NCCL_OFI_WARN (" Failed to query CUDA driver version." );
105164 return -EINVAL;
106165 }
107166
108167 NCCL_OFI_INFO (NCCL_INIT | NCCL_NET,
109- " Using CUDA driver version %d with runtime %d" ,
110- driverVersion,
111- runtimeVersion);
168+ " Using CUDA driver version %d with runtime %d" ,
169+ driverVersion,
170+ runtimeVersion);
112171
113172 if (HAVE_CUDA_GDRFLUSH_SUPPORT && nccl_net_ofi_cuda_have_gdr_support_attr () && ofi_nccl_cuda_flush_enable ()) {
114173 NCCL_OFI_WARN (" CUDA flush enabled" );
@@ -137,7 +196,6 @@ int nccl_net_ofi_cuda_flush_gpudirect_rdma_writes(void)
137196#endif
138197}
139198
140-
141199int nccl_net_ofi_cuda_mem_alloc (void **ptr, size_t size)
142200{
143201 CUdeviceptr d_ptr;
0 commit comments