We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8d9c4a5 commit f8325cfCopy full SHA for f8325cf
src/diffusers/testing_utils.py
@@ -5,10 +5,15 @@
5
6
import torch
7
8
+from packaging import version
9
+
10
11
global_rng = random.Random()
12
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
-torch_device = "mps" if torch.backends.mps.is_available() else torch_device
13
+is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")
14
15
+if is_torch_higher_equal_than_1_12:
16
+ torch_device = "mps" if torch.backends.mps.is_available() else torch_device
17
18
19
def parse_flag_from_env(key, default=False):
0 commit comments