Skip to content

Commit d558811

Browse files
bghirabghirasayakpaul
authored
Min-SNR gamma support for Dreambooth training (huggingface#5107)
* min-SNR gamma for Dreambooth training * Align the mse_loss_weights style with SDXL training example --------- Co-authored-by: bghira <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 157c901 commit d558811

File tree

1 file changed

+53
-5
lines changed

1 file changed

+53
-5
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,30 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
224224
raise ValueError(f"{model_class} is not supported.")
225225

226226

227+
def compute_snr(timesteps, noise_scheduler):
228+
"""
229+
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
230+
"""
231+
alphas_cumprod = noise_scheduler.alphas_cumprod
232+
sqrt_alphas_cumprod = alphas_cumprod**0.5
233+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
234+
# Expand the tensors.
235+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
236+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
237+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
238+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
239+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
240+
241+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
242+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
243+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
244+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
245+
246+
# Compute SNR
247+
snr = (alpha / sigma) ** 2
248+
return snr
249+
250+
227251
def parse_args(input_args=None):
228252
parser = argparse.ArgumentParser(description="Simple example of a training script.")
229253
parser.add_argument(
@@ -524,6 +548,13 @@ def parse_args(input_args=None):
524548
" See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
525549
),
526550
)
551+
parser.add_argument(
552+
"--snr_gamma",
553+
type=float,
554+
default=None,
555+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
556+
"More details here: https://arxiv.org/abs/2303.09556.",
557+
)
527558
parser.add_argument(
528559
"--pre_compute_text_embeddings",
529560
action="store_true",
@@ -1261,17 +1292,34 @@ def compute_text_embeddings(prompt):
12611292
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
12621293
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
12631294
target, target_prior = torch.chunk(target, 2, dim=0)
1295+
# Compute prior loss
1296+
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
12641297

1265-
# Compute instance loss
1298+
# Compute instance loss
1299+
if args.snr_gamma is None:
12661300
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1301+
else:
1302+
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1303+
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
1304+
# This is discussed in Section 4.2 of the same paper.
1305+
snr = compute_snr(timesteps, noise_scheduler)
1306+
base_weight = (
1307+
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1308+
)
12671309

1268-
# Compute prior loss
1269-
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
1310+
if noise_scheduler.config.prediction_type == "v_prediction":
1311+
# Velocity objective needs to be floored to an SNR weight of one.
1312+
mse_loss_weights = base_weight + 1
1313+
else:
1314+
# Epsilon and sample both use the same loss weights.
1315+
mse_loss_weights = base_weight
1316+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1317+
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1318+
loss = loss.mean()
12701319

1320+
if args.with_prior_preservation:
12711321
# Add the prior loss to the instance loss.
12721322
loss = loss + args.prior_loss_weight * prior_loss
1273-
else:
1274-
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
12751323

12761324
accelerator.backward(loss)
12771325
if accelerator.sync_gradients:

0 commit comments

Comments
 (0)