Skip to content

Commit 02a8d66

Browse files
bghirabghirasayakpaul
authored
Min-SNR Gamma: correct the fix for SNR weighted loss in v-prediction … (huggingface#5238)
Min-SNR Gamma: correct the fix for SNR weighted loss in v-prediction by adding 1 to SNR rather than the resulting loss weights Co-authored-by: bghira <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent e6faf60 commit 02a8d66

File tree

10 files changed

+40
-154
lines changed

10 files changed

+40
-154
lines changed

examples/controlnet/train_controlnet_flax.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -907,17 +907,10 @@ def compute_loss(params, minibatch, sample_rng):
907907

908908
if args.snr_gamma is not None:
909909
snr = jnp.array(compute_snr(timesteps))
910-
base_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
911910
if noise_scheduler.config.prediction_type == "v_prediction":
912-
snr_loss_weights = base_weights + 1
913-
else:
914-
# Epsilon and sample prediction use the base weights.
915-
snr_loss_weights = base_weights
916-
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
917-
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
918-
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
919-
snr_loss_weights[snr == 0] = 1.0
920-
911+
# Velocity objective requires that we add one to SNR values before we divide by them.
912+
snr = snr + 1
913+
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
921914
loss = loss * snr_loss_weights
922915

923916
loss = loss.mean()

examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -781,25 +781,13 @@ def collate_fn(examples):
781781
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
782782
# This is discussed in Section 4.2 of the same paper.
783783
snr = compute_snr(noise_scheduler, timesteps)
784-
base_weight = (
784+
if noise_scheduler.config.prediction_type == "v_prediction":
785+
# Velocity objective requires that we add one to SNR values before we divide by them.
786+
snr = snr + 1
787+
mse_loss_weights = (
785788
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
786789
)
787790

788-
if noise_scheduler.config.prediction_type == "v_prediction":
789-
# Velocity objective needs to be floored to an SNR weight of one.
790-
mse_loss_weights = base_weight + 1
791-
else:
792-
# Epsilon and sample both use the same loss weights.
793-
mse_loss_weights = base_weight
794-
795-
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
796-
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
797-
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
798-
mse_loss_weights[snr == 0] = 1.0
799-
800-
# We first calculate the original loss. Then we mean over the non-batch dimensions and
801-
# rebalance the sample-wise losses with their respective loss weights.
802-
# Finally, we take the mean of the rebalanced loss.
803791
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
804792
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
805793
loss = loss.mean()

examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -631,25 +631,13 @@ def collate_fn(examples):
631631
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
632632
# This is discussed in Section 4.2 of the same paper.
633633
snr = compute_snr(noise_scheduler, timesteps)
634-
base_weight = (
634+
if noise_scheduler.config.prediction_type == "v_prediction":
635+
# Velocity objective requires that we add one to SNR values before we divide by them.
636+
snr = snr + 1
637+
mse_loss_weights = (
635638
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
636639
)
637640

638-
if noise_scheduler.config.prediction_type == "v_prediction":
639-
# Velocity objective needs to be floored to an SNR weight of one.
640-
mse_loss_weights = base_weight + 1
641-
else:
642-
# Epsilon and sample both use the same loss weights.
643-
mse_loss_weights = base_weight
644-
645-
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
646-
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
647-
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
648-
mse_loss_weights[snr == 0] = 1.0
649-
650-
# We first calculate the original loss. Then we mean over the non-batch dimensions and
651-
# rebalance the sample-wise losses with their respective loss weights.
652-
# Finally, we take the mean of the rebalanced loss.
653641
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
654642
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
655643
loss = loss.mean()

examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -664,25 +664,13 @@ def collate_fn(examples):
664664
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
665665
# This is discussed in Section 4.2 of the same paper.
666666
snr = compute_snr(noise_scheduler, timesteps)
667-
base_weight = (
667+
if noise_scheduler.config.prediction_type == "v_prediction":
668+
# Velocity objective requires that we add one to SNR values before we divide by them.
669+
snr = snr + 1
670+
mse_loss_weights = (
668671
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
669672
)
670673

671-
if noise_scheduler.config.prediction_type == "v_prediction":
672-
# Velocity objective needs to be floored to an SNR weight of one.
673-
mse_loss_weights = base_weight + 1
674-
else:
675-
# Epsilon and sample both use the same loss weights.
676-
mse_loss_weights = base_weight
677-
678-
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
679-
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
680-
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
681-
mse_loss_weights[snr == 0] = 1.0
682-
683-
# We first calculate the original loss. Then we mean over the non-batch dimensions and
684-
# rebalance the sample-wise losses with their respective loss weights.
685-
# Finally, we take the mean of the rebalanced loss.
686674
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
687675
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
688676
loss = loss.mean()

examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -811,25 +811,13 @@ def collate_fn(examples):
811811
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
812812
# This is discussed in Section 4.2 of the same paper.
813813
snr = compute_snr(noise_scheduler, timesteps)
814-
base_weight = (
814+
if noise_scheduler.config.prediction_type == "v_prediction":
815+
# Velocity objective requires that we add one to SNR values before we divide by them.
816+
snr = snr + 1
817+
mse_loss_weights = (
815818
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
816819
)
817820

818-
if noise_scheduler.config.prediction_type == "v_prediction":
819-
# Velocity objective needs to be floored to an SNR weight of one.
820-
mse_loss_weights = base_weight + 1
821-
else:
822-
# Epsilon and sample both use the same loss weights.
823-
mse_loss_weights = base_weight
824-
825-
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
826-
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
827-
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
828-
mse_loss_weights[snr == 0] = 1.0
829-
830-
# We first calculate the original loss. Then we mean over the non-batch dimensions and
831-
# rebalance the sample-wise losses with their respective loss weights.
832-
# Finally, we take the mean of the rebalanced loss.
833821
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
834822
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
835823
loss = loss.mean()

examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -848,24 +848,13 @@ def collate_fn(examples):
848848
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
849849
# This is discussed in Section 4.2 of the same paper.
850850
snr = compute_snr(noise_scheduler, timesteps)
851-
base_weight = (
851+
if noise_scheduler.config.prediction_type == "v_prediction":
852+
# Velocity objective requires that we add one to SNR values before we divide by them.
853+
snr = snr + 1
854+
mse_loss_weights = (
852855
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
853856
)
854-
if noise_scheduler.config.prediction_type == "v_prediction":
855-
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
856-
mse_loss_weights = base_weight + 1
857-
else:
858-
# Epsilon and sample prediction use the base weights.
859-
mse_loss_weights = base_weight
860-
861-
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
862-
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
863-
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
864-
mse_loss_weights[snr == 0] = 1.0
865-
866-
# We first calculate the original loss. Then we mean over the non-batch dimensions and
867-
# rebalance the sample-wise losses with their respective loss weights.
868-
# Finally, we take the mean of the rebalanced loss.
857+
869858
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
870859
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
871860
loss = loss.mean()

examples/text_to_image/train_text_to_image.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -929,25 +929,13 @@ def collate_fn(examples):
929929
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
930930
# This is discussed in Section 4.2 of the same paper.
931931
snr = compute_snr(noise_scheduler, timesteps)
932-
base_weight = (
932+
if noise_scheduler.config.prediction_type == "v_prediction":
933+
# Velocity objective requires that we add one to SNR values before we divide by them.
934+
snr = snr + 1
935+
mse_loss_weights = (
933936
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
934937
)
935938

936-
if noise_scheduler.config.prediction_type == "v_prediction":
937-
# Velocity objective needs to be floored to an SNR weight of one.
938-
mse_loss_weights = base_weight + 1
939-
else:
940-
# Epsilon and sample both use the same loss weights.
941-
mse_loss_weights = base_weight
942-
943-
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
944-
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
945-
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
946-
mse_loss_weights[snr == 0] = 1.0
947-
948-
# We first calculate the original loss. Then we mean over the non-batch dimensions and
949-
# rebalance the sample-wise losses with their respective loss weights.
950-
# Finally, we take the mean of the rebalanced loss.
951939
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
952940
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
953941
loss = loss.mean()

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -759,25 +759,13 @@ def collate_fn(examples):
759759
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
760760
# This is discussed in Section 4.2 of the same paper.
761761
snr = compute_snr(noise_scheduler, timesteps)
762-
base_weight = (
762+
if noise_scheduler.config.prediction_type == "v_prediction":
763+
# Velocity objective requires that we add one to SNR values before we divide by them.
764+
snr = snr + 1
765+
mse_loss_weights = (
763766
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
764767
)
765768

766-
if noise_scheduler.config.prediction_type == "v_prediction":
767-
# Velocity objective needs to be floored to an SNR weight of one.
768-
mse_loss_weights = base_weight + 1
769-
else:
770-
# Epsilon and sample both use the same loss weights.
771-
mse_loss_weights = base_weight
772-
773-
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
774-
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
775-
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
776-
mse_loss_weights[snr == 0] = 1.0
777-
778-
# We first calculate the original loss. Then we mean over the non-batch dimensions and
779-
# rebalance the sample-wise losses with their respective loss weights.
780-
# Finally, we take the mean of the rebalanced loss.
781769
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
782770
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
783771
loss = loss.mean()

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,25 +1050,13 @@ def compute_time_ids(original_size, crops_coords_top_left):
10501050
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
10511051
# This is discussed in Section 4.2 of the same paper.
10521052
snr = compute_snr(noise_scheduler, timesteps)
1053-
base_weight = (
1053+
if noise_scheduler.config.prediction_type == "v_prediction":
1054+
# Velocity objective requires that we add one to SNR values before we divide by them.
1055+
snr = snr + 1
1056+
mse_loss_weights = (
10541057
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
10551058
)
10561059

1057-
if noise_scheduler.config.prediction_type == "v_prediction":
1058-
# Velocity objective needs to be floored to an SNR weight of one.
1059-
mse_loss_weights = base_weight + 1
1060-
else:
1061-
# Epsilon and sample both use the same loss weights.
1062-
mse_loss_weights = base_weight
1063-
1064-
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
1065-
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
1066-
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
1067-
mse_loss_weights[snr == 0] = 1.0
1068-
1069-
# We first calculate the original loss. Then we mean over the non-batch dimensions and
1070-
# rebalance the sample-wise losses with their respective loss weights.
1071-
# Finally, we take the mean of the rebalanced loss.
10721060
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
10731061
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
10741062
loss = loss.mean()

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,25 +1067,13 @@ def compute_time_ids(original_size, crops_coords_top_left):
10671067
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
10681068
# This is discussed in Section 4.2 of the same paper.
10691069
snr = compute_snr(noise_scheduler, timesteps)
1070-
base_weight = (
1070+
if noise_scheduler.config.prediction_type == "v_prediction":
1071+
# Velocity objective requires that we add one to SNR values before we divide by them.
1072+
snr = snr + 1
1073+
mse_loss_weights = (
10711074
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
10721075
)
10731076

1074-
if noise_scheduler.config.prediction_type == "v_prediction":
1075-
# Velocity objective needs to be floored to an SNR weight of one.
1076-
mse_loss_weights = base_weight + 1
1077-
else:
1078-
# Epsilon and sample both use the same loss weights.
1079-
mse_loss_weights = base_weight
1080-
1081-
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
1082-
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
1083-
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
1084-
mse_loss_weights[snr == 0] = 1.0
1085-
1086-
# We first calculate the original loss. Then we mean over the non-batch dimensions and
1087-
# rebalance the sample-wise losses with their respective loss weights.
1088-
# Finally, we take the mean of the rebalanced loss.
10891077
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
10901078
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
10911079
loss = loss.mean()

0 commit comments

Comments
 (0)