Skip to content

Commit f8325cf

Browse files
[MPS] Make sure it doesn't break torch < 1.12 (huggingface#425)
* [MPS] Make sure it doesn't break torch < 1.12 * up
1 parent 8d9c4a5 commit f8325cf

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/diffusers/testing_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@
55

66
import torch
77

8+
from packaging import version
9+
810

911
global_rng = random.Random()
1012
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
11-
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
13+
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")
14+
15+
if is_torch_higher_equal_than_1_12:
16+
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
1217

1318

1419
def parse_flag_from_env(key, default=False):

0 commit comments

Comments
 (0)