Skip to content

Commit e001fed

Browse files
authored
Fix dreambooth loss type with prior_preservation and fp16 (open-mmlab#826)
Fix dreambooth loss type with prior preservation
1 parent 0a09af2 commit e001fed

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def collate_fn(examples):
544544
noise, noise_prior = torch.chunk(noise, 2, dim=0)
545545

546546
# Compute instance loss
547-
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
547+
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
548548

549549
# Compute prior loss
550550
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")

0 commit comments

Comments
 (0)