Skip to content

Commit 4a06c74

Browse files
bghirabghirasayakpaul
authored
Min-SNR Gamma: follow-up fix for zero-terminal SNR models on v-prediction or epsilon (huggingface#5177)
* merge with main * fix flax example * fix onnx example --------- Co-authored-by: bghira <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 89d8f84 commit 4a06c74

File tree

10 files changed

+121
-18
lines changed

10 files changed

+121
-18
lines changed

examples/controlnet/train_controlnet_flax.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -907,10 +907,17 @@ def compute_loss(params, minibatch, sample_rng):
907907

908908
if args.snr_gamma is not None:
909909
snr = jnp.array(compute_snr(timesteps))
910-
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
910+
base_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
911911
if noise_scheduler.config.prediction_type == "v_prediction":
912-
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
913-
snr_loss_weights = snr_loss_weights + 1
912+
snr_loss_weights = base_weights + 1
913+
else:
914+
# Epsilon and sample prediction use the base weights.
915+
snr_loss_weights = base_weights
916+
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
917+
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
918+
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
919+
snr_loss_weights[snr == 0] = 1.0
920+
914921
loss = loss * snr_loss_weights
915922

916923
loss = loss.mean()

examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -801,9 +801,22 @@ def collate_fn(examples):
801801
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
802802
# This is discussed in Section 4.2 of the same paper.
803803
snr = compute_snr(timesteps)
804-
mse_loss_weights = (
804+
base_weight = (
805805
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
806806
)
807+
808+
if noise_scheduler.config.prediction_type == "v_prediction":
809+
# Velocity objective needs to be floored to an SNR weight of one.
810+
mse_loss_weights = base_weight + 1
811+
else:
812+
# Epsilon and sample both use the same loss weights.
813+
mse_loss_weights = base_weight
814+
815+
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
816+
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
817+
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
818+
mse_loss_weights[snr == 0] = 1.0
819+
807820
# We first calculate the original loss. Then we mean over the non-batch dimensions and
808821
# rebalance the sample-wise losses with their respective loss weights.
809822
# Finally, we take the mean of the rebalanced loss.

examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,9 +654,22 @@ def collate_fn(examples):
654654
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
655655
# This is discussed in Section 4.2 of the same paper.
656656
snr = compute_snr(timesteps)
657-
mse_loss_weights = (
657+
base_weight = (
658658
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
659659
)
660+
661+
if noise_scheduler.config.prediction_type == "v_prediction":
662+
# Velocity objective needs to be floored to an SNR weight of one.
663+
mse_loss_weights = base_weight + 1
664+
else:
665+
# Epsilon and sample both use the same loss weights.
666+
mse_loss_weights = base_weight
667+
668+
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
669+
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
670+
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
671+
mse_loss_weights[snr == 0] = 1.0
672+
660673
# We first calculate the original loss. Then we mean over the non-batch dimensions and
661674
# rebalance the sample-wise losses with their respective loss weights.
662675
# Finally, we take the mean of the rebalanced loss.

examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,9 +685,22 @@ def collate_fn(examples):
685685
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
686686
# This is discussed in Section 4.2 of the same paper.
687687
snr = compute_snr(timesteps)
688-
mse_loss_weights = (
688+
base_weight = (
689689
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
690690
)
691+
692+
if noise_scheduler.config.prediction_type == "v_prediction":
693+
# Velocity objective needs to be floored to an SNR weight of one.
694+
mse_loss_weights = base_weight + 1
695+
else:
696+
# Epsilon and sample both use the same loss weights.
697+
mse_loss_weights = base_weight
698+
699+
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
700+
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
701+
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
702+
mse_loss_weights[snr == 0] = 1.0
703+
691704
# We first calculate the original loss. Then we mean over the non-batch dimensions and
692705
# rebalance the sample-wise losses with their respective loss weights.
693706
# Finally, we take the mean of the rebalanced loss.

examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,9 +833,22 @@ def collate_fn(examples):
833833
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
834834
# This is discussed in Section 4.2 of the same paper.
835835
snr = compute_snr(timesteps)
836-
mse_loss_weights = (
836+
base_weight = (
837837
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
838838
)
839+
840+
if noise_scheduler.config.prediction_type == "v_prediction":
841+
# Velocity objective needs to be floored to an SNR weight of one.
842+
mse_loss_weights = base_weight + 1
843+
else:
844+
# Epsilon and sample both use the same loss weights.
845+
mse_loss_weights = base_weight
846+
847+
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
848+
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
849+
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
850+
mse_loss_weights[snr == 0] = 1.0
851+
839852
# We first calculate the original loss. Then we mean over the non-batch dimensions and
840853
# rebalance the sample-wise losses with their respective loss weights.
841854
# Finally, we take the mean of the rebalanced loss.

examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -872,12 +872,21 @@ def collate_fn(examples):
872872
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
873873
# This is discussed in Section 4.2 of the same paper.
874874
snr = compute_snr(timesteps)
875-
mse_loss_weights = (
875+
base_weight = (
876876
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
877877
)
878878
if noise_scheduler.config.prediction_type == "v_prediction":
879879
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
880-
mse_loss_weights = mse_loss_weights + 1
880+
mse_loss_weights = base_weight + 1
881+
else:
882+
# Epsilon and sample prediction use the base weights.
883+
mse_loss_weights = base_weight
884+
885+
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
886+
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
887+
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
888+
mse_loss_weights[snr == 0] = 1.0
889+
881890
# We first calculate the original loss. Then we mean over the non-batch dimensions and
882891
# rebalance the sample-wise losses with their respective loss weights.
883892
# Finally, we take the mean of the rebalanced loss.

examples/text_to_image/train_text_to_image.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -952,12 +952,22 @@ def collate_fn(examples):
952952
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
953953
# This is discussed in Section 4.2 of the same paper.
954954
snr = compute_snr(timesteps)
955-
mse_loss_weights = (
955+
base_weight = (
956956
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
957957
)
958+
958959
if noise_scheduler.config.prediction_type == "v_prediction":
959-
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
960-
mse_loss_weights = mse_loss_weights + 1
960+
# Velocity objective needs to be floored to an SNR weight of one.
961+
mse_loss_weights = base_weight + 1
962+
else:
963+
# Epsilon and sample both use the same loss weights.
964+
mse_loss_weights = base_weight
965+
966+
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
967+
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
968+
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
969+
mse_loss_weights[snr == 0] = 1.0
970+
961971
# We first calculate the original loss. Then we mean over the non-batch dimensions and
962972
# rebalance the sample-wise losses with their respective loss weights.
963973
# Finally, we take the mean of the rebalanced loss.

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -783,12 +783,22 @@ def collate_fn(examples):
783783
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
784784
# This is discussed in Section 4.2 of the same paper.
785785
snr = compute_snr(timesteps)
786-
mse_loss_weights = (
786+
base_weight = (
787787
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
788788
)
789+
789790
if noise_scheduler.config.prediction_type == "v_prediction":
790-
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
791-
mse_loss_weights = mse_loss_weights + 1
791+
# Velocity objective needs to be floored to an SNR weight of one.
792+
mse_loss_weights = base_weight + 1
793+
else:
794+
# Epsilon and sample both use the same loss weights.
795+
mse_loss_weights = base_weight
796+
797+
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
798+
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
799+
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
800+
mse_loss_weights[snr == 0] = 1.0
801+
792802
# We first calculate the original loss. Then we mean over the non-batch dimensions and
793803
# rebalance the sample-wise losses with their respective loss weights.
794804
# Finally, we take the mean of the rebalanced loss.

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,12 +1072,22 @@ def compute_time_ids(original_size, crops_coords_top_left):
10721072
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
10731073
# This is discussed in Section 4.2 of the same paper.
10741074
snr = compute_snr(timesteps)
1075-
mse_loss_weights = (
1075+
base_weight = (
10761076
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
10771077
)
1078+
10781079
if noise_scheduler.config.prediction_type == "v_prediction":
1079-
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
1080-
mse_loss_weights = mse_loss_weights + 1
1080+
# Velocity objective needs to be floored to an SNR weight of one.
1081+
mse_loss_weights = base_weight + 1
1082+
else:
1083+
# Epsilon and sample both use the same loss weights.
1084+
mse_loss_weights = base_weight
1085+
1086+
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
1087+
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
1088+
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
1089+
mse_loss_weights[snr == 0] = 1.0
1090+
10811091
# We first calculate the original loss. Then we mean over the non-batch dimensions and
10821092
# rebalance the sample-wise losses with their respective loss weights.
10831093
# Finally, we take the mean of the rebalanced loss.

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,11 @@ def compute_time_ids(original_size, crops_coords_top_left):
11001100
# Epsilon and sample both use the same loss weights.
11011101
mse_loss_weights = base_weight
11021102

1103+
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
1104+
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
1105+
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
1106+
mse_loss_weights[snr == 0] = 1.0
1107+
11031108
# We first calculate the original loss. Then we mean over the non-batch dimensions and
11041109
# rebalance the sample-wise losses with their respective loss weights.
11051110
# Finally, we take the mean of the rebalanced loss.

0 commit comments

Comments
 (0)