14
14
limitations under the License.
15
15
"""
16
16
17
+ import os
17
18
import ctypes
18
19
import functools
19
20
import math
@@ -270,6 +271,40 @@ def get_mpi_include_lib_path():
270
271
return include_dir , lib_dirs
271
272
272
273
274
+ def get_nccl_include_lib_path ():
275
+ import pathlib
276
+ import subprocess
277
+
278
+ # Try to find NCCL paths from environment variables first
279
+ nccl_include_path = os .environ .get ("NCCL_INCLUDE_PATH" )
280
+ nccl_lib_path = os .environ .get ("NCCL_LIB_PATH" )
281
+
282
+ if nccl_include_path and nccl_lib_path :
283
+ return nccl_include_path , [nccl_lib_path ]
284
+
285
+ # Find NCCL library using find command
286
+ try :
287
+ lib_output = subprocess .check_output (
288
+ ["find" , "/usr" , "-name" , "libnccl*.so" ], text = True
289
+ ).strip ()
290
+ if lib_output :
291
+ lib_path = pathlib .Path (lib_output .split ("\n " )[0 ]).parent
292
+ # Find corresponding include path
293
+ include_output = subprocess .check_output (
294
+ ["find" , "/usr/include" , "-name" , "nccl.h" ], text = True
295
+ ).strip ()
296
+ if include_output :
297
+ include_path = pathlib .Path (include_output .split ("\n " )[0 ]).parent
298
+ return include_path , [lib_path ]
299
+ except subprocess .CalledProcessError :
300
+ pass
301
+
302
+ raise RuntimeError (
303
+ "Could not find NCCL include or library paths. "
304
+ "Please set NCCL_INCLUDE_PATH and NCCL_LIB_PATH environment variables."
305
+ )
306
+
307
+
273
308
def get_output_info (input : torch .Tensor , dim : int ) -> List [int ]:
274
309
dim = dim % input .ndim
275
310
output_shape = [val if idx != dim else - 1 for idx , val in enumerate (input .shape )]
@@ -302,8 +337,11 @@ def restore_full_output(
302
337
303
338
def gen_comm_module () -> JitSpec :
304
339
mpi_include_path , mpi_lib_path = get_mpi_include_lib_path ()
340
+ nccl_include_path , nccl_lib_path = get_nccl_include_lib_path ()
305
341
mpi_lib_path = str (mpi_lib_path [0 ])
342
+ nccl_lib_path = str (nccl_lib_path [0 ])
306
343
print (mpi_include_path , mpi_lib_path )
344
+ print (nccl_include_path , nccl_lib_path )
307
345
return gen_jit_spec (
308
346
"comm" ,
309
347
[
@@ -322,8 +360,9 @@ def gen_comm_module() -> JitSpec:
322
360
jit_env .FLASHINFER_CSRC_DIR / "nv_internal" ,
323
361
jit_env .FLASHINFER_CSRC_DIR / "nv_internal" / "include" ,
324
362
mpi_include_path ,
363
+ nccl_include_path ,
325
364
],
326
- extra_ldflags = ["-lnccl " , f"-L{ mpi_lib_path } " , "-lmpi " ],
365
+ extra_ldflags = [f"-L { mpi_lib_path } " , "-lmpi" , f"-L{ nccl_lib_path } " , "-lnccl " ],
327
366
extra_cuda_cflags = sm100a_nvcc_flags
328
367
+ [
329
368
"-DENABLE_MULTI_DEVICE" ,
0 commit comments