Skip to content

Commit c5f04d4

Browse files
authored
apply amp bf16 on textual inversion (huggingface#1465)
* add conf.yaml * enable bf16 enable amp bf16 for unet forward fix style fix readme remove useless file * change amp to full bf16 * align * make stype * fix format
1 parent 61dec53 commit c5f04d4

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

examples/textual_inversion/textual_inversion.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -532,9 +532,15 @@ def main():
532532
)
533533
accelerator.register_for_checkpointing(lr_scheduler)
534534

535+
weight_dtype = torch.float32
536+
if accelerator.mixed_precision == "fp16":
537+
weight_dtype = torch.float16
538+
elif accelerator.mixed_precision == "bf16":
539+
weight_dtype = torch.bfloat16
540+
535541
# Move vae and unet to device
536-
vae.to(accelerator.device)
537-
unet.to(accelerator.device)
542+
unet.to(accelerator.device, dtype=weight_dtype)
543+
vae.to(accelerator.device, dtype=weight_dtype)
538544

539545
# Keep vae and unet in eval model as we don't train these
540546
vae.eval()
@@ -600,11 +606,11 @@ def main():
600606

601607
with accelerator.accumulate(text_encoder):
602608
# Convert images to latent space
603-
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
609+
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
604610
latents = latents * 0.18215
605611

606612
# Sample noise that we'll add to the latents
607-
noise = torch.randn(latents.shape).to(latents.device)
613+
noise = torch.randn(latents.shape).to(latents.device).to(dtype=weight_dtype)
608614
bsz = latents.shape[0]
609615
# Sample a random timestep for each image
610616
timesteps = torch.randint(
@@ -616,7 +622,7 @@ def main():
616622
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
617623

618624
# Get the text embedding for conditioning
619-
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
625+
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
620626

621627
# Predict the noise residual
622628
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -629,7 +635,7 @@ def main():
629635
else:
630636
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
631637

632-
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
638+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
633639
accelerator.backward(loss)
634640

635641
optimizer.step()

0 commit comments

Comments
 (0)