Skip to content

Commit 22c1ba5

Browse files
Fix k_dpm_2 & k_dpm_2_a on MPS (huggingface#2241)
Needed to convert `timesteps` to `float32` a bit sooner. Fixes huggingface#1537
1 parent 7386e77 commit 22c1ba5

File tree

2 files changed

+13
-14
lines changed

2 files changed

+13
-14
lines changed

src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,16 @@ def set_timesteps(
161161
# standard deviation of the initial noise distribution
162162
self.init_noise_sigma = self.sigmas.max()
163163

164-
timesteps = torch.from_numpy(timesteps).to(device)
165-
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
166-
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
167-
timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
168-
169164
if str(device).startswith("mps"):
170165
# mps does not support float64
171-
self.timesteps = timesteps.to(device, dtype=torch.float32)
166+
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
172167
else:
173-
self.timesteps = timesteps
168+
timesteps = torch.from_numpy(timesteps).to(device)
169+
170+
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
171+
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
172+
173+
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
174174

175175
self.sample = None
176176

src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,18 +149,17 @@ def set_timesteps(
149149
# standard deviation of the initial noise distribution
150150
self.init_noise_sigma = self.sigmas.max()
151151

152-
timesteps = torch.from_numpy(timesteps).to(device)
152+
if str(device).startswith("mps"):
153+
# mps does not support float64
154+
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
155+
else:
156+
timesteps = torch.from_numpy(timesteps).to(device)
153157

154158
# interpolate timesteps
155159
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
156160
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
157-
timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
158161

159-
if str(device).startswith("mps"):
160-
# mps does not support float64
161-
self.timesteps = timesteps.to(torch.float32)
162-
else:
163-
self.timesteps = timesteps
162+
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
164163

165164
self.sample = None
166165

0 commit comments

Comments
 (0)