Skip to content

Commit 9bca402

Browse files
authored
[MPS] fix mps failing tests (huggingface#934)
fix mps failing tests
1 parent 2fdd094 commit 9bca402

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

tests/test_models_vae.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,21 @@ def test_output_pretrained(self):
106106

107107
# Since the VAE Gaussian prior's generator is seeded on the appropriate device,
108108
# the expected output slices are not the same for CPU and GPU.
109-
if torch_device in ("mps", "cpu"):
109+
if torch_device == "mps":
110+
expected_output_slice = torch.tensor(
111+
[
112+
-4.0078e-01,
113+
-3.8323e-04,
114+
-1.2681e-01,
115+
-1.1462e-01,
116+
2.0095e-01,
117+
1.0893e-01,
118+
-8.8247e-02,
119+
-3.0361e-01,
120+
-9.8644e-03,
121+
]
122+
)
123+
elif torch_device == "cpu":
110124
expected_output_slice = torch.tensor(
111125
[-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026]
112126
)

tests/test_pipelines.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,13 @@ def test_components(self):
455455
text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
456456

457457
prompt = "A painting of a squirrel eating a burger"
458-
generator = torch.Generator(device=torch_device).manual_seed(0)
458+
459+
# Device type MPS is not supported for torch.Generator() api.
460+
if torch_device == "mps":
461+
generator = torch.manual_seed(0)
462+
else:
463+
generator = torch.Generator(device=torch_device).manual_seed(0)
464+
459465
image_inpaint = inpaint(
460466
[prompt],
461467
generator=generator,

0 commit comments

Comments
 (0)