Skip to content

Commit 0feb21a

Browse files
authored
[Tests] Fix mps+generator fast tests (huggingface#1230)
* [Tests] Fix mps+generator fast tests * mps for Euler * retry * warmup issue again? * fix reproducible initial noise * Revert "fix reproducible initial noise" This reverts commit f300d05. * fix reproducible initial noise * fix device
1 parent 187de44 commit 0feb21a

File tree

5 files changed

+44
-20
lines changed

5 files changed

+44
-20
lines changed

.github/workflows/pr_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ jobs:
136136
- name: Run fast PyTorch tests on M1 (MPS)
137137
shell: arch -arch arm64 bash {0}
138138
run: |
139-
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/
139+
${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
140140
141141
- name: Failure short reports
142142
if: ${{ failure() }}

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __call__(
7878
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
7979
message = (
8080
f"The `generator` device is `{generator.device}` and does not match the pipeline "
81-
f"device `{self.device}`, so the `generator` will be set to `None`. "
81+
f"device `{self.device}`, so the `generator` will be ignored. "
8282
f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
8383
)
8484
deprecate(
@@ -89,11 +89,13 @@ def __call__(
8989
generator = None
9090

9191
# Sample gaussian noise to begin loop
92-
image = torch.randn(
93-
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
94-
generator=generator,
95-
device=self.device,
96-
)
92+
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
93+
if self.device.type == "mps":
94+
# randn does not work reproducibly on mps
95+
image = torch.randn(image_shape, generator=generator)
96+
image = image.to(self.device)
97+
else:
98+
image = torch.randn(image_shape, generator=generator, device=self.device)
9799

98100
# set step values
99101
self.scheduler.set_timesteps(num_inference_steps)

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __call__(
8383
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
8484
message = (
8585
f"The `generator` device is `{generator.device}` and does not match the pipeline "
86-
f"device `{self.device}`, so the `generator` will be set to `None`. "
86+
f"device `{self.device}`, so the `generator` will be ignored. "
8787
f'Please use `torch.Generator(device="{self.device}")` instead.'
8888
)
8989
deprecate(
@@ -94,11 +94,13 @@ def __call__(
9494
generator = None
9595

9696
# Sample gaussian noise to begin loop
97-
image = torch.randn(
98-
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
99-
generator=generator,
100-
device=self.device,
101-
)
97+
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
98+
if self.device.type == "mps":
99+
# randn does not work reproducibly on mps
100+
image = torch.randn(image_shape, generator=generator)
101+
image = image.to(self.device)
102+
else:
103+
image = torch.randn(image_shape, generator=generator, device=self.device)
102104

103105
# set step values
104106
self.scheduler.set_timesteps(num_inference_steps)

tests/pipelines/ddpm/test_ddpm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,14 @@ def test_inference_predict_epsilon(self):
8181
if torch_device == "mps":
8282
_ = ddpm(num_inference_steps=1)
8383

84-
generator = torch.Generator(device=torch_device).manual_seed(0)
84+
if torch_device == "mps":
85+
# device type MPS is not supported for torch.Generator() api.
86+
generator = torch.manual_seed(0)
87+
else:
88+
generator = torch.Generator(device=torch_device).manual_seed(0)
8589
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
8690

87-
generator = torch.Generator(device=torch_device).manual_seed(0)
91+
generator = generator.manual_seed(0)
8892
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
8993

9094
image_slice = image[0, -3:, -3:, -1]

tests/test_scheduler.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,7 +1281,11 @@ def test_full_loop_no_noise(self):
12811281

12821282
scheduler.set_timesteps(self.num_inference_steps)
12831283

1284-
generator = torch.Generator(torch_device).manual_seed(0)
1284+
if torch_device == "mps":
1285+
# device type MPS is not supported for torch.Generator() api.
1286+
generator = torch.manual_seed(0)
1287+
else:
1288+
generator = torch.Generator(device=torch_device).manual_seed(0)
12851289

12861290
model = self.dummy_model()
12871291
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
@@ -1308,7 +1312,11 @@ def test_full_loop_device(self):
13081312

13091313
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
13101314

1311-
generator = torch.Generator(torch_device).manual_seed(0)
1315+
if torch_device == "mps":
1316+
# device type MPS is not supported for torch.Generator() api.
1317+
generator = torch.manual_seed(0)
1318+
else:
1319+
generator = torch.Generator(device=torch_device).manual_seed(0)
13121320

13131321
model = self.dummy_model()
13141322
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
@@ -1364,7 +1372,11 @@ def test_full_loop_no_noise(self):
13641372

13651373
scheduler.set_timesteps(self.num_inference_steps)
13661374

1367-
generator = torch.Generator(device=torch_device).manual_seed(0)
1375+
if torch_device == "mps":
1376+
# device type MPS is not supported for torch.Generator() api.
1377+
generator = torch.manual_seed(0)
1378+
else:
1379+
generator = torch.Generator(device=torch_device).manual_seed(0)
13681380

13691381
model = self.dummy_model()
13701382
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
@@ -1381,7 +1393,7 @@ def test_full_loop_no_noise(self):
13811393
result_sum = torch.sum(torch.abs(sample))
13821394
result_mean = torch.mean(torch.abs(sample))
13831395

1384-
if str(torch_device).startswith("cpu"):
1396+
if torch_device in ["cpu", "mps"]:
13851397
assert abs(result_sum.item() - 152.3192) < 1e-2
13861398
assert abs(result_mean.item() - 0.1983) < 1e-3
13871399
else:
@@ -1396,7 +1408,11 @@ def test_full_loop_device(self):
13961408

13971409
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
13981410

1399-
generator = torch.Generator(device=torch_device).manual_seed(0)
1411+
if torch_device == "mps":
1412+
# device type MPS is not supported for torch.Generator() api.
1413+
generator = torch.manual_seed(0)
1414+
else:
1415+
generator = torch.Generator(device=torch_device).manual_seed(0)
14001416

14011417
model = self.dummy_model()
14021418
sample = self.dummy_sample_deter * scheduler.init_noise_sigma

0 commit comments

Comments
 (0)