Skip to content

Commit b4077af

Browse files
[bug fix] using snr gamma and prior preservation loss in the dreambooth lora sdxl training scripts (huggingface#6356)
* change timesteps used to calculate snr when --with_prior_preservation is enabled * change timesteps used to calculate snr when --with_prior_preservation is enabled (canonical script) * style * revert canonical script to before snr gamma change --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 9f2bff5 commit b4077af

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,9 +1819,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18191819
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
18201820
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
18211821
# This is discussed in Section 4.2 of the same paper.
1822-
snr = compute_snr(noise_scheduler, timesteps)
1822+
1823+
if args.with_prior_preservation:
1824+
# if we're using prior preservation, we calc snr for instance loss only -
1825+
# and hence only need timesteps corresponding to instance images
1826+
snr_timesteps, _ = torch.chunk(timesteps, 2, dim=0)
1827+
else:
1828+
snr_timesteps = timesteps
1829+
1830+
snr = compute_snr(noise_scheduler, snr_timesteps)
18231831
base_weight = (
1824-
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1832+
torch.stack([snr, args.snr_gamma * torch.ones_like(snr_timesteps)], dim=1).min(dim=1)[0] / snr
18251833
)
18261834

18271835
if noise_scheduler.config.prediction_type == "v_prediction":

0 commit comments

Comments
 (0)