5454from diffusers .training_utils import cast_training_params , compute_snr
5555from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
5656from diffusers .utils .import_utils import is_xformers_available
57+ from diffusers .utils .torch_utils import is_compiled_module
5758
5859
5960# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -460,13 +461,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
460461 text_input_ids = text_input_ids_list [i ]
461462
462463 prompt_embeds = text_encoder (
463- text_input_ids .to (text_encoder .device ),
464- output_hidden_states = True ,
464+ text_input_ids .to (text_encoder .device ), output_hidden_states = True , return_dict = False
465465 )
466466
467467 # We are only ALWAYS interested in the pooled output of the final text encoder
468468 pooled_prompt_embeds = prompt_embeds [0 ]
469- prompt_embeds = prompt_embeds . hidden_states [- 2 ]
469+ prompt_embeds = prompt_embeds [ - 1 ] [- 2 ]
470470 bs_embed , seq_len , _ = prompt_embeds .shape
471471 prompt_embeds = prompt_embeds .view (bs_embed , seq_len , - 1 )
472472 prompt_embeds_list .append (prompt_embeds )
@@ -637,6 +637,11 @@ def main(args):
637637 # only upcast trainable parameters (LoRA) into fp32
638638 cast_training_params (models , dtype = torch .float32 )
639639
640+ def unwrap_model (model ):
641+ model = accelerator .unwrap_model (model )
642+ model = model ._orig_mod if is_compiled_module (model ) else model
643+ return model
644+
640645 # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
641646 def save_model_hook (models , weights , output_dir ):
642647 if accelerator .is_main_process :
@@ -647,13 +652,13 @@ def save_model_hook(models, weights, output_dir):
647652 text_encoder_two_lora_layers_to_save = None
648653
649654 for model in models :
650- if isinstance (model , type (accelerator . unwrap_model (unet ))):
655+ if isinstance (model , type (unwrap_model (unet ))):
651656 unet_lora_layers_to_save = convert_state_dict_to_diffusers (get_peft_model_state_dict (model ))
652- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_one ))):
657+ elif isinstance (model , type (unwrap_model (text_encoder_one ))):
653658 text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers (
654659 get_peft_model_state_dict (model )
655660 )
656- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_two ))):
661+ elif isinstance (model , type (unwrap_model (text_encoder_two ))):
657662 text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers (
658663 get_peft_model_state_dict (model )
659664 )
@@ -678,11 +683,11 @@ def load_model_hook(models, input_dir):
678683 while len (models ) > 0 :
679684 model = models .pop ()
680685
681- if isinstance (model , type (accelerator . unwrap_model (unet ))):
686+ if isinstance (model , type (unwrap_model (unet ))):
682687 unet_ = model
683- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_one ))):
688+ elif isinstance (model , type (unwrap_model (text_encoder_one ))):
684689 text_encoder_one_ = model
685- elif isinstance (model , type (accelerator . unwrap_model (text_encoder_two ))):
690+ elif isinstance (model , type (unwrap_model (text_encoder_two ))):
686691 text_encoder_two_ = model
687692 else :
688693 raise ValueError (f"unexpected save model: { model .__class__ } " )
@@ -1031,8 +1036,12 @@ def compute_time_ids(original_size, crops_coords_top_left):
10311036 )
10321037 unet_added_conditions .update ({"text_embeds" : pooled_prompt_embeds })
10331038 model_pred = unet (
1034- noisy_model_input , timesteps , prompt_embeds , added_cond_kwargs = unet_added_conditions
1035- ).sample
1039+ noisy_model_input ,
1040+ timesteps ,
1041+ prompt_embeds ,
1042+ added_cond_kwargs = unet_added_conditions ,
1043+ return_dict = False ,
1044+ )[0 ]
10361045
10371046 # Get the target for loss depending on the prediction type
10381047 if args .prediction_type is not None :
@@ -1125,9 +1134,9 @@ def compute_time_ids(original_size, crops_coords_top_left):
11251134 pipeline = StableDiffusionXLPipeline .from_pretrained (
11261135 args .pretrained_model_name_or_path ,
11271136 vae = vae ,
1128- text_encoder = accelerator . unwrap_model (text_encoder_one ),
1129- text_encoder_2 = accelerator . unwrap_model (text_encoder_two ),
1130- unet = accelerator . unwrap_model (unet ),
1137+ text_encoder = unwrap_model (text_encoder_one ),
1138+ text_encoder_2 = unwrap_model (text_encoder_two ),
1139+ unet = unwrap_model (unet ),
11311140 revision = args .revision ,
11321141 variant = args .variant ,
11331142 torch_dtype = weight_dtype ,
@@ -1166,12 +1175,12 @@ def compute_time_ids(original_size, crops_coords_top_left):
11661175 # Save the lora layers
11671176 accelerator .wait_for_everyone ()
11681177 if accelerator .is_main_process :
1169- unet = accelerator . unwrap_model (unet )
1178+ unet = unwrap_model (unet )
11701179 unet_lora_state_dict = convert_state_dict_to_diffusers (get_peft_model_state_dict (unet ))
11711180
11721181 if args .train_text_encoder :
1173- text_encoder_one = accelerator . unwrap_model (text_encoder_one )
1174- text_encoder_two = accelerator . unwrap_model (text_encoder_two )
1182+ text_encoder_one = unwrap_model (text_encoder_one )
1183+ text_encoder_two = unwrap_model (text_encoder_two )
11751184
11761185 text_encoder_lora_layers = convert_state_dict_to_diffusers (get_peft_model_state_dict (text_encoder_one ))
11771186 text_encoder_2_lora_layers = convert_state_dict_to_diffusers (get_peft_model_state_dict (text_encoder_two ))
0 commit comments