@@ -77,16 +77,17 @@ def _check_capability():
7777 The minimum cuda capability that we support is 3.5.
7878 """
7979
80- CUDA_VERSION = torch ._C ._cuda_getCompiledVersion ()
81- for d in range (device_count ()):
82- capability = get_device_capability (d )
83- major = capability [0 ]
84- minor = capability [1 ]
85- name = get_device_name (d )
86- if capability == (3 , 0 ) or major < 3 :
87- warnings .warn (old_gpu_warn % (d , name , major , capability [1 ]))
88- elif CUDA_VERSION <= 9000 and major >= 7 and minor >= 5 :
89- warnings .warn (incorrect_binary_warn % (d , name , 10000 , CUDA_VERSION ))
80+ if torch .version .cuda is not None : # on ROCm we don't want this check
81+ CUDA_VERSION = torch ._C ._cuda_getCompiledVersion ()
82+ for d in range (device_count ()):
83+ capability = get_device_capability (d )
84+ major = capability [0 ]
85+ minor = capability [1 ]
86+ name = get_device_name (d )
87+ if capability == (3 , 0 ) or major < 3 :
88+ warnings .warn (old_gpu_warn % (d , name , major , capability [1 ]))
89+ elif CUDA_VERSION <= 9000 and major >= 7 and minor >= 5 :
90+ warnings .warn (incorrect_binary_warn % (d , name , 10000 , CUDA_VERSION ))
9091
9192
9293def is_initialized ():
0 commit comments