Skip to content

Commit 86b6dea

Browse files
SunMarcMatthew Hoffman
andauthored
Fix access error for torch.mps when using torch==1.13.1 on macOS (huggingface#2806)
* Fix access error for torch.mps when using torch==1.13.1 * Add missing parentheses * add min_version --------- Co-authored-by: Matthew Hoffman <[email protected]>
1 parent b24a0ef commit 86b6dea

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

src/accelerate/test_utils/testing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,10 @@ def get_backend():
6363
return "xla", torch.cuda.device_count(), torch.cuda.memory_allocated
6464
elif is_cuda_available():
6565
return "cuda", torch.cuda.device_count(), torch.cuda.memory_allocated
66-
elif is_mps_available():
66+
elif is_mps_available(min_version="2.0"):
6767
return "mps", 1, torch.mps.current_allocated_memory()
68+
elif is_mps_available():
69+
return "mps", 1, 0
6870
elif is_mlu_available():
6971
return "mlu", torch.mlu.device_count(), torch.mlu.memory_allocated
7072
elif is_npu_available():

src/accelerate/utils/imports.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,10 @@ def is_mlflow_available():
302302
return False
303303

304304

305-
def is_mps_available():
306-
return is_torch_version(">=", "1.12") and torch.backends.mps.is_available() and torch.backends.mps.is_built()
305+
def is_mps_available(min_version="1.12"):
306+
# With torch 1.12, you can use torch.backends.mps
307+
# With torch 2.0.0, you can use torch.mps
308+
return is_torch_version(">=", min_version) and torch.backends.mps.is_available() and torch.backends.mps.is_built()
307309

308310

309311
def is_ipex_available():

src/accelerate/utils/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def release_memory(*objects):
5959
torch.mlu.empty_cache()
6060
elif is_npu_available():
6161
torch.npu.empty_cache()
62-
elif is_mps_available():
62+
elif is_mps_available(min_version="2.0"):
6363
torch.mps.empty_cache()
6464
else:
6565
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)