Skip to content

Commit 544710e

Browse files
bghirabghirasayakpaul
authored
diffusers#7426 fix stable diffusion xl inference on MPS when dtypes shift unexpectedly due to pytorch bugs (huggingface#7446)
* mps: fix XL pipeline inference at training time due to upstream pytorch bug * Update src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py Co-authored-by: Sayak Paul <[email protected]> * apply the safe-guarding logic elsewhere. --------- Co-authored-by: bghira <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 443aa14 commit 544710e

File tree

4 files changed

+68
-0
lines changed

4 files changed

+68
-0
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,16 @@ def __call__(
11931193
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
11941194

11951195
# compute the previous noisy sample x_t -> x_t-1
1196+
latents_dtype = latents.dtype
11961197
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1198+
if latents.dtype != latents_dtype:
1199+
if torch.backends.mps.is_available():
1200+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1201+
latents = latents.to(latents_dtype)
1202+
else:
1203+
raise ValueError(
1204+
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
1205+
)
11971206

11981207
if callback_on_step_end is not None:
11991208
callback_kwargs = {}
@@ -1228,6 +1237,14 @@ def __call__(
12281237
if needs_upcasting:
12291238
self.upcast_vae()
12301239
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1240+
elif latents.dtype != self.vae.dtype:
1241+
if torch.backends.mps.is_available():
1242+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1243+
self.vae = self.vae.to(latents.dtype)
1244+
else:
1245+
raise ValueError(
1246+
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
1247+
)
12311248

12321249
# unscale/denormalize the latents
12331250
# denormalize with the mean and std if available and not None

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,7 +1370,16 @@ def denoising_value_valid(dnv):
13701370
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
13711371

13721372
# compute the previous noisy sample x_t -> x_t-1
1373+
latents_dtype = latents.dtype
13731374
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1375+
if latents.dtype != latents_dtype:
1376+
if torch.backends.mps.is_available():
1377+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1378+
latents = latents.to(latents_dtype)
1379+
else:
1380+
raise ValueError(
1381+
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
1382+
)
13741383

13751384
if callback_on_step_end is not None:
13761385
callback_kwargs = {}
@@ -1405,6 +1414,14 @@ def denoising_value_valid(dnv):
14051414
if needs_upcasting:
14061415
self.upcast_vae()
14071416
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1417+
elif latents.dtype != self.vae.dtype:
1418+
if torch.backends.mps.is_available():
1419+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1420+
self.vae = self.vae.to(latents.dtype)
1421+
else:
1422+
raise ValueError(
1423+
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
1424+
)
14081425

14091426
# unscale/denormalize the latents
14101427
# denormalize with the mean and std if available and not None

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,7 +1720,16 @@ def denoising_value_valid(dnv):
17201720
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
17211721

17221722
# compute the previous noisy sample x_t -> x_t-1
1723+
latents_dtype = latents.dtype
17231724
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1725+
if latents.dtype != latents_dtype:
1726+
if torch.backends.mps.is_available():
1727+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1728+
latents = latents.to(latents_dtype)
1729+
else:
1730+
raise ValueError(
1731+
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
1732+
)
17241733

17251734
if num_channels_unet == 4:
17261735
init_latents_proper = image_latents
@@ -1772,6 +1781,14 @@ def denoising_value_valid(dnv):
17721781
if needs_upcasting:
17731782
self.upcast_vae()
17741783
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1784+
elif latents.dtype != self.vae.dtype:
1785+
if torch.backends.mps.is_available():
1786+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1787+
self.vae = self.vae.to(latents.dtype)
1788+
else:
1789+
raise ValueError(
1790+
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
1791+
)
17751792

17761793
# unscale/denormalize the latents
17771794
# denormalize with the mean and std if available and not None

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,16 @@ def __call__(
918918
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
919919

920920
# compute the previous noisy sample x_t -> x_t-1
921+
latents_dtype = latents.dtype
921922
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
923+
if latents.dtype != latents_dtype:
924+
if torch.backends.mps.is_available():
925+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
926+
latents = latents.to(latents_dtype)
927+
else:
928+
raise ValueError(
929+
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
930+
)
922931

923932
# call the callback, if provided
924933
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -937,6 +946,14 @@ def __call__(
937946
if needs_upcasting:
938947
self.upcast_vae()
939948
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
949+
elif latents.dtype != self.vae.dtype:
950+
if torch.backends.mps.is_available():
951+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
952+
self.vae = self.vae.to(latents.dtype)
953+
else:
954+
raise ValueError(
955+
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
956+
)
940957

941958
# unscale/denormalize the latents
942959
# denormalize with the mean and std if available and not None

0 commit comments

Comments
 (0)