File tree Expand file tree Collapse file tree 3 files changed +8
-4
lines changed Expand file tree Collapse file tree 3 files changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -63,8 +63,10 @@ def get_backend():
63
63
return "xla" , torch .cuda .device_count (), torch .cuda .memory_allocated
64
64
elif is_cuda_available ():
65
65
return "cuda" , torch .cuda .device_count (), torch .cuda .memory_allocated
66
- elif is_mps_available ():
66
+ elif is_mps_available (min_version = "2.0" ):
67
67
return "mps" , 1 , torch .mps .current_allocated_memory ()
68
+ elif is_mps_available ():
69
+ return "mps" , 1 , 0
68
70
elif is_mlu_available ():
69
71
return "mlu" , torch .mlu .device_count (), torch .mlu .memory_allocated
70
72
elif is_npu_available ():
Original file line number Diff line number Diff line change @@ -302,8 +302,10 @@ def is_mlflow_available():
302
302
return False
303
303
304
304
305
- def is_mps_available ():
306
- return is_torch_version (">=" , "1.12" ) and torch .backends .mps .is_available () and torch .backends .mps .is_built ()
305
+ def is_mps_available (min_version = "1.12" ):
306
+ # With torch 1.12, you can use torch.backends.mps
307
+ # With torch 2.0.0, you can use torch.mps
308
+ return is_torch_version (">=" , min_version ) and torch .backends .mps .is_available () and torch .backends .mps .is_built ()
307
309
308
310
309
311
def is_ipex_available ():
Original file line number Diff line number Diff line change @@ -59,7 +59,7 @@ def release_memory(*objects):
59
59
torch .mlu .empty_cache ()
60
60
elif is_npu_available ():
61
61
torch .npu .empty_cache ()
62
- elif is_mps_available ():
62
+ elif is_mps_available (min_version = "2.0" ):
63
63
torch .mps .empty_cache ()
64
64
else :
65
65
torch .cuda .empty_cache ()
You can’t perform that action at this time.
0 commit comments