@@ -332,15 +332,6 @@ def parse_args(input_args=None):
332332 help = "SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
333333 "More details here: https://arxiv.org/abs/2303.09556." ,
334334 )
335- parser .add_argument (
336- "--force_snr_gamma" ,
337- action = "store_true" ,
338- help = (
339- "When using SNR gamma with rescaled betas for zero terminal SNR, a divide-by-zero error can cause NaN"
340- " condition when computing the SNR with a sigma value of zero. This parameter overrides the check,"
341- " allowing the use of SNR gamma with a terminal SNR model. Use with caution, and closely monitor results."
342- ),
343- )
344335 parser .add_argument ("--use_ema" , action = "store_true" , help = "Whether to use EMA model." )
345336 parser .add_argument (
346337 "--allow_tf32" ,
@@ -554,18 +545,6 @@ def main(args):
554545 # Load scheduler and models
555546 noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
556547 # Check for terminal SNR in combination with SNR Gamma
557- if (
558- args .snr_gamma
559- and not args .force_snr_gamma
560- and (
561- hasattr (noise_scheduler .config , "rescale_betas_zero_snr" ) and noise_scheduler .config .rescale_betas_zero_snr
562- )
563- ):
564- raise ValueError (
565- f"The selected noise scheduler for the model { args .pretrained_model_name_or_path } uses rescaled betas for zero SNR.\n "
566- "When this configuration is present, the parameter --snr_gamma may not be used without parameter --force_snr_gamma.\n "
567- "This is due to a mathematical incompatibility between our current SNR gamma implementation, and a sigma value of zero."
568- )
569548 text_encoder_one = text_encoder_cls_one .from_pretrained (
570549 args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision
571550 )
@@ -1013,9 +992,17 @@ def compute_time_ids(original_size, crops_coords_top_left):
1013992 # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1014993 # This is discussed in Section 4.2 of the same paper.
1015994 snr = compute_snr (timesteps )
1016- mse_loss_weights = (
995+ base_weight = (
1017996 torch .stack ([snr , args .snr_gamma * torch .ones_like (timesteps )], dim = 1 ).min (dim = 1 )[0 ] / snr
1018997 )
998+
999+ if noise_scheduler .config .prediction_type == "v_prediction" :
1000+ # Velocity objective needs to be floored to an SNR weight of one.
1001+ mse_loss_weights = base_weight + 1
1002+ else :
1003+ # Epsilon and sample both use the same loss weights.
1004+ mse_loss_weights = base_weight
1005+
10191006 # We first calculate the original loss. Then we mean over the non-batch dimensions and
10201007 # rebalance the sample-wise losses with their respective loss weights.
10211008 # Finally, we take the mean of the rebalanced loss.
0 commit comments