Skip to content

Commit d070c0b

Browse files
t-vifacebook-github-bot
authored andcommitted
ROCm: enable cpp_extensions.load/load_inline (pytorch#35897)
Summary: This enables cpp_extensions.load/load_inline. This works by hipify-ing cuda sources. Also enable tests. CuDNN/MIOpen extensions aren't yet supported, I propose to not do this in this PR. Pull Request resolved: pytorch#35897 Differential Revision: D20983279 Pulled By: ezyang fbshipit-source-id: a5d0f5ac592d04488a6a46522c58e2ee0a6fd57c
1 parent ce54f0d commit d070c0b

File tree

4 files changed

+112
-28
lines changed

4 files changed

+112
-28
lines changed

test/run_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@
121121
'distributed/rpc/test_dist_optimizer_spawn',
122122
'distributed/rpc/test_rpc_spawn',
123123
'test_cpp_extensions_aot_ninja',
124-
'test_cpp_extensions_jit',
125124
'test_determination',
126125
'test_multiprocessing',
127126
'test_jit_simple',

test/test_cpp_extensions_jit.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
import torch
1313
import torch.backends.cudnn
1414
import torch.utils.cpp_extension
15-
from torch.utils.cpp_extension import CUDA_HOME
15+
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
1616

1717

1818
TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
1919
TEST_CUDNN = False
20-
if TEST_CUDA:
20+
TEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None and ROCM_HOME is not None
21+
if TEST_CUDA and torch.version.cuda is not None: # the skip CUDNN test for ROCm
2122
CUDNN_HEADER_EXISTS = os.path.isfile(os.path.join(CUDA_HOME, "include/cudnn.h"))
2223
TEST_CUDNN = (
2324
TEST_CUDA and CUDNN_HEADER_EXISTS and torch.backends.cudnn.is_available()
@@ -109,6 +110,7 @@ def test_jit_cuda_extension(self):
109110
],
110111
extra_cuda_cflags=["-O2"],
111112
verbose=True,
113+
keep_intermediates=False,
112114
)
113115

114116
x = torch.zeros(100, device="cuda", dtype=torch.float32)
@@ -184,6 +186,7 @@ def _check_cuobjdump_output(expected_values, is_ptx=False):
184186
os.environ['TORCH_CUDA_ARCH_LIST'] = old_envvar
185187

186188
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
189+
@unittest.skipIf(TEST_ROCM, "disabled on rocm")
187190
def test_jit_cuda_archflags(self):
188191
# Test a number of combinations:
189192
# - the default for the machine we're testing on

torch/utils/cpp_extension.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import torch
1616
from .file_baton import FileBaton
1717
from ._cpp_extension_versioner import ExtensionVersioner
18+
from .hipify import hipify_python
19+
from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner
1820

1921
from setuptools.command.build_ext import build_ext
2022

@@ -814,7 +816,8 @@ def load(name,
814816
build_directory=None,
815817
verbose=False,
816818
with_cuda=None,
817-
is_python_module=True):
819+
is_python_module=True,
820+
keep_intermediates=True):
818821
'''
819822
Loads a PyTorch C++ extension just-in-time (JIT).
820823
@@ -895,7 +898,8 @@ def load(name,
895898
build_directory or _get_build_directory(name, verbose),
896899
verbose,
897900
with_cuda,
898-
is_python_module)
901+
is_python_module,
902+
keep_intermediates=keep_intermediates)
899903

900904

901905
def load_inline(name,
@@ -910,7 +914,8 @@ def load_inline(name,
910914
verbose=False,
911915
with_cuda=None,
912916
is_python_module=True,
913-
with_pytorch_error_handling=True):
917+
with_pytorch_error_handling=True,
918+
keep_intermediates=True):
914919
'''
915920
Loads a PyTorch C++ extension just-in-time (JIT) from string sources.
916921
@@ -1043,7 +1048,8 @@ def load_inline(name,
10431048
build_directory,
10441049
verbose,
10451050
with_cuda,
1046-
is_python_module)
1051+
is_python_module,
1052+
keep_intermediates=keep_intermediates)
10471053

10481054

10491055
def _jit_compile(name,
@@ -1055,7 +1061,11 @@ def _jit_compile(name,
10551061
build_directory,
10561062
verbose,
10571063
with_cuda,
1058-
is_python_module):
1064+
is_python_module,
1065+
keep_intermediates=True):
1066+
if with_cuda is None:
1067+
with_cuda = any(map(_is_cuda_file, sources))
1068+
with_cudnn = any(['cudnn' in f for f in extra_ldflags or []])
10591069
old_version = JIT_EXTENSION_VERSIONER.get_version(name)
10601070
version = JIT_EXTENSION_VERSIONER.bump_version_if_changed(
10611071
name,
@@ -1074,16 +1084,27 @@ def _jit_compile(name,
10741084
baton = FileBaton(os.path.join(build_directory, 'lock'))
10751085
if baton.try_acquire():
10761086
try:
1077-
_write_ninja_file_and_build_library(
1078-
name=name,
1079-
sources=sources,
1080-
extra_cflags=extra_cflags or [],
1081-
extra_cuda_cflags=extra_cuda_cflags or [],
1082-
extra_ldflags=extra_ldflags or [],
1083-
extra_include_paths=extra_include_paths or [],
1084-
build_directory=build_directory,
1085-
verbose=verbose,
1086-
with_cuda=with_cuda)
1087+
with GeneratedFileCleaner(keep_intermediates=keep_intermediates) as clean_ctx:
1088+
if IS_HIP_EXTENSION and (with_cuda or with_cudnn):
1089+
hipify_python.hipify(
1090+
project_directory=build_directory,
1091+
output_directory=build_directory,
1092+
includes=os.path.join(build_directory, '*'),
1093+
extra_files=[os.path.abspath(s) for s in sources],
1094+
show_detailed=verbose,
1095+
is_pytorch_extension=True,
1096+
clean_ctx=clean_ctx
1097+
)
1098+
_write_ninja_file_and_build_library(
1099+
name=name,
1100+
sources=sources,
1101+
extra_cflags=extra_cflags or [],
1102+
extra_cuda_cflags=extra_cuda_cflags or [],
1103+
extra_ldflags=extra_ldflags or [],
1104+
extra_include_paths=extra_include_paths or [],
1105+
build_directory=build_directory,
1106+
verbose=verbose,
1107+
with_cuda=with_cuda)
10871108
finally:
10881109
baton.release()
10891110
else:
@@ -1231,10 +1252,10 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose):
12311252
extra_ldflags.append('-L{}'.format(lib_path))
12321253
extra_ldflags.append('-lc10')
12331254
if with_cuda:
1234-
extra_ldflags.append('-lc10_cuda')
1255+
extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda')
12351256
extra_ldflags.append('-ltorch_cpu')
12361257
if with_cuda:
1237-
extra_ldflags.append('-ltorch_cuda')
1258+
extra_ldflags.append('-ltorch_hip' if IS_HIP_EXTENSION else '-ltorch_cuda')
12381259
extra_ldflags.append('-ltorch')
12391260
extra_ldflags.append('-ltorch_python')
12401261

@@ -1465,7 +1486,15 @@ def _write_ninja_file_to_build_library(path,
14651486
else:
14661487
cflags = common_cflags + ['-fPIC', '-std=c++14'] + extra_cflags
14671488

1468-
if with_cuda:
1489+
if with_cuda and IS_HIP_EXTENSION:
1490+
cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIPCC_FLAGS
1491+
cuda_flags += extra_cuda_cflags
1492+
cuda_flags += _get_rocm_arch_flags(cuda_flags)
1493+
sources = [s if not _is_cuda_file(s) else
1494+
os.path.abspath(os.path.join(
1495+
path, get_hip_file_path(os.path.relpath(s, path))))
1496+
for s in sources]
1497+
elif with_cuda:
14691498
cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags()
14701499
if IS_WINDOWS:
14711500
for flag in COMMON_MSVC_FLAGS:
@@ -1568,7 +1597,11 @@ def sanitize_flags(flags):
15681597
config = ['ninja_required_version = 1.3']
15691598
config.append('cxx = {}'.format(compiler))
15701599
if with_cuda:
1571-
config.append('nvcc = {}'.format(_join_cuda_home('bin', 'nvcc')))
1600+
if IS_HIP_EXTENSION:
1601+
nvcc = _join_rocm_home('bin', 'hipcc')
1602+
else:
1603+
nvcc = _join_cuda_home('bin', 'nvcc')
1604+
config.append('nvcc = {}'.format(nvcc))
15721605

15731606
flags = ['cflags = {}'.format(' '.join(cflags))]
15741607
flags.append('post_cflags = {}'.format(' '.join(post_cflags)))

torch/utils/hipify/hipify_python.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,45 @@ class bcolors:
7272
BOLD = '\033[1m'
7373
UNDERLINE = '\033[4m'
7474

75+
# To the programmer, the output of hipify most likely are intermediates.
76+
# This class allows users of hipify to ask for a cleanup by running the
77+
# hipify and compilation in a with instantiating this context manager class
78+
# with keep_intermediates=False.
79+
# The main usecase is the cpp_extensions, specifically the load method.
80+
# It is a good idea to keep intermediates (in case of errors or to
81+
# not recompile unchanged files), but in cases where you don't want to
82+
# keep them (e.g. in the CI), this can be used to remove files.
83+
class GeneratedFileCleaner:
84+
"""Context Manager to clean up generated files"""
85+
def __init__(self, keep_intermediates=False):
86+
self.keep_intermediates = keep_intermediates
87+
self.files_to_clean = set()
88+
self.dirs_to_clean = []
89+
90+
def __enter__(self):
91+
return self
92+
93+
def open(self, fn, *args):
94+
if not os.path.exists(fn):
95+
self.files_to_clean.add(os.path.abspath(fn))
96+
return open(fn, *args)
97+
98+
def makedirs(self, dn, exist_ok=False):
99+
parent, n = os.path.split(dn)
100+
if not n:
101+
parent, n = os.path.split(parent)
102+
if parent and n and not os.path.exists(parent):
103+
self.makedirs(parent, exist_ok=True)
104+
if not os.path.isdir(dn) or not exist_ok:
105+
os.mkdir(dn)
106+
self.dirs_to_clean.append(os.path.abspath(dn))
107+
108+
def __exit__(self, type, value, traceback):
109+
if not self.keep_intermediates:
110+
for f in self.files_to_clean:
111+
os.unlink(f)
112+
for d in self.dirs_to_clean[::-1]:
113+
os.rmdir(d)
75114

76115
def matched_files_iter(root_path, includes=('*',), ignores=(), extensions=(), out_of_place_only=False):
77116
def _fnmatch(filepath, patterns):
@@ -120,19 +159,24 @@ def preprocess(
120159
show_detailed=False,
121160
show_progress=True,
122161
hip_clang_launch=False,
123-
is_pytorch_extension=False):
162+
is_pytorch_extension=False,
163+
clean_ctx=None):
124164
"""
125165
Call preprocessor on selected files.
126166
127167
Arguments)
128168
show_detailed - Show a detailed summary of the transpilation process.
129169
"""
130170

171+
if clean_ctx is None:
172+
clean_ctx = GeneratedFileCleaner(keep_intermediates=True)
173+
131174
# Preprocessing statistics.
132175
stats = {"unsupported_calls": [], "kernel_launches": []}
133176

134177
for filepath in all_files:
135-
result = preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch_extension)
178+
result = preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch_extension, clean_ctx)
179+
136180
# Show what happened
137181
if show_progress:
138182
print(
@@ -606,15 +650,15 @@ def pattern(self):
606650
RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"')
607651
RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh
608652

609-
def preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch_extension):
653+
def preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch_extension, clean_ctx):
610654
""" Executes the CUDA -> HIP conversion on the specified file. """
611655
fin_path = os.path.join(output_directory, filepath)
612656
with open(fin_path, 'r') as fin:
613657
output_source = fin.read()
614658

615659
fout_path = os.path.join(output_directory, get_hip_file_path(filepath))
616660
if not os.path.exists(os.path.dirname(fout_path)):
617-
os.makedirs(os.path.dirname(fout_path))
661+
clean_ctx.makedirs(os.path.dirname(fout_path))
618662

619663
# unsupported_calls statistics reporting is broken atm
620664
def pt_repl(m):
@@ -675,7 +719,7 @@ def repl(m):
675719
with open(fout_path, 'r') as fout_old:
676720
do_write = fout_old.read() != output_source
677721
if do_write:
678-
with open(fout_path, 'w') as fout:
722+
with clean_ctx.open(fout_path, 'w') as fout:
679723
fout.write(output_source)
680724
return "ok"
681725
else:
@@ -776,11 +820,13 @@ def hipify(
776820
extensions=(".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
777821
output_directory="",
778822
includes=(),
823+
extra_files=(),
779824
out_of_place_only=False,
780825
ignores=(),
781826
show_progress=True,
782827
hip_clang_launch=False,
783828
is_pytorch_extension=False,
829+
clean_ctx=None
784830
):
785831
if project_directory == "":
786832
project_directory = os.getcwd()
@@ -802,6 +848,8 @@ def hipify(
802848
all_files = list(matched_files_iter(output_directory, includes=includes,
803849
ignores=ignores, extensions=extensions,
804850
out_of_place_only=out_of_place_only))
851+
all_files_set = set(all_files)
852+
all_files += [f for f in extra_files if f not in all_files_set]
805853

806854
# Start Preprocessor
807855
preprocess(
@@ -810,4 +858,5 @@ def hipify(
810858
show_detailed=show_detailed,
811859
show_progress=show_progress,
812860
hip_clang_launch=hip_clang_launch,
813-
is_pytorch_extension=is_pytorch_extension)
861+
is_pytorch_extension=is_pytorch_extension,
862+
clean_ctx=clean_ctx)

0 commit comments

Comments
 (0)