Skip to content

Commit aa0b78f

Browse files
committed
NCCL find path update
1 parent 75ab94b commit aa0b78f

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

flashinfer/comm.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
limitations under the License.
1515
"""
1616

17+
import os
1718
import ctypes
1819
import functools
1920
import math
@@ -270,6 +271,40 @@ def get_mpi_include_lib_path():
270271
return include_dir, lib_dirs
271272

272273

274+
def get_nccl_include_lib_path():
275+
import pathlib
276+
import subprocess
277+
278+
# Try to find NCCL paths from environment variables first
279+
nccl_include_path = os.environ.get("NCCL_INCLUDE_PATH")
280+
nccl_lib_path = os.environ.get("NCCL_LIB_PATH")
281+
282+
if nccl_include_path and nccl_lib_path:
283+
return nccl_include_path, [nccl_lib_path]
284+
285+
# Find NCCL library using find command
286+
try:
287+
lib_output = subprocess.check_output(
288+
["find", "/usr", "-name", "libnccl*.so"], text=True
289+
).strip()
290+
if lib_output:
291+
lib_path = pathlib.Path(lib_output.split("\n")[0]).parent
292+
# Find corresponding include path
293+
include_output = subprocess.check_output(
294+
["find", "/usr/include", "-name", "nccl.h"], text=True
295+
).strip()
296+
if include_output:
297+
include_path = pathlib.Path(include_output.split("\n")[0]).parent
298+
return include_path, [lib_path]
299+
except subprocess.CalledProcessError:
300+
pass
301+
302+
raise RuntimeError(
303+
"Could not find NCCL include or library paths. "
304+
"Please set NCCL_INCLUDE_PATH and NCCL_LIB_PATH environment variables."
305+
)
306+
307+
273308
def get_output_info(input: torch.Tensor, dim: int) -> List[int]:
274309
dim = dim % input.ndim
275310
output_shape = [val if idx != dim else -1 for idx, val in enumerate(input.shape)]
@@ -302,8 +337,11 @@ def restore_full_output(
302337

303338
def gen_comm_module() -> JitSpec:
304339
mpi_include_path, mpi_lib_path = get_mpi_include_lib_path()
340+
nccl_include_path, nccl_lib_path = get_nccl_include_lib_path()
305341
mpi_lib_path = str(mpi_lib_path[0])
342+
nccl_lib_path = str(nccl_lib_path[0])
306343
print(mpi_include_path, mpi_lib_path)
344+
print(nccl_include_path, nccl_lib_path)
307345
return gen_jit_spec(
308346
"comm",
309347
[
@@ -322,8 +360,9 @@ def gen_comm_module() -> JitSpec:
322360
jit_env.FLASHINFER_CSRC_DIR / "nv_internal",
323361
jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include",
324362
mpi_include_path,
363+
nccl_include_path,
325364
],
326-
extra_ldflags=["-lnccl", f"-L{mpi_lib_path}", "-lmpi"],
365+
extra_ldflags=[f"-L{mpi_lib_path}", "-lmpi", f"-L{nccl_lib_path}", "-lnccl"],
327366
extra_cuda_cflags=sm100a_nvcc_flags
328367
+ [
329368
"-DENABLE_MULTI_DEVICE",

0 commit comments

Comments
 (0)