6161 UNet2DConditionModel ,
6262)
6363from diffusers .optimization import get_scheduler
64+ from diffusers .training_utils import resolve_interpolation_mode
6465from diffusers .utils import check_min_version , is_wandb_available
6566from diffusers .utils .import_utils import is_xformers_available
6667
@@ -165,6 +166,7 @@ def __init__(
165166 global_batch_size : int ,
166167 num_workers : int ,
167168 resolution : int = 512 ,
169+ interpolation_type : str = "bilinear" ,
168170 shuffle_buffer_size : int = 1000 ,
169171 pin_memory : bool = False ,
170172 persistent_workers : bool = False ,
@@ -174,10 +176,12 @@ def __init__(
174176 # flatten list using itertools
175177 train_shards_path_or_url = list (itertools .chain .from_iterable (train_shards_path_or_url ))
176178
179+ interpolation_mode = resolve_interpolation_mode (interpolation_type )
180+
177181 def transform (example ):
178182 # resize image
179183 image = example ["image" ]
180- image = TF .resize (image , resolution , interpolation = transforms . InterpolationMode . BILINEAR )
184+ image = TF .resize (image , resolution , interpolation = interpolation_mode )
181185
182186 # get crop coordinates and crop image
183187 c_top , c_left , _ , _ = transforms .RandomCrop .get_params (image , output_size = (resolution , resolution ))
@@ -353,8 +357,9 @@ def append_dims(x, target_dims):
353357
354358# From LCMScheduler.get_scalings_for_boundary_condition_discrete
355359def scalings_for_boundary_conditions (timestep , sigma_data = 0.5 , timestep_scaling = 10.0 ):
356- c_skip = sigma_data ** 2 / ((timestep / 0.1 ) ** 2 + sigma_data ** 2 )
357- c_out = (timestep / 0.1 ) / ((timestep / 0.1 ) ** 2 + sigma_data ** 2 ) ** 0.5
360+ scaled_timestep = timestep_scaling * timestep
361+ c_skip = sigma_data ** 2 / (scaled_timestep ** 2 + sigma_data ** 2 )
362+ c_out = scaled_timestep / (scaled_timestep ** 2 + sigma_data ** 2 ) ** 0.5
358363 return c_skip , c_out
359364
360365
@@ -572,6 +577,15 @@ def parse_args():
572577 " resolution"
573578 ),
574579 )
580+ parser .add_argument (
581+ "--interpolation_type" ,
582+ type = str ,
583+ default = "bilinear" ,
584+ help = (
585+ "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
586+ " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
587+ ),
588+ )
575589 parser .add_argument (
576590 "--center_crop" ,
577591 default = False ,
@@ -710,6 +724,50 @@ def parse_args():
710724 default = 64 ,
711725 help = "The rank of the LoRA projection matrix." ,
712726 )
727+ parser .add_argument (
728+ "--lora_alpha" ,
729+ type = int ,
730+ default = 64 ,
731+ help = (
732+ "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
733+ " update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
734+ ),
735+ )
736+ parser .add_argument (
737+ "--lora_dropout" ,
738+ type = float ,
739+ default = 0.0 ,
740+ help = "The dropout probability for the dropout layer added before applying the LoRA to each layer input." ,
741+ )
742+ parser .add_argument (
743+ "--lora_target_modules" ,
744+ type = str ,
745+ default = None ,
746+ help = (
747+ "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
748+ " be used. By default, LoRA will be applied to all conv and linear layers."
749+ ),
750+ )
751+ parser .add_argument (
752+ "--vae_encode_batch_size" ,
753+ type = int ,
754+ default = 32 ,
755+ required = False ,
756+ help = (
757+ "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
758+ " Encoding or decoding the whole batch at once may run into OOM issues."
759+ ),
760+ )
761+ parser .add_argument (
762+ "--timestep_scaling_factor" ,
763+ type = float ,
764+ default = 10.0 ,
765+ help = (
766+ "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
767+ " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
768+ " suffice."
769+ ),
770+ )
713771 # ----Mixed Precision----
714772 parser .add_argument (
715773 "--mixed_precision" ,
@@ -915,9 +973,10 @@ def main(args):
915973 )
916974
917975 # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
918- lora_config = LoraConfig (
919- r = args .lora_rank ,
920- target_modules = [
976+ if args .lora_target_modules is not None :
977+ lora_target_modules = [module_key .strip () for module_key in args .lora_target_modules .split ("," )]
978+ else :
979+ lora_target_modules = [
921980 "to_q" ,
922981 "to_k" ,
923982 "to_v" ,
@@ -932,7 +991,12 @@ def main(args):
932991 "downsamplers.0.conv" ,
933992 "upsamplers.0.conv" ,
934993 "time_emb_proj" ,
935- ],
994+ ]
995+ lora_config = LoraConfig (
996+ r = args .lora_rank ,
997+ target_modules = lora_target_modules ,
998+ lora_alpha = args .lora_alpha ,
999+ lora_dropout = args .lora_dropout ,
9361000 )
9371001 unet = get_peft_model (unet , lora_config )
9381002
@@ -1051,6 +1115,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
10511115 global_batch_size = args .train_batch_size * accelerator .num_processes ,
10521116 num_workers = args .dataloader_num_workers ,
10531117 resolution = args .resolution ,
1118+ interpolation_type = args .interpolation_type ,
10541119 shuffle_buffer_size = 1000 ,
10551120 pin_memory = True ,
10561121 persistent_workers = True ,
@@ -1162,10 +1227,10 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
11621227 if vae .dtype != weight_dtype :
11631228 vae .to (dtype = weight_dtype )
11641229
1165- # encode pixel values with batch size of at most 32
1230+ # encode pixel values with batch size of at most args.vae_encode_batch_size
11661231 latents = []
1167- for i in range (0 , pixel_values .shape [0 ], 32 ):
1168- latents .append (vae .encode (pixel_values [i : i + 32 ]).latent_dist .sample ())
1232+ for i in range (0 , pixel_values .shape [0 ], args . vae_encode_batch_size ):
1233+ latents .append (vae .encode (pixel_values [i : i + args . vae_encode_batch_size ]).latent_dist .sample ())
11691234 latents = torch .cat (latents , dim = 0 )
11701235
11711236 latents = latents * vae .config .scaling_factor
@@ -1181,9 +1246,13 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
11811246 timesteps = torch .where (timesteps < 0 , torch .zeros_like (timesteps ), timesteps )
11821247
11831248 # 3. Get boundary scalings for start_timesteps and (end) timesteps.
1184- c_skip_start , c_out_start = scalings_for_boundary_conditions (start_timesteps )
1249+ c_skip_start , c_out_start = scalings_for_boundary_conditions (
1250+ start_timesteps , timestep_scaling = args .timestep_scaling_factor
1251+ )
11851252 c_skip_start , c_out_start = [append_dims (x , latents .ndim ) for x in [c_skip_start , c_out_start ]]
1186- c_skip , c_out = scalings_for_boundary_conditions (timesteps )
1253+ c_skip , c_out = scalings_for_boundary_conditions (
1254+ timesteps , timestep_scaling = args .timestep_scaling_factor
1255+ )
11871256 c_skip , c_out = [append_dims (x , latents .ndim ) for x in [c_skip , c_out ]]
11881257
11891258 # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
0 commit comments