Skip to content

Commit 0ea5162

Browse files
authored
[Core] Fix dtype in InstructPix2Pix SDXL while computing image_latents (huggingface#5013)
* check out dtypes. * check out dtypes. * check out dtypes. * check out dtypes. * check out dtypes. * check out dtypes. * check out dtypes. * potential fix * check out dtypes. * check out dtypes. * working?
1 parent 6d6a08f commit 0ea5162

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,8 @@ def prepare_image_latents(
495495
image_latents = image
496496
else:
497497
# make sure the VAE is in float32 mode, as it overflows in float16
498-
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
498+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
499+
if needs_upcasting:
499500
self.upcast_vae()
500501
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
501502

@@ -511,6 +512,10 @@ def prepare_image_latents(
511512
else:
512513
image_latents = self.vae.encode(image).latent_dist.mode()
513514

515+
# cast back to fp16 if needed
516+
if needs_upcasting:
517+
self.vae.to(dtype=torch.float16)
518+
514519
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
515520
# expand image_latents for batch_size
516521
deprecation_message = (
@@ -533,6 +538,9 @@ def prepare_image_latents(
533538
uncond_image_latents = torch.zeros_like(image_latents)
534539
image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
535540

541+
if image_latents.dtype != self.vae.dtype:
542+
image_latents = image_latents.to(dtype=self.vae.dtype)
543+
536544
return image_latents
537545

538546
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids

0 commit comments

Comments
 (0)