@@ -532,9 +532,15 @@ def main():
532532 )
533533 accelerator .register_for_checkpointing (lr_scheduler )
534534
535+ weight_dtype = torch .float32
536+ if accelerator .mixed_precision == "fp16" :
537+ weight_dtype = torch .float16
538+ elif accelerator .mixed_precision == "bf16" :
539+ weight_dtype = torch .bfloat16
540+
535541 # Move vae and unet to device
536- vae .to (accelerator .device )
537- unet .to (accelerator .device )
542+ unet .to (accelerator .device , dtype = weight_dtype )
543+ vae .to (accelerator .device , dtype = weight_dtype )
538544
539545 # Keep vae and unet in eval model as we don't train these
540546 vae .eval ()
@@ -600,11 +606,11 @@ def main():
600606
601607 with accelerator .accumulate (text_encoder ):
602608 # Convert images to latent space
603- latents = vae .encode (batch ["pixel_values" ]).latent_dist .sample ().detach ()
609+ latents = vae .encode (batch ["pixel_values" ]. to ( dtype = weight_dtype ) ).latent_dist .sample ().detach ()
604610 latents = latents * 0.18215
605611
606612 # Sample noise that we'll add to the latents
607- noise = torch .randn (latents .shape ).to (latents .device )
613+ noise = torch .randn (latents .shape ).to (latents .device ). to ( dtype = weight_dtype )
608614 bsz = latents .shape [0 ]
609615 # Sample a random timestep for each image
610616 timesteps = torch .randint (
@@ -616,7 +622,7 @@ def main():
616622 noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
617623
618624 # Get the text embedding for conditioning
619- encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
625+ encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]. to ( dtype = weight_dtype )
620626
621627 # Predict the noise residual
622628 model_pred = unet (noisy_latents , timesteps , encoder_hidden_states ).sample
@@ -629,7 +635,7 @@ def main():
629635 else :
630636 raise ValueError (f"Unknown prediction type { noise_scheduler .config .prediction_type } " )
631637
632- loss = F .mse_loss (model_pred , target , reduction = "none" ).mean ([1 , 2 , 3 ]).mean ()
638+ loss = F .mse_loss (model_pred . float () , target . float () , reduction = "none" ).mean ([1 , 2 , 3 ]).mean ()
633639 accelerator .backward (loss )
634640
635641 optimizer .step ()
0 commit comments