4646from diffusers .training_utils import EMAModel , compute_snr
4747from diffusers .utils import check_min_version , deprecate , is_wandb_available , make_image_grid
4848from diffusers .utils .import_utils import is_xformers_available
49+ from diffusers .utils .torch_utils import is_compiled_module
4950
5051
5152if is_wandb_available ():
@@ -833,6 +834,12 @@ def collate_fn(examples):
833834 tracker_config .pop ("validation_prompts" )
834835 accelerator .init_trackers (args .tracker_project_name , tracker_config )
835836
837+ # Function for unwrapping if model was compiled with `torch.compile`.
838+ def unwrap_model (model ):
839+ model = accelerator .unwrap_model (model )
840+ model = model ._orig_mod if is_compiled_module (model ) else model
841+ return model
842+
836843 # Train!
837844 total_batch_size = args .train_batch_size * accelerator .num_processes * args .gradient_accumulation_steps
838845
@@ -912,7 +919,7 @@ def collate_fn(examples):
912919 noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
913920
914921 # Get the text embedding for conditioning
915- encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
922+ encoder_hidden_states = text_encoder (batch ["input_ids" ], return_dict = False )[0 ]
916923
917924 # Get the target for loss depending on the prediction type
918925 if args .prediction_type is not None :
@@ -927,7 +934,7 @@ def collate_fn(examples):
927934 raise ValueError (f"Unknown prediction type { noise_scheduler .config .prediction_type } " )
928935
929936 # Predict the noise residual and compute loss
930- model_pred = unet (noisy_latents , timesteps , encoder_hidden_states ). sample
937+ model_pred = unet (noisy_latents , timesteps , encoder_hidden_states , return_dict = False )[ 0 ]
931938
932939 if args .snr_gamma is None :
933940 loss = F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
@@ -1023,7 +1030,7 @@ def collate_fn(examples):
10231030 # Create the pipeline using the trained modules and save it.
10241031 accelerator .wait_for_everyone ()
10251032 if accelerator .is_main_process :
1026- unet = accelerator . unwrap_model (unet )
1033+ unet = unwrap_model (unet )
10271034 if args .use_ema :
10281035 ema_unet .copy_to (unet .parameters ())
10291036
0 commit comments